江东的笔记

Be overcome difficulties is victory

0%

感知机对偶形式

对偶形式的学习

感知机:

在这里插入图片描述
符号函数:
在这里插入图片描述
选择误分类点到超平面的总距离作为损失函数:
距离:
在这里插入图片描述
误分类点:
在这里插入图片描述
误分类点距离
在这里插入图片描述
总距离
在这里插入图片描述

感知机对偶形式

在这里插入图片描述
在这里插入图片描述
过程
在这里插入图片描述
例题:
在这里插入图片描述
在这里插入图片描述
动态可视化代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# 以半动画的方式展示感知识机对偶问题的操作的合理性
# 给定初始点, 初始直线
import time

import matplotlib
import numpy as np
from matplotlib.colors import ListedColormap

matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.family']='sans-serif'
matplotlib.rcParams['axes.unicode_minus'] = False
import matplotlib.pyplot as plt


cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])

# 可视化展示
point_coordinates = np.array([2., 2.]) # 关键点坐标

line_a, line_b = 7, 9 # 初始线方程
line_c = - line_b * line_a # 注意负号

bottom, up = -5, 10 # 视窗

# 关键点对应的“基线”的方程的分类,以平面展示
xx_plane = np.linspace(bottom - 0.5, up + 0.5, 300)
yy_plane = np.linspace(bottom - 0.5, up + 0.5, 300)
xx_plane, yy_plane = np.meshgrid(xx_plane, yy_plane)
class_plane = np.array([1 if np.dot(xi, point_coordinates) + 1 > 0 else 0 for xi in zip(
xx_plane.ravel(), yy_plane.ravel())]).reshape(xx_plane.shape)

# 优化方程对应的分类,以整数点展示
dots = np.array([np.array([ii, jj]) for ii in range(bottom, up + 1) for jj in range(bottom, up + 1)
if ii != point_coordinates[0] or jj != point_coordinates[1]])
class_dots = np.array([1 if np.dot(xi, [line_a, line_b]) + line_c > 0 else 0 for xi in dots])


# 找到直线在视窗内的两个顶点 直线与视窗四线的交点的中间的两个
def window_cross(line_a, line_b, line_c):
if line_a == 0 and line_b ==0:
return (0, 0), (0, 0)
elif line_a == 0:
return (bottom, -line_c / line_b), (up, -line_c / line_b)
elif line_b == 0:
return (-line_c / line_a, bottom), (-line_c / line_a, up)
else:
c1 = bottom, - 1 / line_b * (line_c + line_a * bottom)
c2 = up, - 1 / line_b * (line_c + line_a * up)
c3 = -1 / line_a * (line_c + line_b * bottom), bottom
c4 = -1 / line_a * (line_c + line_b * up), up
cross_points = [c1, c2, c3, c4]
cross_points.sort()
return cross_points[1:3]


# 初始状态展示
plt.figure(figsize=(8, 6))
plt.title('起始状态 同向')
plt.pcolormesh(xx_plane, yy_plane, class_plane, cmap=cmap_light)
plt.scatter(dots[:, 0], dots[:, 1], c=class_dots, cmap=cmap_bold)
color_point = 'b' if point_coordinates[0] * line_a + point_coordinates[1] * line_b + line_c > 0 else 'r'
plt.scatter(point_coordinates[0], point_coordinates[1], c=color_point, marker="v", s=100)
# plt.grid()
plt.plot([bottom - 0.5, up + 0.5], [0, 0], c='k')
plt.plot([0, 0], [bottom - 0.5, up + 0.5], c='k')
cross = window_cross(line_a, line_b, line_c)
plt.plot([cross[0][0], cross[1][0]], [cross[0][1], cross[1][1]], c='g')


# 迭代过程的动态展示
plt.close('all')
plt.figure(figsize=(8, 6))
plt.ion()

for ii in range(200):
plt.cla()
plt.title(f'epoch={ii+1}: ({line_a}) * x + ({line_b}) * y + ({line_c}) = 0', fontsize=20)
if ii < 8:
time.sleep(0.4)
else:
time.sleep(0.02)
plt.pcolormesh(xx_plane, yy_plane, class_plane, cmap=cmap_light)
plt.scatter(dots[:, 0], dots[:, 1], c=class_dots, cmap=cmap_bold)
color_point = 'b' if point_coordinates[0] * line_a + point_coordinates[1] * line_b + line_c > 0 else 'r'
plt.scatter(point_coordinates[0], point_coordinates[1], c=color_point, marker="v", s=100)
plt.grid()
plt.plot([bottom - 0.5, up + 0.5], [0, 0], c='k')
plt.plot([0, 0], [bottom - 0.5, up + 0.5], c='k')
cross = window_cross(line_a, line_b, line_c)
plt.plot([cross[0][0], cross[1][0]], [cross[0][1], cross[1][1]], c='g')
# 暂停
if ii < 8:
time.sleep(0.3)
else:
time.sleep(0.02)
plt.pause(0.01)
# 迭代更新
line_a, line_b, line_c = line_a+point_coordinates[0], line_b+point_coordinates[0], line_c + 1
class_dots = np.array([1 if np.dot(xi, [line_a, line_b]) + line_c > 0 else 0 for xi in dots])

# 关闭交互模式
plt.ioff()

# 图形显示
plt.show()