决策树挑出好西瓜

本文详细介绍了如何使用决策树算法ID3、C4.5和CART对西瓜数据集进行建模,包括信息熵、信息增益和增益率的原理,以及sk-learn库的具体实现步骤。通过实例展示了决策树在西瓜品质预测中的应用和可视化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、决策树

1.概念

决策树是一种基于树结构来进行决策的分类算法,我们希望从给定的训练数据集学得一个模型(即决策树),用该模型对新样本分类。决策树可以非常直观展现分类的过程和结果,一旦模型构建成功,对新样本的分类效率也相当高。

2. 信息熵(information entropy)

假定当前样本集合D中第k类样本所占的比例为pk(k=1,,2,...,
|y|),则D的信息熵定义为:Ent(D) = -∑k=1 pk·log2 pk约定
若p=0,则log2 p=0)显然,Ent(D)值越小,D的纯度越高。因为
0<=pk<= 1,故log2 pk<=0,Ent(D)>=0. 极限情况下,考虑D中
样本同属于同一类,则此时的Ent(D)值为0(取到最小值)。当D
中样本都分别属于不同类别时,Ent(D)取到最大值log2 |y|.

3. 信息增益(information gain)

  假定离散属性a有V个可能的取值{a1,a2,...,aV}. 若使用a
  对样本集D进行分类,则会产生V个分支结点,记Dv为第v个分
  支结点包含的D中所有在属性a上取值为av的样本。不同分支结
  点样本数不同,我们给予分支结点不同的权重:|Dv|/|D|, 该
  权重赋予样本数较多的分支结点更大的影响、由此,用属性a对
  样本集D进行划分所获得的信息增益定义为:Gain(D,a) = Ent(D)-∑v=1 |Dv|/|D|·Ent(Dv)

其中,Ent(D)是数据集D划分前的信息熵,∑v=1 |Dv|/|D|·Ent(Dv)可以表示为划分后的信息熵。“前-后”的结果表明了本次划分所获得的信息熵减少量,也就是纯度的提升度。显然,Gain(D,a) 越大,获得的纯度提升越大,此次划分的效果越好。

