机器学习模型交叉验证脚本

机器学习模型交叉验证脚本

本文以阿里云机器学习平台上的 ps_smart (GBDT)算法为例,提供一个搜索最佳超参数的交叉验证任务的bash脚本。

机器学习模型超参数网格搜索脚本 提供了超参数网格搜索的能力。然而,当验证集的数量较少时,网格搜索的最优超参数非常容易过拟合,在实际的生产环境中,往往效果不如预期。为了缓解数据量少的问题,我们把网格搜索的Top N最优超参数保存下来,对这组超参数继续使用交叉验证的方式评估每组超参数对应的模型的实现效果指标。

本文提供的示例是一个LTV预测的回归任务,计算MAE、RMSE、WAPE 三个评估指标。

#!/bin/bash
#set -x
odps='.odpscmd/bin/odpscmd --config=odps_config.ini'
hyper_params_file='hyper_params.txt'

function log_info()
{
    if [ "$LOG_LEVEL" != "WARN" ] && [ "$LOG_LEVEL" != "ERROR" ]
    then
        echo "`date +"%Y-%m-%d %H:%M:%S"` [INFO] ($$)($USER): $*";
    fi
}

function prepare()
{
    log_info "function [$FUNCNAME] begin"
    if [ ! -d ".odpscmd" ]; then
        wget https://odps-repo.oss-cn-hangzhou.aliyuncs.com/odpscmd/latest/odpscmd_public.zip
        unzip -d .odpscmd odpscmd_public.zip
    fi
    log_info "function [$FUNCNAME] end"
}

