zoukankan      html  css  js  c++  java
  • 【感知机模型】手写代码训练 / 使用sklearn的Perceptron模块训练

    读取原始数据

    import pandas as pd
    import numpy as np
    
    in_data = pd.read_table('./origin-data/perceptron_15.dat', sep='s+', header=None)
    X_train = np.array(in_data.loc[:,[0,1,2,3]])
    y_train = np.array(in_data[4])
    

    训练感知机模型

    class MyPerceptron:
      def __init__(self):
        self.w = None
        self.b = 0
        self.l_rate = 1
    
      def fit(self, X_train, y_train):
      #用样本点的特征数更新初始w,如x1=(3,3)T,有两个特征,则self.w=[0,0]
        self.w = np.zeros(X_train.shape[1])
        i = 0
        while i < X_train.shape[0]:
          X = X_train[i]
          y = y_train[i]
          # 如果y*(wx+b)≤0 说明是误判点,更新w,b
          if y * (np.dot(self.w, X) + self.b) <= 0:
            self.w += self.l_rate * np.dot(y, X)
            self.b += self.l_rate * y
            i=0 #如果是误判点,从头进行检测
          else:
            i+=1
    
    from sklearn.linear_model import Perceptron
    
    # 使用sklearn中的Perceptron类训练
    perceptron = Perceptron()
    time1 = datetime.datetime.now()
    perceptron.fit(X_train, y_train)
    time2 = datetime.datetime.now()
    print("共用时:", (time2-time1).microseconds, "微秒")
    print(perceptron.coef_)
    print(perceptron.intercept_)
    

    共用时: 4769 微秒
    [[ 2.9686576 -1.513057 2.211151 4.227677 ]]
    [-3.]

    # 使用自己写的MyPerceptron类训练
    perceptron = MyPerceptron()
    time1 = datetime.datetime.now()
    perceptron.fit(X_train, y_train)
    time2 = datetime.datetime.now()
    print("共用时:", (time2-time1).microseconds, "微秒")
    print(perceptron.w)
    print(perceptron.b)
    

    共用时: 12479 微秒
    [ 3.6161856 -2.013502 3.123158 5.49830856]
    -4

  • 相关阅读:
    文件光标移动
    python的版本的差别 "2","3"
    java通过jdbc操作Excel
    qt通过odbc操作Excel
    qt读取oracle表数据
    virtual box安装oracle_rac_10g
    oracle rac +standby
    rac不完全恢复
    rac完全恢复学习
    oracle rac搭建(三)--安装中的问题
  • 原文地址:https://www.cnblogs.com/yanqiang/p/12038266.html
Copyright © 2011-2022 走看看