zoukankan      html  css  js  c++  java
  • 从感知机到支持向量机—学习笔记

    step 1

    用高斯分布生成两类点

     1 class Point3:
     2     def __init__(self):
     3         self.x = random.gauss(50, 10)
     4         self.y = random.gauss(50, 10)
     5 
     6         self.label = -1
     7         self.color = 'r'
     8 
     9 class Point4:
    10     def __init__(self):
    11         self.x = random.gauss(90, 10)
    12         self.y = random.gauss(90, 10)
    13 
    14         self.label = 1
    15         self.color = 'b'

    step 2

    画一条初始直线,先定义两个点(x1, 0)和(x2, 100),x1属于(0, 50),x2属于(50, 100),有了两个点之后,画出一条直线

     1 class Line:
     2     def __init__(self):
     3         self.x1 = random.randint(MIN, MAX//2)       # MAX=100 MIN=0 (0, 50)  随机生成一个整数
     4         self.x2 = random.randint(MAX//2, MAX)       # MAX=100 MIN=0 (50, 100)
     5         self.y1 = 0
     6         self.y2 = 100
     7 
     8         self.x = [self.x1, self.x2]
     9         self.y = [self.y1, self.y2]
    10 
    11         self.w1 = -(self.y2 - self.y1) / (self.x2 - self.x1)
    12         self.w2 = 1
    13         self.b = -(self.w1 * self.x1) + self.w2 * self.y1

    step 3
    判断误分类点
    正确分类1:w1*x+w2*y+b>0且label=1
    正确分类2:w1*x+w2*y+b<0且label=-1

    1  def sign(self, point):
    2         # print(self.w1 * point.x + self.w2 * point.y + self.b)
    3         # print(point.label * (self.w1 * point.x + self.w2 * point.y + self.b))
    4         return point.label * (self.w1 * point.x + self.w2 * point.y + self.b)

    step 4
    有了更新后的w1、w2和b之后,更新一条新的直线。
    首先,需要先找到两个点,此时y1=0, y2=100不变,则我们只需找到对应的x1,x2即可。

    1     def update(self):
    2         self.x1 = -self.b / self.w1
    3         self.x2 = (-self.b - self.w2 * self.y2) / self.w1
    4         self.x = [self.x1, self.x2]
    5         self.y = [self.y1, self.y2]

    step 5
    w1、w2和b的更新规则,参考博文支持向量机http://www.carefree0910.com/posts/d455305a/

     1 def preceptron_base_dis(all_points):
     2     line = Line()
     3     plt.plot(line.x, line.y, 'g--', linewidth=1)
     4     Flag = True
     5     while True:
     6         Flag = True
     7         for point in all_points:
     8             if line.sign(point) < 1:    # 只有误分类点才更新
     9                 line.w1 = (1 - step) * line.w1 + step * C * point.label * point.x
    10                 line.w2 = (1 - step) * line.w2 + step * C * point.label * point.y
    11                 line.b = line.b + step * C * point.label
    12                 Flag = False
    13         if Flag:
    14             break
    15         line.update()
    16         #plt.plot(l.x, l.y, 'y--', linewidth=1)
    17     plt.plot(line.x, line.y, '.-', linewidth=1)
    18     plt.show()

    全部代码汇总

      1 import matplotlib.pyplot as plt
      2 import numpy
      3 import random
      4 import sys
      5 
      6 MAX=100
      7 MIN=0
      8 POINT_NUM=20
      9 step=0.01
     10 C = 0.1
     11 
     12 class Point:
     13     def __init__(self):
     14         self.x = random.uniform(MIN, MAX)
     15         self.y = random.uniform(MIN, MAX)
     16 
     17         if self.x > self.y:
     18             self.label = 1
     19             self.color = 'b'
     20         else:
     21             self.label = -1
     22             self.color = 'r'
     23 class Point2:
     24     def __init__(self):
     25         self.x = random.randint(MIN, MAX)
     26         if self.x > MAX // 2:
     27             self.y = random.randint(0, MAX // 4)
     28         else:
     29             self.y = random.randint(MAX * 2 // 4, MAX)
     30 
     31         if self.x > self.y:
     32             self.label = 1
     33             self.color = 'b'
     34         else:
     35             self.label = -1
     36             self.color = 'r'
     37 
     38 class Point3:
     39     def __init__(self):
     40         self.x = random.gauss(50, 10)
     41         self.y = random.gauss(50, 10)
     42 
     43         self.label = -1
     44         self.color = 'r'
     45 
     46 class Point4:
     47     def __init__(self):
     48         self.x = random.gauss(90, 10)
     49         self.y = random.gauss(90, 10)
     50 
     51         self.label = 1
     52         self.color = 'b'
     53 class Line:
     54     def __init__(self):
     55         self.x1 = random.randint(MIN, MAX//2)       # MAX=100 MIN=0 (0, 50)  随机生成一个整数
     56         self.x2 = random.randint(MAX//2, MAX)       # MAX=100 MIN=0 (50, 100)
     57         self.y1 = 0
     58         self.y2 = 100
     59 
     60         self.x = [self.x1, self.x2]
     61         self.y = [self.y1, self.y2]
     62 
     63         self.w1 = -(self.y2 - self.y1) / (self.x2 - self.x1)
     64         self.w2 = 1
     65         self.b = -(self.w1 * self.x1) + self.w2 * self.y1
     66 
     67     def sign(self, point):
     68         # print(self.w1 * point.x + self.w2 * point.y + self.b)
     69         # print(point.label * (self.w1 * point.x + self.w2 * point.y + self.b))
     70         return point.label * (self.w1 * point.x + self.w2 * point.y + self.b)
     71 
     72     def update(self):
     73         self.x1 = -self.b / self.w1
     74         self.x2 = (-self.b - self.w2 * self.y2) / self.w1
     75         self.x = [self.x1, self.x2]
     76         self.y = [self.y1, self.y2]
     77 
     78 
     79 def initialPoint():
     80     plt.figure()
     81     all_point = []
     82     for idx in range(POINT_NUM):
     83         p = Point3()
     84         plt.plot(p.x, p.y, p.color + 'o', label="point")
     85         all_point.append(p)
     86 
     87     for idx in range(POINT_NUM):
     88         p = Point4()
     89         plt.plot(p.x, p.y, p.color + 'o', label="point")
     90         all_point.append(p)
     91     return all_point
     92 
     93 def preceptron_base_dis(all_points):
     94     line = Line()
     95     plt.plot(line.x, line.y, 'g--', linewidth=1)
     96     Flag = True
     97     while True:
     98         Flag = True
     99         for point in all_points:
    100             if line.sign(point) < 1:    # 只有误分类点才更新
    101                 line.w1 = (1 - step) * line.w1 + step * C * point.label * point.x
    102                 line.w2 = (1 - step) * line.w2 + step * C * point.label * point.y
    103                 line.b = line.b + step * C * point.label
    104                 Flag = False
    105         if Flag:
    106             break
    107         line.update()
    108         #plt.plot(l.x, l.y, 'y--', linewidth=1)
    109     plt.plot(line.x, line.y, '.-', linewidth=1)
    110     plt.show()
    111 
    112 def preceptron(all_points):
    113     line = Line()
    114     plt.plot(line.x, line.y, 'g--', linewidth=1)
    115     Flag = True
    116     while True:
    117         Flag = True
    118         for point in all_points:
    119             if line.sign(point) <= 0:
    120                 line.w1 += step * point.label * point.x
    121                 line.w2 += step * point.label * point.y
    122                 line.b += step * point.label
    123                 Flag = False
    124         if Flag:
    125             break
    126         line.update()
    127         #plt.plot(line.x, line.y, 'y--', linewidth=1)
    128     plt.plot(line.x, line.y, 'o-', linewidth=1)
    129     plt.show()
    130 
    131 all_points = initialPoint()
    132 preceptron_base_dis(all_points) 
  • 相关阅读:
    java将pdf转成base64字符串及将base64字符串反转pdf
    input校验不能以0开头的数字
    js校验密码,不能为空的8-20位非纯数字或字母的密码
    tomcat正常关闭,端口号占用解决 StandardServer.await: create[8005]:
    Eclipse中项目报Target runtime com.genuitec.runtime.generic.jee60 is not defined异常的解决
    Access restriction: The type Base64 is not accessible due to restriction on
    [操作系统] 线程和进程的简单解释
    ssh登录一段时间后断开的解决方案
    [SAMtools] 常用指令总结
    [C] 有关内存问题
  • 原文地址:https://www.cnblogs.com/Joyce-song94/p/7594806.html
Copyright © 2011-2022 走看看