决策树 CART处理连续型数据的 python实现

参考的原文地址,原文还有对CART的分析。
这个代码用的是误差平方和最小,对连续值进行了二分,可以处理0,1离散值,不能‘好,不好’这种字符串表示特征。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
import numpy as np
import pickle
# =============================================================================
# 创建CART树(回归树或模型树)
# 输入:数据集data,叶子节点形式leafType:regressLeaf(回归树)、modelLeaf(模型树)
# 损失函数errType:误差平方和也分为regressLeaf和modelLeaf
# 用户自定义阈值参数:误差减少的阈值,子样本集应包含的最少样本个数
# =============================================================================
def CreateCart(data,leafType=regressLeaf,errType=regressErr,threshold=(1,4)):

    #寻找最优特征与最优切分点
    feature,value = ChooseBest(data,leafType,errType,threshold)
    #停止条件:当结点的样本个数小于阈值,或基尼指数小于阈值,或者没有更多特征时停止
    if feature == None:
        return value
    returnTree = {}
    returnTree['bestSplitFeature'] = feature
    returnTree['bestSplitValue'] = value
    leftSet,rightSet = binarySplit(data,feature,value)
    returnTree['left'] = leftSet
    returnTree['right'] = rightSet
    return returnTree

def ChooseBest(data,leafType,errType,threshold):
    thresholdErr, thresholdSamples = threshold[0],threshold[1]
    #数据中输出相同,不需要继续划分树,输出平均值
    #需要将np array转换成list再做set
    if len(set(data[:,-1].T.tolist()[0])) == 1:
        #回归树返回叶子平均值,模型树返回系数
        return None,leafType(data)
    m,n = data.shape()
    #分别处理回归树、模型树的err计算方法
    Err = errType(data)
    bestErr, bestFindex, bestFval = np.Inf,0,0
    #对于每个特征,根据每个取值将data划分成两个子集,计算两个子集的err
    #取所有可能划分点中,err最小的特征和结点,作为划分点
    #同时,保证划分的子集中样本个数大于阈值thresholdSamples
    #这种划分方式适用于连续型数据,标称型数据如果不能比大小则要用id3的方式划分
    for findex in range(n-1):
        for fval in data[:,findex]:
            left,right = binarySplit(data,findex,fval)
            #阈值判断
            if (left.shape()[0]<thresholdSamples) or (right.shape()[0]<thresholdSamples):
                continue
            temerr = errType(left) + errType(right)
            #更新最小误差
            if temerr < bestErr:
                bestErr, bestFindex, bestFval = temerr,findex,fval
    #检验所选最优划分点的误差,与未划分时的差值是否小于阈值thresholdErr
    if Err - bestErr < thresholdErr:
        return None, leafType(data)
    
    return bestFindex,bestFval
            
def binarySplit(data,findex, fval):
    #np.nonzero选取符合条件的非零值所在的每一纬度的下标
    left = data[np.nonzero(data[:,findex] <= fval)[0],:]
    right = data[np.nonzero(data[:,findex] > fval)[0],:]
    return left, right
# =============================================================================
# 回归树、模型树对应的不同处理函数
# =============================================================================
#输出叶结点平均值
def regressLeaf(data):
    #数据集输出列即最后一列的平均值
    return np.mean(data[:,-1])
#输出系数
def modelLeaf(data):
    w,x,y = linearSolve(data)
    return w
#y=kx+b写成矩阵形式
def linearSolve(data):
    m,n = np.shape(data)
    #mat转换成矩阵,方便后面计算
    x,y = np.mat(np.ones((m,n))),np.mat(np.ones((m,1)))
    x[:,1:n] = data[:,0:(n-1)]
    y = data[:,-1]
    xTx = x.T*x
    #计算行列式,判断是否能够取逆
    if np.linalg.det(xTx) == 0:
        #抛出异常
        raise NameError('matrix cannot do inverse,try increasing the second value of threshold')
    else:
        w = xTx.I*(x.T*y)
        return w,x,y
