from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
iris = load_iris()
k_range = range(1, 31)
k_error = []
x = iris.data
y = iris.target
k_sel = 0
min_error = 1
for k in k_range:
knn = KNeighborsClassifier(n_neighbors=k)
# cv参数决定数据集划分比例,6折交叉验证
scores = cross_val_score(knn, x, y, cv=6, scoring='accuracy')
k_error.append(1 - scores.mean())
if (1 - scores.mean() < min_error):
min_error = 1 - scores.mean()
k_sel = k
print("k:", k_sel, "min_error:", min_error)
plt.plot(k_range, k_error)
plt.xlabel('Value of K for KNN')
plt.ylabel('Error')
plt.show()
首先是数据层面的操作。通过load_iris()加载经典的鸢尾花数据集,该数据集包含150个样本,每个样本有花萼长度、花萼宽度、花瓣长度、花瓣宽度4个特征,对应山鸢尾、变色鸢尾、维吉尼亚鸢尾3类标签。代码中先查看了数据集的基础信息,如打印特征数据、标签、特征名称和目标类别名称,还利用train_test_split将数据集按7:3的比例划分为训练集与测试集,为后续模型训练和评估提供数据支撑,这一步是机器学习任务中“数据准备”阶段的关键环节,确保数据能被模型有效利用。
接着进入模型的训练与初步评估。构建KNeighborsClassifier实例(初始设置n_neighbors=5),调用fit方法用训练集数据训练模型,让模型学习“特征→标签”的对应关系。之后用predict方法对测试集特征进行预测,得到预测标签y_predict,再通过score方法计算模型在测试集上的准确率,以此评估模型初步的分类性能,这是“模型构建与评估”阶段,验证了KNN算法在鸢尾花数据集上的基本可用性。
最后是模型的优化——最优k值选择。KNN算法的核心参数k(近邻数量)对模型性能影响显著,为找到最优k,代码设置k的候选范围为1到30,通过cross_val_score进行6折交叉验证(将数据集拆分为6份,轮流作为训练和验证集),计算每个k值下模型的平均准确率,并转换为错误率存入列表。同时,在循环中实时跟踪最小错误率及对应的k值,最终确定当k=9时,模型错误率最低(约22.5%)。为更直观展示k与错误率的关系,还通过matplotlib绘制折线图,清晰呈现错误率随k值变化的趋势,这一“参数优化”阶段能让模型在特定数据集上的性能达到更优状态。
整体而言,这段代码完整覆盖了机器学习分类任务的核心步骤,从数据准备、模型构建、初步评估到参数优化,清晰展现了如何利用KNN算法解决鸢尾花分类问题,是理解和实践监督学习分类任务的很好范例,有助于初学者掌握机器学习的基本流程与方法。