KNN实现对digits数据集分类
KNN算法介绍:
数据的导入:
导入包
1 | from sklearn import datasets |
导入数据集
1 | # 手写字体的数据集导入 |
数据集介绍:
数据集情况:1797条数据
1 | data.shape, target.shape |
对于导入的数据集data里面的每个数据的形状是(64,),我们可以将其转化为8X8像素的数据,将第一个数据进行可视化展示:
形状转换:
1 | ima = data[0].reshape(8, 8) |
可视化:
1 | plt.imshow(ima) |
数据集的分割:
1 | x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=1) |
定义KNN函数:
1 | def knn_code(loc, k=5, order=2): # k order是超参 |
评估准确率:
1 | res = [] |
完整代码:
1 | #!/usr/bin/env python |