function gen_partition() {
    log_info "function [$FUNCNAME] begin"
    local n=$1
    local k=$2
    local i
    pt=""
    for ((i=0;i<$n;i++))
    do
        if [ "$i" -eq "$k" ]; then
            continue
        fi
        pt=${pt}",'"${i}"'"
    done
    exclude_pt=${pt#,}
    log_info "function [$FUNCNAME] end"
}

function prepare_cv_data() {
    log_info "function [$FUNCNAME] begin"
    $odps -e "CREATE TABLE IF NOT EXISTS ps_smart_ltv
    (
        mae DOUBLE,
        rmse DOUBLE,
        wape DOUBLE
    )
    PARTITIONED BY (pt STRING COMMENT '实验参数', k STRING);"

    $odps -e "CREATE TABLE IF NOT EXISTS userfeature_v2_googleplay_mergekv_freedom_day3_dataset
    (
        dt  STRING,
        uid STRING,
        kv  STRING,
        targetprice DOUBLE,
        ispay BIGINT
    )
    COMMENT '训练数据集'
    PARTITIONED BY (pt STRING COMMENT '分区')
    LIFECYCLE 7;"

    local n=10
    $odps -e "INSERT OVERWRITE TABLE userfeature_v2_googleplay_mergekv_freedom_day3_dataset PARTITION(pt)
    SELECT *
    FROM (
        SELECT dt,uid,kv,targetprice,ispay, FLOOR(rand() * ${n}) as pt
        FROM rg_ai_bj.tmp_userfeature_v2_googleplay_mergekv_freedom_day3_train_20220905_jp_m1
        UNION ALL
        SELECT dt,uid,replace(kv,',',' ') kv,targetprice,ispay, FLOOR(rand(20220826) * ${n}) as pt
        FROM rg_ai_bj.tmp_userfeature_v2_googleplay_mergekv_freedom_day3_test_20220905_jp_m1
    ) T;"

    local k
    for ((k=0;k<${n};k++))
    do
    {
        gen_partition $n $k
        $odps -e "INSERT OVERWRITE TABLE userfeature_v2_googleplay_mergekv_freedom_day3_dataset PARTITION(pt='exclude_${k}')
        SELECT \`(pt)?+.+\`
        FROM userfeature_v2_googleplay_mergekv_freedom_day3_dataset
        WHERE pt IN (${exclude_pt});"
    } &
    done
    wait
    log_info "function [$FUNCNAME] end"
}

function run_job() {
    log_info "function [$FUNCNAME] begin"
    local k_fold=$1
    local tree_count=$2
    local max_depth=$3
    local l1=$4
    local l2=$5
    local lr=$6
    local eps=$7
    local model=${tree_count}_${max_depth}_${l1/0./p}_${l2/0./p}_${lr/0./p}_${eps/0./p}
    log_info "run model: $model, k_fold: ${k_fold}"

    $odps -e "PAI -name ps_smart
    -project algo_public
    -DinputTableName='userfeature_v2_googleplay_mergekv_freedom_day3_dataset'
    -DinputTablePartitions='pt=exclude_${k_fold}'
    -DmodelName='smart_${k_fold}_${model}'
    -DoutputTableName='smart_table_${k_fold}_${model}'
    -DoutputImportanceTableName='smart_imp_${k_fold}_${model}'
    -DlabelColName='targetprice'
    -DfeatureColNames='kv'
    -DenableSparse='true'
    -Dobjective='reg:tweedie'
    -Dmetric='tweedie-nloglik'
    -DfeatureImportanceType='gain'
    -DtreeCount='${tree_count}'
    -DmaxDepth='${max_depth}'
    -Dshrinkage='${lr}'
    -Dl2='${l2}'
    -Dl1='${l1}'
    -Dlifecycle='31'
    -DsketchEps='${eps}'
    -DsampleRatio='1.0'
    -DfeatureRatio='1.0'
    -DbaseScore='0.0'
    -DminSplitLoss='0'
    "
    if [ $? -ne 0 ]; then
        return $?
    fi

    $odps -e "drop table if exists smart_output_${k_fold}_${model};"
    $odps -e "PAI -name prediction
    -project algo_public
    -DinputTableName='userfeature_v2_googleplay_mergekv_freedom_day3_dataset'
    -DinputTablePartitions='pt=${k_fold}'
    -DmodelName='smart_${k_fold}_${model}'
    -DoutputTableName='smart_output_${k_fold}_${model}'
    -DfeatureColNames='kv'
    -DappendColNames='targetprice'
    -DenableSparse='true'
    -DitemDelimiter=' '
    -Dlifecycle='128'
    "
    if [ $? -ne 0 ]; then
        return $?
    fi
    
    $odps -e "INSERT OVERWRITE TABLE ps_smart_ltv PARTITION(pt='${model}', k='${k_fold}')
    SELECT AVG(ABS(targetprice-prediction_result)) MAE,
        SQRT(AVG((targetprice-prediction_result)*(targetprice-prediction_result))) RMSE,
        SUM(ABS(targetprice-prediction_result))/SUM(ABS(targetprice)) WAPE
    FROM smart_output_${k_fold}_${model};"
    log_info "function [$FUNCNAME] end"
}


function run_cross_validation()
{
    log_info "function [$FUNCNAME] begin"
    local args=$@
    local tree_count=$1
    local max_depth=$2
    local l1=$3
    local l2=$4
    local lr=$5
    local eps=$6
    local model=${tree_count}_${max_depth}_${l1/0./p}_${l2/0./p}_${lr/0./p}_${eps/0./p}
 
    local n=10
    local i 
    for ((i=0;i<$n;i++))
    do
    {
        run_job ${i} $args  
    } &
    done
    wait


    $odps -e "
    INSERT OVERWRITE TABLE ps_smart_ltv PARTITION(pt='${model}', k='mean')
    select avg(MAE), avg(RMSE), avg(WAPE)
    from ps_smart_ltv
    where pt='${model}' and k!='mean';
    "
    log_info "function [$FUNCNAME] end"
}

function run_from_file()
{
    log_info "function [$FUNCNAME] begin"
    threadTask=1 #并发数
    fifoFile="test_fifo"
    rm -f ${fifoFile}
    mkfifo ${fifoFile}  #创建fifo管道
    exec 9<> ${fifoFile}
    rm -f ${fifoFile}
    # 预先向管道写入数据
    for ((i=0;i<${threadTask};i++))
    do
        echo "" >&9
    done
    
    log_info "wait all task finish,then exit!!!"
    while read line
    do
        read -u9
        {
            run_cross_validation $line
            echo "" >&9
        } &
    done < $1
    wait

    exec 9<&-  # 关闭文件描述符的读
    exec 9>&-  # 关闭文件描述符的写
    log_info "function [$FUNCNAME] end"
}

prepare
prepare_cv_data
run_from_file ${hyper_params_file}
#run_from_file $1

备注:请结合机器学习模型超参数网格搜索脚本使用,网格搜索的Top N最优超参数需要预先保存到hyper_params.txt文件中。

本文由mdnice多平台发布

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 160,026评论 4 364
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,655评论 1 296
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 109,726评论 0 244
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,204评论 0 213
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,558评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,731评论 1 222
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,944评论 2 314
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,698评论 0 203
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,438评论 1 246
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,633评论 2 247
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,125评论 1 260
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,444评论 3 255
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,137评论 3 238
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,103评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,888评论 0 197
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,772评论 2 276
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,669评论 2 271

推荐阅读更多精彩内容