江东的笔记

Be overcome difficulties is victory

0%

Kmeans实战-实现二维的bolo分析

对二维的bolo数据集分析与可视化

首先导入包:

1
2
3
4
5
# 导入包
import matplotlib.pyplot as plt # 画图的包
import numpy as np
import pylab as mpl # import matplotlib as mpl
from sklearn.datasets import make_blobs # 产生数据集

默认设置:

1
2
mpl.rcParams['font.sans-serif'] = ['FangSong']  # 指定默认字体
mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题

进行初始化:

1
2
3
4
5
6
7
n_samples = 1500  # 生成1500个数据集
random_state = 170 # 170这个是随机种子
k = 3 # 超参数
np.random.seed(26) #给numpy设置一个随机种子,保证每次都能产生相同的值
X, y = make_blobs(n_samples=n_samples, random_state=random_state) # 生成数据集,包括1500个样本
ages = np.vstack((X[y == 0][:500], X[y == 1][:500], X[y == 2][:500])) # 将数据进行堆叠,shape为(1500, 2)
y = np.array(([0] * 500 + [1] * 500 + [2] * 500)) #生成0 1 2 各500个

迭代初始化:

1
2
3
4
centers = np.zeros([3, 2])    # 生成0矩阵
centers_random = np.random.choice(range(len(y)), 3) # 迭代起点
centers_new = ages[centers_random] # 随机选取中心
dis_to_cent = np.zeros((k, len(ages))) # 一个二维数据,相当于一个空的容器

实现预测:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
while not (centers_new == centers).all():
centers = centers_new.copy() # 注意python的赋值过程,进行展开讲解,== is 和复制方式
for ii in range(k):
dis_to_cent[ii] = np.linalg.norm(ages - centers[ii], axis=1) # 计算每个数值到中心的距离

clusters = dis_to_cent.argmin(axis=0) # 划分出每个类别

for ii in range(k): # 重新计算中心
cluster = ages[clusters == ii]
centers_new[ii] = ages[clusters == ii].mean(0)

print(centers, centers_new)
print(centers_new)
print('centers_new==centers?', (centers_new == centers).all())