import numpy as np
from machine_learning.lib.tree_node import Node
'''类决策树基类的实现,类其中有4个成员:root,max_depth,get_score,feature_sample_rate
其中:root记录所生成决策树的根节点,由根节点可唯一的表示这颗决策树
max_depth 记录CART算法中决策树模型的深度限制,有构造函数指定
get_score是分值函数的指针,由构造函数传入。传入方差则实现决策树回归问题的CART算法,传入熵则实现决策树分类的CART算法
feature_sample_rate是[0,1]区间的一个数,有构造函数指定,该算法在随机森林中会用到,决策树不需要
'''
class DecisionTreeBase:
def __init__(self,max_depth,get_score,feature_sample_rate=1.0):
self.max_depth=max_depth
self.get_score=get_score
self.feature_sample_rate=feature_sample_rate
def split_data(self,j,theta,X,idx):
idx1,idx2=list(),list()
for i in idx:
if X[i][j]<=theta:
idx1.append(i)
else:
idx2.append(i)
return idx1,idx2
def get_random_features(self,n):
shuffled=np.random.permutation(n)
size=int(self.feature_sample_rate*n)
return shuffled[:size]
def find_best_split(self,X,y,idx):
m,n=X.shape
best_score,best_j,best_theta=float("inf"),-1,float("inf")
best_idx1,best_idx2=list(),list()
selected_j=self.get_random_features(n)
for j in selected_j:
thetas=set([x[j] for x in X])
for theta in thetas:
idx1,idx2=self.split_data(j,theta,X,idx)
if min(len(idx1),len(idx2))==0:
continue
score1,score2=self.get_score(y,idx1),self.get_score(y,idx2)
w=1.0*len(idx1)/len(idx)
score=w*score1+(1-w)*score2
if score<best_score:
best_score,best_j,best_theta=score,j,theta
best_idx1,best_idx2=idx1,idx2
return best_j,best_theta,best_idx1,best_idx2,best_score
'''X,y分别表示全体训练数据的特征组与标签,idx是一个下标集合。generate_tree函数
用下表再idx中的训练数据递归生成一颗不超过d的决策树,并返回生成的决策树的根节点'''
def generate_tree(self,X,y,idx,d):
r=Node()
if d==0 or len(idx)==1:
r.p=np.average(y[idx],axis=0)
return r
j,theta,idx1,idx2,score=self.find_best_split(X,y,idx)
current_score=self.get_score(y,idx)
if score>=current_score:
return r
r.j,r.theta=j,theta
r.left,r.right=self.generate_tree(X,y,idx1,d-1),self.generate_tree(X,y,idx2,d-1)
return r
'''fit函数是CART函数的入口,调用generate_tree函数生成一颗深度不超过max_depth的决策树,并将根节点存于成员root中'''
def fit(self,X,y):
self.root=self.generate_tree(X,y,range(len(X)),self.max_depth)
def get_prediction(self,r,x):
if r.left==None and r.right==None:
return r.p
if x[r.j]<=r.theta:
return self.get_prediction(r.left,x)
else:
return self.get_prediction(r.right,x)
def predict(self,X):
y=list()
for i in range(len(X)):
y.append(self.get_prediction(self.root,X[i]))
return np.array(y)