4. 增益率(gain ratio)

  基于信息增益的最优属性划分原则——信息增益准则,对可取值
  数据较多的属性有所偏好。C4.5算法使用增益率替代信息增益来
  选择最优划分属性,增益率定义为:Gain_ratio(D,a) = Gain(D,a)/IV(a
  其中 IV(a) = -∑v=1 |Dv|/|D|·log2 |Dv|/|D|

称为属性a的固有值。属性a的可能取值数目越多(即V越大),则IV(a)的值通常会越大。这在一定程度上消除了对可取值数据较多的属性的偏好。
事实上,增益率准则对可取值数目较少的属性有所偏好,C4.5算法并不是直接使用增益率准则,而是先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的。

下面,我们就对以下表格中的西瓜样本构建决策树模型。
在这里插入图片描述

5、代码实现

1.导入模块部分

#导入模块

import pandas as pd
import numpy as np
from collections import Counter
from math import log2

用pandas模块的read_excel()函数读取数据文本;用numpy模块将dataframe转换为list(列表);用Counter来完成计数;用math模块的log2函数计算对数。后边代码中会有对应体现。
2.数据获取与处理函数
#数据获取与处理

def getData(filePath):
    data = pd.read_excel(filePath)
    return data

def dataDeal(data):
    dataList = np.array(data).tolist()
    dataSet = [element[1:] for element in dataList]
    return dataSet

getData()通过pandas模块中的read_excel()函数读取样本数据。尝试过将数据文件保存为csv格式,但是对于中文的处理不是很好,所以选择了使用xls格式文件。

dataDeal()函数将dataframe转换为list,并且去掉了编号列。编号列并不是西瓜的属性,事实上,如果把它当做属性,会获得最大的信息增益。

这两个函数是完全可以合并为同一个函数的,但是因为我想分别使用data(dataframe结构,带属性标签)和dataSet(list)数据样本,所以分开写了两个函数。

3.获取属性名称

#获取属性名称

def getLabels(data):
    labels = list(data.columns)[1:-1]
    return labels

很简单,获取属性名称:纹理,色泽,根蒂,敲声,脐部,触感。
4.获取类别标记
#获取类别标记

def targetClass(dataSet):
    classification = set([element[-1] for element in dataSet])
    return classification

获取一个样本是否好瓜的标记(是与否)。

5.叶结点标记

#将分支结点标记为叶结点,选择样本数最多的类作为类标记

def majorityRule(dataSet):
    mostKind = Counter([element[-1] for element in dataSet]).most_common(1)
    majorityKind = mostKind[0][0]
    return majorityKind

6.计算信息熵

#计算信息熵

def infoEntropy(dataSet):
    classColumnCnt = Counter([element[-1] for element in dataSet])
    Ent = 0
    for symbol in classColumnCnt:
        p_k = classColumnCnt[symbol]/len(dataSet)
        Ent = Ent-p_k*log2(p_k)
    return Ent

7.子数据集构建

#子数据集构建

def makeAttributeData(dataSet,value,iColumn):
    attributeData = []
    for element in dataSet:
        if element[iColumn]==value:
            row = element[:iColumn]
            row.extend(element[iColumn+1:])
            attributeData.append(row)
    return attributeData

在某一个属性值下的数据,比如纹理为清晰的数据集。

8.计算信息增益

#计算信息增益

def infoGain(dataSet,iColumn):
    Ent = infoEntropy(dataSet)
    tempGain = 0.0
    attribute = set([element[iColumn] for element in dataSet])
    for value in attribute:
        attributeData = makeAttributeData(dataSet,value,iColumn)
        tempGain = tempGain+len(attributeData)/len(dataSet)*infoEntropy(attributeData)
        Gain = Ent-tempGain
    return Gain

9.选择最优属性

#选择最优属性

def selectOptimalAttribute(dataSet,labels):
    bestGain = 0
    sequence = 0
    for iColumn in range(0,len(labels)):#不计最后的类别列
        Gain = infoGain(dataSet,iColumn)
        if Gain>bestGain:
            bestGain = Gain
            sequence = iColumn
        print(labels[iColumn],Gain)
    return sequence

10.建立决策树

#建立决策树

def createTree(dataSet,labels):
    classification = targetClass(dataSet) #获取类别种类(集合去重)
    if len(classification) == 1:
        return list(classification)[0]
    if len(labels) == 1:
        return majorityRule(dataSet)#返回样本种类较多的类别
    sequence = selectOptimalAttribute(dataSet,labels)
    print(labels)
    optimalAttribute = labels[sequence]
    del(labels[sequence])
    myTree = {optimalAttribute:{}}
    attribute = set([element[sequence] for element in dataSet])
    for value in attribute:
        
        print(myTree)
        print(value)
        subLabels = labels[:]
        myTree[optimalAttribute][value] =  \
                createTree(makeAttributeData(dataSet,value,sequence),subLabels)
    return myTree

树本身并不复杂,采用递归的方式实现。

11.定义主函数

def main():
    filePath = 'watermelonData.xls'
    data = getData(filePath)
    dataSet = dataDeal(data)
    labels = getLabels(data)
    myTree = createTree(dataSet,labels)
    return myTree

12.生成树

if __name__ == '__main__':
    myTree = main()

在这里插入图片描述在这里插入图片描述

6.绘制决策树

#使用Matlotlib绘制决策树
import matplotlib.pyplot as plt

#设置文本框和箭头格式
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['font.family'] = 'sans-serif'

#画节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy = parentPt,\
    xycoords = "axes fraction", xytext = centerPt, textcoords = 'axes fraction',\
    va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)
    
#获取决策树的叶子节点数
def getNumLeafs(myTree):
    leafNumber = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if(type(secondDict[key]).__name__ == 'dict'):
            leafNumber = leafNumber + getNumLeafs(secondDict[key])
        else:
            leafNumber += 1
    return leafNumber

