zoukankan      html  css  js  c++  java
  • 使用感知机训练加法模型

    感知机此处不介绍,这里只是简单的做了一个使用感知机思路,训练一个y=a+b计算模型. 

     1 # -*-coding:utf-8-*-
     2 '@author: xijun.gong'
     3 import numpy as np
     4 import random
     5 import math
     6 
     7 
     8 class Perceptron:
     9     def __init__(self, learnRate, maxIter, bit_len):
    10         """
    11         :param bit_len
    12         :param learnRate:
    13         :param maxIter:  最大迭代次数
    14         """
    15         self.learmRate = learnRate;
    16         self.weight = None;
    17         self.maxIter = maxIter;
    18         # produce map
    19         self.bit_len = bit_len;
    20         self.nummap = None;
    21         self.initMap()
    22         pass
    23 
    24     def initMap(self):
    25         maxNum = (1 << self.bit_len);  # 该位数下的最大值
    26         self.nummap = np.zeros((maxNum, self.bit_len), dtype=np.int);  # include zero
    27         for _id in xrange(maxNum):
    28             for index in xrange(self.bit_len):
    29                 self.nummap[_id][index] = 1 & (_id >> index);
    30         pass
    31 
    32     def initWeight(self):
    33         """
    34         :return:
    35         """
    36         self.weight = np.ones(self.bit_len) / self.bit_len;
    37 
    38     def fit(self, fds, labels):
    39         """
    40         :param fds: 训练样本集合
    41         :param labels:
    42         :return:
    43         """
    44         feature_nums = fds.shape[1]  # 样本中的特征参数数量
    45         self.initWeight()
    46         for iter in xrange(self.maxIter):
    47             print 'train as iter is {} '.format(iter)
    48             acc_cnt = 0
    49             for _ind, sample in enumerate(fds):
    50                 a = self.nummap[int(sample[0])];
    51                 b = self.nummap[int(sample[1])];
    52                 label_y = sum(self.weight * (a + b))
    53                 # 计算var_w 表示倒三角w
    54                 print 'the reality:{} , predict {}'.format(labels[_ind], label_y);
    55                 if math.fabs(labels[_ind] - label_y) <= 0.000001:
    56                     acc_cnt += 1;
    57                     continue;
    58                 var_w = self.learmRate * (labels[_ind] - label_y) * (a + b)
    59                 self.weight += var_w;
    60             print 'accuary is {}'.format(acc_cnt / (len(fds) * 1.0))
    61             if acc_cnt == len(fds):
    62                 np.save('weight.npy', {'weight': self.weight});
    63                 return;
    64         pass
    65 
    66     def load(self, path='weight.npy'):
    67         return np.load(path)['weight']
    68 
    69     def predict(self, fd):
    70         a = self.nummap[fd[0]];
    71         b = self.nummap[fd[1]];
    72         return sum(self.weight * (a + b))
    73 
    74     def predict_prod(self):
    75         pass
    76 
    77 
    78 if __name__ == '__main__':
    79     import time
    80 
    81     perceptron = Perceptron(learnRate=0.01, maxIter=2000, bit_len=5);
    82     xa = np.arange(31);
    83     xb = np.zeros(31);
    84     labels = np.zeros(31)
    85     for i in xrange(31):
    86         xb[i] = random.randint(0, (int(time.time() + 1)) % 31)
    87         labels[i] = xb[i] + xa[i]
    88     perceptron.fit(np.array([xa, xb]).T, labels)
    89     print 'predict is {}'.format(perceptron.predict([24, 13]))

    运行结果:

    train as iter is 277 
    the reality:0.0 , predict 0.0
    the reality:16.0 , predict 16.0000005749
    the reality:16.0 , predict 15.9999994995
    the reality:3.0 , predict 3.00000059084
    the reality:18.0 , predict 17.999999818
    the reality:15.0 , predict 15.0000000195
    the reality:20.0 , predict 19.9999998534
    the reality:22.0 , predict 22.0000009642
    the reality:10.0 , predict 9.99999911021
    the reality:22.0 , predict 21.9999996143
    the reality:23.0 , predict 22.9999990943
    the reality:17.0 , predict 17.0000000549
    the reality:25.0 , predict 24.9999994128
    the reality:18.0 , predict 18.0000008934
    the reality:20.0 , predict 19.9999998534
    the reality:15.0 , predict 15.0000000195
    the reality:27.0 , predict 26.999999038
    the reality:31.0 , predict 30.9999993919
    the reality:25.0 , predict 25.0000003525
    the reality:21.0 , predict 20.9999999986
    the reality:35.0 , predict 34.9999997457
    the reality:29.0 , predict 28.9999993564
    the reality:39.0 , predict 38.9999996894
    the reality:26.0 , predict 26.0000009079
    the reality:31.0 , predict 30.9999993919
    the reality:25.0 , predict 24.9999990026
    the reality:33.0 , predict 32.9999994273
    the reality:32.0 , predict 31.9999999473
    the reality:32.0 , predict 31.9999991549
    the reality:34.0 , predict 34.0000002657
    the reality:33.0 , predict 32.9999994273
    accuary is 1.0
    predict is 36.9999984312
  • 相关阅读:
    Linux环境下安装RabbitMQ
    JSONP和HttpClient的区别
    Oracle中如何写存储过程
    Oracle数据库操作---基础使用(二)
    Oracle数据库操作---入门(一)
    Java使用递归的方法进行冒泡排序
    Linux常用操作指令
    windows 下rust安装工具链 下载加速
    ubuntu 非lvm 模式 扩充根目录
    CRC16 脚本 python
  • 原文地址:https://www.cnblogs.com/gongxijun/p/6653490.html
Copyright © 2011-2022 走看看