机器学习经典算法2-决策树

本文介绍了决策树的基本理念,强调其相比knn算法在数据理解上的优势。通过信息增益和熵来构建决策树,并概述了算法的一般流程,包括数据收集、准备、分析、训练、测试和使用。此外,还提到了决策树构造的终止条件和伪代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、算法简要

         决策树的基本理念就是通过不断的条件筛选,从而得到最后的答案。knn算法最大的缺点就是无法给出数据的内在含义,而决策树则在数据形式非常容易理解,有一定的实际意义。

        这里所讲到的决策树非叶子节点的建立是依据信息增益和熵的概念,这个可以自己去查。通过计算按特定属性划分数据集前后发生的熵的变化,选择信息增益最大的特征属性作为分叉节点,从而一步一步进行决策树的构造。

二、算法一般流程

       1.收集数据:任意方法和途径

       2.准备数据:数据必须离散化

       3.分析数据:构造树完成后,检查图形是否符合预测

       4.训练算法:决策树的构造

       5.测试算法:一般将决策树用于分类,可以用错误率衡量

       6.使用算法:决策树可以用于任何监督学习算法

三、算法伪代码

      1.判断决策树是否构造完毕,否则执行2,是则执行3。有终止条件:所有的类标签完全相同,返回该类标签即可;使用完了所有的特征,仍然不能将数据集划分成包含唯一类别的分组,此时,返回出现次数较多的类别

      2.选择时的信息增益最大的特征作为分叉节点,对分叉节点的每个特定值,调用1

      3.决策树构造完毕,对测试数据进行分类

四、代码实现与示例

 

import matplotlib.pyplot as plt
from math import log
import operator
'''calculate shannon entropy:count different distribution first then calculate'''
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts={}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob*log(prob,2)
    return shannonEnt
'''split the dataSet with the given feature index, form new dataSet without the given feature'''
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet
'''choose the besf feature index with the largest information gain'''
def chosenBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0])-1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob*calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = newEntropy
            bestFeature = i
    return bestFeature
'''return the majority of the list, for figure out the class with unpurity examples'''
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),
                              reverse=True)
    return sortedClassCount[0][0]
def createTree(dataSet, labels):
    '''if we modify labels_t, python will not labels simutaneously'''
    labels_t = labels[:]
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat = chosenBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels_t[bestFeat])
    featValues =[example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels_t[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)
def classify(inputTree, featLabels, testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel
myDat = [[1,1,'yes'],
         [1,1,'yes'],
         [1,0,'no'],
         [0,1,'no'],
         [0,1,'no']
         ]
labels = ['no surfacing','flippers']
myT = createTree(myDat, labels)
print myT
fileName="tmpTree.txt"
storeTree(myT,fileName)
print grabTree(fileName)
print classify(myT, labels, [1,0])


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

大胖5566

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值