#获取决策树的高度(递归)
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        #test to see if the nodes are dictonaires, if not they are leaf nodes
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

#在父子节点添加信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

#画树
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


#画布初始化
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

定义主函数

def main():
    #createPlot()
    print(getTreeDepth(myTree)) #输出数的深度
    print(getNumLeafs(myTree))  #输出数的叶子个数
    createPlot(myTree)
main()

在这里插入图片描述

二、用sk-learn库对西瓜数据集,分别进行ID3、C4.5和CART的算法代码实现。

1. 基于信息增益准则(I D 3 ID3ID3或C 4.5 C4.5C4.5)方法建立决策树

导入相关库

#导入相关库
import pandas as pd
import graphviz 
from sklearn.model_selection import train_test_split
from sklearn import tree

导入数据

f = open('watermalon.csv','r',encoding='utf-8')
data = pd.read_csv(f)

x = data[["色泽","根蒂","敲声","纹理","脐部","触感"]].copy()
y = data['好瓜'].copy()
print(data)

在这里插入图片描述

数据预处理
将特征值数值化

#将特征值数值化
x = x.copy()
for i in ["色泽","根蒂","敲声","纹理","脐部","触感"]:
    for j in range(len(x)):
        if(x[i][j] == "青绿" or x[i][j] == "蜷缩" or data[i][j] == "浊响" \
           or x[i][j] == "清晰" or x[i][j] == "凹陷" or x[i][j] == "硬滑"):
            x[i][j] = 1
        elif(x[i][j] == "乌黑" or x[i][j] == "稍蜷" or data[i][j] == "沉闷" \
           or x[i][j] == "稍糊" or x[i][j] == "稍凹" or x[i][j] == "软粘"):
            x[i][j] = 2
        else:
            x[i][j] = 3
            
y = y.copy()
for i in range(len(y)):
    if(y[i] == "是"):
        y[i] = int(1)
    else:
        y[i] = int(-1) 

将数据转换为DataFrame数据类型

#需要将数据x,y转化好格式,数据框dataframe,否则格式报错
x = pd.DataFrame(x).astype(int)
y = pd.DataFrame(y).astype(int)
print(x)
print(y)

在这里插入图片描述

划分训练集与测试集
将80%数据用于训练,20%数据用于测试

x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.2)
print(x_train)

在这里插入图片描述

建立模型并训练

#决策树学习
clf = tree.DecisionTreeClassifier(criterion="entropy")                    #实例化 
clf = clf.fit(x_train, y_train) 
score = clf.score(x_test, y_test)
print(score)

可视化决策树

feature_name = ["色泽","根蒂","敲声","纹理","脐部","触感"]
dot_data = tree.export_graphviz(clf                          
                                ,feature_names= feature_name                                
                                ,class_names=["好瓜","坏瓜"]                                
                                ,filled=True                                
                                ,rounded=True
                                ,out_file =None  
                                ) 
graph = graphviz.Source(dot_data) 
graph

2. 基于基尼指数(C A R T CARTCART)建立决策树

根据DecisionTreeClassifier函数参数解释可知,只需将criterion值改为"gini",即可基于基尼指数(C A R T CARTCART)建立决策树。

#决策树学习
clf = tree.DecisionTreeClassifier(criterion="gini")                    #实例化 
clf = clf.fit(x_train, y_train) 
score = clf.score(x_test, y_test)
print(score)

可视化(C A R T CARTCART)决策树

在这里插入图片描述

三、总结

该实验是在jupyter下实现针对西瓜数据集的ID3算法代码,并输出可视化结果。
用sk-learn库对西瓜数据集,分别进行ID3、C4.5和CART的算法代码实现。

四、参考链接

https://www.cnblogs.com/dennis-liucd/p/7905793.html,

https://blog.csdn.net/bjjoy2009/article/details/80884526

https://blog.csdn.net/YangMax1/article/details/120917236?spm=1001.2014.3001.5501

https://blog.csdn.net/qq_55691662/article/details/120569410

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值