Skip to the content.

《Machine Learning In Action》 中Decision Tree部分代码精简优化


Contact me:

Blog -> https://cugtyt.github.io/blog/index
Email -> cugtyt@qq.com
GitHub -> Cugtyt@GitHub


本系列博客主页及相关见此处
此处粘贴的代码来源为https://github.com/Jack-Cherish/Machine-Learning


1. calcShannonEnt函数,计算香农熵

原代码:

def calcShannonEnt(dataSet):
    numEntires = len(dataSet)                        #返回数据集的行数
    labelCounts = {}                                #保存每个标签(Label)出现次数的字典
    for featVec in dataSet:                            #对每组特征向量进行统计
        currentLabel = featVec[-1]                    #提取标签(Label)信息
        if currentLabel not in labelCounts.keys():    #如果标签(Label)没有放入统计次数的字典,添加进去
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1                #Label计数
    shannonEnt = 0.0                                #经验熵(香农熵)
    for key in labelCounts:                            #计算香农熵
        prob = float(labelCounts[key]) / numEntires    #选择该标签(Label)的概率
        shannonEnt -= prob * log(prob, 2)            #利用公式计算
    return shannonEnt                                #返回经验熵(香农熵)

问题:

简单的for循环可以使用列表推导实现,逻辑清晰易懂,使用Counter可以省去自己手动count。

参考修改:

def calcShannonEnt(dataSet):
    # Count labels
    label_count = Counter(data[-1] for data in dataSet)
    # Compute prob
    probs = [p[1] / len(dataSet) for p in label_count.items()]
    # Compute entropy and sum
    shannonEnt = sum([-p * log(p, 2) for p in probs])
    return shannonEnt

2. splitDataSet函数,对于指定的特征分割数据集

原代码:

def splitDataSet(dataSet, axis, value):       
    retDataSet = []                                        #创建返回的数据集列表
    for featVec in dataSet:                             #遍历数据集
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]                #去掉axis特征
            reducedFeatVec.extend(featVec[axis+1:])     #将符合条件的添加到返回的数据集
            retDataSet.append(reducedFeatVec)
    return retDataSet                                      #返回划分后的数据集

问题:

使用列表推导可以简化代码,逻辑更加清晰易懂,也更轻巧,注意,修改后的代码略有不同,原代码删除了一部分值,参考代码保留,因为这对于后面的处理没有影响。

参考修改:

def splitDataSet(dataSet, axis, value):
    retDataSet = [data for data in dataSet for i, v in enumerate(data) if i == axis and v == value]
    return retDataSet              

3. chooseBestFeatureToSplit函数,用于遍历所有特征,选择最好的分割特征

原代码:

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1                    #特征数量
    baseEntropy = calcShannonEnt(dataSet)                 #计算数据集的香农熵
    bestInfoGain = 0.0                                  #信息增益
    bestFeature = -1                                    #最优特征的索引值
    for i in range(numFeatures):                         #遍历所有特征
        #获取dataSet的第i个所有特征
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)                         #创建set集合{},元素不可重复
        newEntropy = 0.0                                  #经验条件熵
        for value in uniqueVals:                         #计算信息增益
            subDataSet = splitDataSet(dataSet, i, value)         #subDataSet划分后的子集
            prob = len(subDataSet) / float(len(dataSet))           #计算子集的概率
            newEntropy += prob * calcShannonEnt(subDataSet)     #根据公式计算经验条件熵
        infoGain = baseEntropy - newEntropy                     #信息增益
        print("第%d个特征的增益为%.3f" % (i, infoGain))            #打印每个特征的信息增益
        if (infoGain > bestInfoGain):                             #计算信息增益
            bestInfoGain = infoGain                             #更新信息增益,找到最大的信息增益
            bestFeature = i                                     #记录信息增益最大的特征的索引值
    return bestFeature                                             #返回信息增益最大的特征的

问题:

实现繁杂,可以使用列表推导简化,利用内置函数也能简化和加速代码。

参考修改:

def chooseBestFeatureToSplit(dataSet):
    base_entropy = calcShannonEnt(dataSet)
    best_info_gain = 0
    best_feature = -1
    # For loop each feature
    for i in range(len(dataSet[0]) - 1):
        feature_count = Counter([data[i] for data in dataSet])
        # Compute new entropy
        new_entropy = sum(feature[1] / float(len(dataSet)) * calcShannonEnt(splitDataSet(dataSet, i, feature[0])) \
                       for feature in feature_count.items())
        # Compute gain and update
        info_gain = base_entropy - new_entropy
        print('No. {0} feature info gain is {1:.3f}'.format(i, info_gain))
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_feature = i
    return best_feature

4. majorityCnt函数,用于统计出现最多的类标签

原代码:

def majorityCnt(classList):
    classCount = {}
    for vote in classList:                                        #统计classList中每个元素出现的次数
        if vote not in classCount.keys():classCount[vote] = 0   
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)        #根据字典的值降序排序
    return sortedClassCount[0][0]                                #返回classList中出现次数最多的元素

问题:

实现繁杂,可以直接使用Counter简化。

参考修改:

def majorityCnt(classList):
    major_label = Counter(classList).most_common(1)[0]
    return major_label

5. 对类别的判断应该使用 is

原代码:

type(secondDict[key]).__name__ == 'dict'

问题:

实现繁杂,使用is可读性更好。

参考修改:

type(secondDict[key]) is dict

6. 从文件读取数据使用pandas的函数更有效

原代码:

if __name__ == '__main__':
    with open('lenses.txt', 'r') as fr:                                        #加载文件
        lenses = [inst.strip().split('\t') for inst in fr.readlines()]        #处理文件
    lenses_target = []                                                        #提取每组数据的类别,保存在列表里
    for each in lenses:
        lenses_target.append(each[-1])

    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']            #特征标签       
    lenses_list = []                                                        #保存lenses数据的临时列表
    lenses_dict = {}                                                        #保存lenses数据的字典,用于生成pandas
    for each_label in lensesLabels:                                            #提取信息,生成字典
        for each in lenses:
            lenses_list.append(each[lensesLabels.index(each_label)])
        lenses_dict[each_label] = lenses_list
        lenses_list = []
    # print(lenses_dict)                                                        #打印字典信息
    lenses_pd = pd.DataFrame(lenses_dict)                                    #生成pandas.DataFrame
    print(lenses_pd)                                                        #打印pandas.DataFrame
    le = LabelEncoder()                                                        #创建LabelEncoder()对象,用于序列化            
    for col in lenses_pd.columns:                                            #为每一列序列化
        lenses_pd[col] = le.fit_transform(lenses_pd[col])
    print(lenses_pd)

问题:

实现繁杂,使用pandas的内置函数可以更方便处理。

参考修改:

if __name__ == '__main__':
    lenses = pd.DataFrame(pd.read_table('lenses.txt'))
    lenses.columns = ['age', 'prescript', 'astigmatic', 'tearRate', 'target']
    lenses = lenses.drop(['target'], axis=1)
    print(lenses) 
    le = LabelEncoder()
    for col in lenses.columns:
        lenses[col] = le.fit_transform(lenses[col])
    print(lenses)