利用KNN算法对鸢尾花数据进行分类

本文介绍了一个基于鸢尾花数据集的KNN分类算法实现过程,包括数据预处理、KNN算法实现及正确率分析。通过Python代码演示了如何计算不同K值下的分类准确率,并展示了K值与准确率之间的关系图。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

鸢尾花数据集下载:

https://blog.csdn.net/wolflikeinnocence/article/details/90140587

KNN代码

'''
Iris数据集三类标签分别为Iris-setosa、Iris-versicolor、Iris-virginica
'''
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def k_nn(x):
    accurary = 0
    for i in range(150):
        count1 = 0
        count2 = 0
        count3 = 0
        prediction = 0
        distance = np.zeros((149, 2))
        # 取出一个测试集
        test = x[i].reshape(1, 5)
        # 把取出的测试集删掉,剩下的当作训练集
        train = np.delete(x, i, axis=0)
        # 把分类标签割掉
        test1 = test[:, 0:4]
        train1 = train[:, 0:4]
        # 测试集对每个训练集计算欧式距离
        for t in range(149):
            # 求欧式距离
            distance[t, 1] = np.linalg.norm(test1 - train1[t])
            # 存储标签
            distance[t, 0] = train[t, 4]
        # 按最后一列排序,即按照欧式距离进行排序
        order = distance[np.lexsort(distance.T)]
        # 类别计数
        for n in range(k):
            if order[n, 0] == 1:
                count1 += 1
            if order[n, 0] == 2:
                count2 += 1
            if order[n, 0] == 3:
                count3 += 1
        # 找出最大的类别数量
        if count1 >= count2 and count1 >= count3:
            prediction = 1
        if count2 >= count1 and count2 >= count3:
            prediction = 2
        if count3 >= count1 and count3 >= count2:
            prediction = 3
        # 预测正确
        if prediction == test[0, 4]:
            accurary += 1
    Accuracy = accurary / 150
    return Accuracy


iris = pd.read_csv('Iris.csv', header=None, sep=',')

x = iris.iloc[1:, 1:5]
x = np.array(x, dtype=float)
a = np.full((50, 1), 1)
b = np.full((50, 1), 2)
c = np.full((50, 1), 3)

Res = np.zeros(50)

d = np.vstack((a, b, c))
# 将数据集中的标签更换为1,2,3
x = np.hstack((x, d))

for m in range(50):
    k = m + 1
    Res[m] = k_nn(x)

# 画图
k_label = np.arange(1, 51, 1)
plt.xlabel('k Value')
plt.ylabel('Accuracy')
plt.ylim((.9, 1))
plt.plot(k_label, Res, 'b')
plt.show()

最后给大家看看不同的K值和正确率的关系图:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值