python实现id3决策树代码连续值
时间: 2025-06-30 18:01:45 浏览: 12
### 实现支持连续值的ID3决策树的Python代码
以下是一个支持连续值的ID3决策树的Python实现。该代码扩展了传统的ID3算法,使其能够处理连续特征。通过寻找最佳分割点来处理连续值。
```python
import numpy as np
import pandas as pd
from collections import Counter
class ID3DecisionTree:
def __init__(self, min_samples_split=2, max_depth=10):
self.min_samples_split = min_samples_split
self.max_depth = max_depth
self.tree = None
def entropy(self, y):
hist = np.bincount(y)
ps = hist / len(y)
return -np.sum([p * np.log2(p) for p in ps if p > 0])
def best_split(self, X, y):
best_gain = -1
split_idx, split_value = None, None
_, n_features = X.shape
base_entropy = self.entropy(y)
for idx in range(n_features):
values = X[:, idx]
thresholds = np.unique(values)
for threshold in thresholds:
dy = np.array([1 if v <= threshold else 0 for v in values])
left_indices = np.where(dy == 1)[0]
right_indices = np.where(dy == 0)[0]
if len(left_indices) == 0 or len(right_indices) == 0:
continue
p = len(left_indices) / len(y)
gain = base_entropy - p * self.entropy(y[left_indices]) - (1 - p) * self.entropy(y[right_indices])
if gain > best_gain:
best_gain = gain
split_idx = idx
split_value = threshold
return best_gain, split_idx, split_value
def build_tree(self, X, y, depth=0):
n_samples, n_features = X.shape
if depth >= self.max_depth or n_samples < self.min_samples_split or len(np.unique(y)) == 1:
leaf_value = Counter(y).most_common(1)[0][0]
return leaf_value
best_gain, split_idx, split_value = self.best_split(X, y)
if best_gain == -1:
leaf_value = Counter(y).most_common(1)[0][0]
return leaf_value
left_indices = np.where(X[:, split_idx] <= split_value)[0]
right_indices = np.where(X[:, split_idx] > split_value)[0]
question = "{} <= {}".format(split_idx, split_value)
subtree = {question: []}
left_tree = self.build_tree(X[left_indices, :], y[left_indices], depth + 1)
right_tree = self.build_tree(X[right_indices, :], y[right_indices], depth + 1)
if left_tree == right_tree:
subtree = left_tree
else:
subtree[question].append(left_tree)
subtree[question].append(right_tree)
return subtree
def fit(self, X, y):
self.tree = self.build_tree(X, y)
def predict_sample(self, x, tree):
if not isinstance(tree, dict):
return tree
question = list(tree.keys())[0]
feature_idx, comparison_operator, value = self.parse_question(question)
if x[feature_idx] <= float(value):
answer = tree[question][0]
else:
answer = tree[question][1]
if not isinstance(answer, dict):
return answer
else:
return self.predict_sample(x, answer)
def parse_question(self, question):
feature_idx, value = question.split(" <= ")
return int(feature_idx), "<=", float(value)
def predict(self, X):
return [self.predict_sample(x, self.tree) for x in X]
```
此代码中,`best_split` 函数用于寻找最佳分割点以处理连续值。通过计算信息增益[^1],选择最佳特征和对应的分割值。构建树的过程递归地进行,直到达到最大深度或节点中的样本数不足以继续分裂[^4]。
###
阅读全文
相关推荐


















