一、算法简要
决策树的基本理念就是通过不断的条件筛选,从而得到最后的答案。knn算法最大的缺点就是无法给出数据的内在含义,而决策树则在数据形式非常容易理解,有一定的实际意义。
这里所讲到的决策树非叶子节点的建立是依据信息增益和熵的概念,这个可以自己去查。通过计算按特定属性划分数据集前后发生的熵的变化,选择信息增益最大的特征属性作为分叉节点,从而一步一步进行决策树的构造。
二、算法一般流程
1.收集数据:任意方法和途径
2.准备数据:数据必须离散化
3.分析数据:构造树完成后,检查图形是否符合预测
4.训练算法:决策树的构造
5.测试算法:一般将决策树用于分类,可以用错误率衡量
6.使用算法:决策树可以用于任何监督学习算法
三、算法伪代码
1.判断决策树是否构造完毕,否则执行2,是则执行3。有终止条件:所有的类标签完全相同,返回该类标签即可;使用完了所有的特征,仍然不能将数据集划分成包含唯一类别的分组,此时,返回出现次数较多的类别
2.选择时的信息增益最大的特征作为分叉节点,对分叉节点的每个特定值,调用1
3.决策树构造完毕,对测试数据进行分类
四、代码实现与示例
import matplotlib.pyplot as plt
from math import log
import operator
'''calculate shannon entropy:count different distribution first then calculate'''
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts={}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob*log(prob,2)
return shannonEnt
'''split the dataSet with the given feature index, form new dataSet without the given feature'''
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
'''choose the besf feature index with the largest information gain'''
def chosenBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0])-1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob*calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = newEntropy
bestFeature = i
return bestFeature
'''return the majority of the list, for figure out the class with unpurity examples'''
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),
reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
'''if we modify labels_t, python will not labels simutaneously'''
labels_t = labels[:]
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeat = chosenBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree={bestFeatLabel:{}}
del(labels_t[bestFeat])
featValues =[example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels_t[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
def storeTree(inputTree, filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
def classify(inputTree, featLabels, testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex]==key:
if type(secondDict[key]).__name__=='dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
myDat = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']
]
labels = ['no surfacing','flippers']
myT = createTree(myDat, labels)
print myT
fileName="tmpTree.txt"
storeTree(myT,fileName)
print grabTree(fileName)
print classify(myT, labels, [1,0])