#划分后左右数据集的总误差平方和
def regressErr(data):
    #回归树的叶子节点取的是均值
    #计算误差相当于(每个输出-均值)^2
    #就是方差
    #np.var的方差分母是(n-1)
    return np.var(data[:,-1])*(np.shape(data)[0]-1)

#模型树误差
def modelErr(data):
    w,x,y = linearSolve(data)
    y_pie = x * w
    return sum(np.power(y-y_pie,2))

# =============================================================================
# 剪枝
# =============================================================================
#从叶子向上,比较剪掉和不剪的误差
def prune(tree,test):
    #数据集没有数据
    if test.shape()[0] == 0: return getMean(tree)
    #向下递归到叶子
    if isTree(tree['left']) or isTree(tree['right']):
        testleft,testright = binarySplit(test,tree['bestSplitFeature'],tree['bestSplitFeatValue'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'],testleft)
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'],testright)
    #找到叶子之后计算误差
    if not isTree(tree['left']) and not isTree(tree['right']):
        leftmean,rightmean = binarySplit(test,tree['bestSplitFeature'],tree['bestSplitFeatValue'])
        errno = sum(np.power(leftmean[:,-1]-tree['left'],2))+sum(np.power(rightmean[:,-1]-tree['right'],2))
        errMerge = sum(np.power(test[:,-1]-getMean(tree),2))
        if errMerge < errno:
            print 'merge'
            return getMean(tree)
        else:
            return tree
    else:
        return tree

def getMean(tree):
    #tree['left']保存左子树数据集,如果是叶子节点保存平均值
    if isTree(tree['left']):tree['left'] = getMean(tree['left'])
    if isTree(tree['right']):tree['right'] = getMean(tree['right'])
    return (tree['left']+tree['right'])/2.


#判断是否存在叶结点
def isTree(obj):
    #type(obj)返回的是obj的类型关键字
    #__name__将类型关键字转化为str
    return (type(obj).__name__=='dict')
# =============================================================================
# 预测
# =============================================================================
def createForeCast(tree,test,modelEval=regressEvaluation):
    m = len(test)
    y = np.mat(np.zeros((m,1)))
    for i in range(m):
        y = treeForeCast(tree,test[i],modelEval)
    return y

def treeForeCast(tree,test,modelEval=regressEvaluation):
    #到了叶子节点输出结果
    if not isTree(tree): return modelEval(tree,test)
    #向左子树递归
    if test[tree['bestSplitFeature']] <= tree['bestSplitFeature']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'],test,modelEval)
        else:
            return modelEval(tree['left'],test)
    #向右子树递归
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'],test,modelEval)
        else:
            return modelEval(tree['right'],test)
    
def regressEvaluation(tree,test):
    return float(tree)
# =============================================================================
# 数据加载
# =============================================================================

def regressData(filename):
    fr = open(filename)
    #持久化
    return pickle.load(fr)

if __name__ == '__main__':
    trainfilename = 'e:\\python\\ml\\trainDataset.txt'
    testfilename = 'e:\\python\\ml\\testDataset.txt'

    trainDataset = regressData(trainfilename)
    testDataset = regressData(testfilename)
    
    cartTree = CreateCart(trainDataset,threshold=(1,4))
    pruneTree = prune(cartTree,testDataset)
    y = createForeCast(cartTree,np.mat([0.3]),modelEval=regressEvaluation)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 157,298评论 4 360
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 66,701评论 1 290
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 107,078评论 0 237
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,687评论 0 202
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,018评论 3 286
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,410评论 1 211
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,729评论 2 310
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,412评论 0 194
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,124评论 1 239
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,379评论 2 242
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 31,903评论 1 257
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,268评论 2 251
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,894评论 3 233
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,014评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,770评论 0 192
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,435评论 2 269
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,312评论 2 260

推荐阅读更多精彩内容