zoukankan      html  css  js  c++  java
  • [转]统计学习方法—chapter2—感知机算法实现

    描述:李航《统计学习方法》第二章感知机算法实现(Python)

    原始形式:

     1 # _*_ encoding:utf-8 _*_
     2 
     3 import numpy as np
     4 import matplotlib.pyplot as plt
     5 
     6 
     7 def createdata():
     8     """创建数据集和相应类标记"""
     9     samples = np.array([[3, 3], [4, 3], [1, 1]])
    10     labels = np.array([1, 1, -1])
    11     return samples, labels
    12 
    13 
    14 
    15 class Perceptron:
    16     """感知机模型"""
    17 
    18     def __init__(self, x, y, a=1):
    19         self.x = x
    20         self.y = y
    21         self.w = np.zeros((x.shape[1], 1))
    22         self.b = 0
    23         self.a = 1  #学习率
    24         self.numsamples = self.x.shape[0]
    25         self.numfeatures = self.x.shape[1]
    26 
    27     def sign(self, w, b, x):
    28         """计算某样本点的f(x)"""
    29         y = np.dot(x, w) + b
    30         return int(y)
    31 
    32     def update(self, label_i, data_i):
    33         """更新w和b"""
    34         tmp = label_i * self.a * data_i
    35         tmp = tmp.reshape(self.w.shape)
    36         self.w = tmp + self.w
    37         self.b = self.b + label_i * self.a
    38 
    39     def train(self):
    40         """训练感知机模型"""
    41         isfind = False
    42         while not isfind:
    43             count = 0
    44             for i in range(self.numsamples):
    45                 tmp = self.sign(self.w, self.b, self.x[i, :])
    46                 if tmp * self.y[i] <= 0:
    47                     print('误分类点为: ', self.x[i, :], '此时的w和b为: ', self.w, self.b)
    48                     count += 1
    49                     self.update(self.y[i], self.x[i, :])
    50             if count == 0:
    51                 print('最终训练得到的w和b为: ', self.w, self.b)
    52                 isfind = True
    53         return self.w, self.b
    54 
    55 
    56 
    57 class Picture:
    58     """数据可视化"""
    59 
    60     def __init__(self, data, w, b):
    61         """初始化参数"""
    62         self.b = b
    63         self.w = w
    64         plt.figure(1)
    65         plt.title('Perceptron Learning Algorithm', size= 14)
    66         plt.xlabel('x0-axis', size=14)
    67         plt.ylabel('x1-axis', size=14)
    68 
    69         xData = np.linspace(0, 5, 100)
    70         yData = self.expression(xData)
    71         plt.plot(xData, yData, color='r', label='sample data')
    72 
    73         plt.scatter(data[0][0], data[0][1], c='r', s=50)
    74         plt.scatter(data[1][0], data[1][1], c='g', s=50)
    75         plt.scatter(data[2][0], data[2][1], s=50, c='b', marker='x')
    76  
    77         plt.savefig('original.png', dpi=75)
    78 
    79     def expression(self, x):
    80         """计算超平面上对应的纵坐标"""
    81         y = (-self.b - self.w[0] * x) / self.w[1]
    82         return y
    83 
    84     def show(self):
    85         """画图"""
    86         plt.show()
    87 
    88 
    89 if __name__ == '__main__':
    90     samples, labels = createdata()
    91     myperceptron = Perceptron(samples, labels)
    92     weights, bias = myperceptron.train()
    93     picture = Picture(samples, weights, bias)
    94     picture.show()

    对偶形式:

      1 # _*_ encoding:utf-8 _*_
      2 
      3 import numpy as np
      4 import matplotlib.pyplot as plt
      5 
      6 def createdata():
      7     """创建数据集和相应的类标记"""
      8     samples = np.array([[3, 3], [4, 3], [1, 1]])
      9     labels = np.array([1, 1, -1])
     10     return samples, labels
     11 
     12 
     13 class Perceptron:
     14     """感知机模型"""
     15 
     16     def __init__(self, x, y, a=1):
     17         """初始化数据集,标记,学习率,参数等"""
     18         self.x = x
     19         self.y = y
     20         self.w = np.zeros((1, x.shape[0]))
     21         self.b = 0
     22         self.a = a
     23         self.numsamples = self.x.shape[0]
     24         self.numfeatures = self.x.shape[1]
     25         self.gmatrix = self.gMatrix()
     26 
     27     def gMatrix(self):
     28         """计算Gram矩阵"""
     29         gmatrix = np.zeros((self.numsamples, self.numsamples))
     30         for i in range(self.numsamples):
     31             for j in range(self.numsamples):
     32                 gmatrix[i][j] = np.dot(self.x[i, :], self.x[j, :])
     33         return gmatrix
     34 
     35     def sign(self, i):
     36         """计算f(x)"""
     37         y = np.dot(self.w*self.y, self.gmatrix[:, i]) + self.b
     38         return int(y)
     39 
     40     def update(self, i):
     41         """更新w和b"""
     42         self.w[:, i] = self.w[:, i] + self.a
     43         self.b = self.b + self.a * self.y[i]
     44 
     45     def cal_w(self):
     46         """计算最终的w"""
     47         w = np.dot(self.w*self.y, self.x)
     48         return w
     49 
     50     def train(self):
     51         """感知机模型训练"""
     52         isfind = False
     53         while not isfind:
     54             count = 0
     55             for i in range(self.numsamples):
     56                 if self.y[i]*self.sign(i) <= 0:
     57                     count += 1
     58                     print('误分类点为: ', self.x[i, :], '此时w和b分别为: ', self.cal_w(), ', ', self.b)
     59                     self.update(i)
     60             if count == 0:
     61                 print('最终的w和b为: ', self. cal_w(), ', ', self.b)
     62                 isfind = True
     63         weights = self.cal_w()
     64         return weights, self.b
     65 
     66 
     67 class Picture:
     68     """数据可视化"""
     69 
     70     def __init__(self, data, w, b):
     71         """"初始化画图参数"""
     72         self.w = w
     73         self.b = b
     74         plt.figure(1)
     75         plt.title('Perceptron Learning Algorithm of Duality', size=20)
     76         plt.xlabel('X0-axis', size=14)
     77         plt.ylabel('X1-axis', size=14)
     78 
     79         xdata = np.linspace(1, 5, 100)
     80         ydata = self.expression(xdata)
     81         plt.plot(xdata, ydata, c='r')
     82 
     83         plt.scatter(data[0][0], data[0][1], s=50)
     84         plt.scatter(data[1][0], data[1][1], s=50)
     85         plt.scatter(data[2][0], data[2][1], s=50, marker='x')
     86         plt.savefig('test.png', dpi=95)
     87 
     88     def expression(self, xdata):
     89         """计算超平面上的纵坐标"""
     90         y = (-self.b - self.w[:, 0]*xdata) / self.w[:, 1]
     91         return y
     92 
     93     def show(self):
     94         """画图"""
     95         plt.show()
     96 
     97 
     98 if __name__ == '__main__':
     99     samples, labels = createdata()
    100     perceptron = Perceptron(x=samples, y=labels)
    101     weights, b = perceptron.train()
    102     picture = Picture(samples, weights, b)
    103     picture.show()

    参考自:https://blog.csdn.net/u010626937/article/details/72896144

  • 相关阅读:
    智器SmartQ T7实体店试用体验
    BI笔记之SSAS库Process的几种方案
    PowerTip of the Day from powershell.com上周汇总(八)
    PowerTip of the Day2010071420100716 summary
    PowerTip of the Day from powershell.com上周汇总(十)
    PowerTip of the Day from powershell.com上周汇总(六)
    重新整理Cellset转Datatable
    自动加密web.config配置节批处理
    与DotNet数据对象结合的自定义数据对象设计 (二) 数据集合与DataTable
    在VS2003中以ClassLibrary工程的方式管理Web工程.
  • 原文地址:https://www.cnblogs.com/OoycyoO/p/9538055.html
Copyright © 2011-2022 走看看