1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
import time
import matplotlib import numpy as np from matplotlib.colors import ListedColormap
matplotlib.rcParams['font.sans-serif'] = ['SimHei'] matplotlib.rcParams['font.family']='sans-serif' matplotlib.rcParams['axes.unicode_minus'] = False import matplotlib.pyplot as plt
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])
point_coordinates = np.array([2., 2.])
line_a, line_b = 7, 9 line_c = - line_b * line_a
bottom, up = -5, 10
xx_plane = np.linspace(bottom - 0.5, up + 0.5, 300) yy_plane = np.linspace(bottom - 0.5, up + 0.5, 300) xx_plane, yy_plane = np.meshgrid(xx_plane, yy_plane) class_plane = np.array([1 if np.dot(xi, point_coordinates) + 1 > 0 else 0 for xi in zip( xx_plane.ravel(), yy_plane.ravel())]).reshape(xx_plane.shape)
dots = np.array([np.array([ii, jj]) for ii in range(bottom, up + 1) for jj in range(bottom, up + 1) if ii != point_coordinates[0] or jj != point_coordinates[1]]) class_dots = np.array([1 if np.dot(xi, [line_a, line_b]) + line_c > 0 else 0 for xi in dots])
def window_cross(line_a, line_b, line_c): if line_a == 0 and line_b ==0: return (0, 0), (0, 0) elif line_a == 0: return (bottom, -line_c / line_b), (up, -line_c / line_b) elif line_b == 0: return (-line_c / line_a, bottom), (-line_c / line_a, up) else: c1 = bottom, - 1 / line_b * (line_c + line_a * bottom) c2 = up, - 1 / line_b * (line_c + line_a * up) c3 = -1 / line_a * (line_c + line_b * bottom), bottom c4 = -1 / line_a * (line_c + line_b * up), up cross_points = [c1, c2, c3, c4] cross_points.sort() return cross_points[1:3]
plt.figure(figsize=(8, 6)) plt.title('起始状态 同向') plt.pcolormesh(xx_plane, yy_plane, class_plane, cmap=cmap_light) plt.scatter(dots[:, 0], dots[:, 1], c=class_dots, cmap=cmap_bold) color_point = 'b' if point_coordinates[0] * line_a + point_coordinates[1] * line_b + line_c > 0 else 'r' plt.scatter(point_coordinates[0], point_coordinates[1], c=color_point, marker="v", s=100)
plt.plot([bottom - 0.5, up + 0.5], [0, 0], c='k') plt.plot([0, 0], [bottom - 0.5, up + 0.5], c='k') cross = window_cross(line_a, line_b, line_c) plt.plot([cross[0][0], cross[1][0]], [cross[0][1], cross[1][1]], c='g')
plt.close('all') plt.figure(figsize=(8, 6)) plt.ion()
for ii in range(200): plt.cla() plt.title(f'epoch={ii+1}: ({line_a}) * x + ({line_b}) * y + ({line_c}) = 0', fontsize=20) if ii < 8: time.sleep(0.4) else: time.sleep(0.02) plt.pcolormesh(xx_plane, yy_plane, class_plane, cmap=cmap_light) plt.scatter(dots[:, 0], dots[:, 1], c=class_dots, cmap=cmap_bold) color_point = 'b' if point_coordinates[0] * line_a + point_coordinates[1] * line_b + line_c > 0 else 'r' plt.scatter(point_coordinates[0], point_coordinates[1], c=color_point, marker="v", s=100) plt.grid() plt.plot([bottom - 0.5, up + 0.5], [0, 0], c='k') plt.plot([0, 0], [bottom - 0.5, up + 0.5], c='k') cross = window_cross(line_a, line_b, line_c) plt.plot([cross[0][0], cross[1][0]], [cross[0][1], cross[1][1]], c='g') if ii < 8: time.sleep(0.3) else: time.sleep(0.02) plt.pause(0.01) line_a, line_b, line_c = line_a+point_coordinates[0], line_b+point_coordinates[0], line_c + 1 class_dots = np.array([1 if np.dot(xi, [line_a, line_b]) + line_c > 0 else 0 for xi in dots])
plt.ioff()
plt.show()
|