zoukankan      html  css  js  c++  java
  • Perceptron实践

    一、题目:使用自己写的感知机Perceptron实现对鸢尾花数据的分类。

      数据来源:

    from sklearn.datasets import load_iris
    dataSet = load_iris()

    二、注意事项:

      1)该数据中一共有三类,而感知机是个二分类器,因此,可以将该数据两两分成三次。

      2)第二类数据与第三类数据很相近,所以可以对数据进行升维处理,再做平方处理。

        例如:一种升维方式: 若x=(x1,x2,...xn) ,x_new=(x1*x1,x1*x2,...x1*xn,x2*x1,x2*x2,...x2*xn,....xn*xn)

            平方处理:  若x=(x1,x2,...xn),x_new=(x1*x1,x2*x2,x3*x3,...,xn*xn)

    三、代码实现

     1 from sklearn.datasets import load_iris
     2 import numpy as np
     3 dataSet = load_iris()
     4 data=dataSet['data'] #数据预处理
     5 labelNum=dataSet['target']
     6 label=dataSet['target_names']
     7 x=[];
     8 for i in range( 0,len(data) ):  #数据升维
     9     y=[]
    10     for j in range(0,len(data[i])):
    11         for k in range(0,len(data[i])):
    12             y.append(data[i][j]*data[i][k])
    13     x.append(y)
    14 data=np.array(x);
    15 data=data**2           
    16 def perceptron(data,label,eta=0.05,times=1000):
    17     b=0;
    18     w=np.zeros(len(data[0]))
    19     while(times>0):
    20         for i in range(0,len(label)):
    21             if( label[i]*(sum(data[i]*w)+b ) <= 0 ):
    22                 w=w+eta*data[i]*label[i]
    23                 b=b+eta*label[i];
    24         times-=1;
    25     return w,b
    26 def multi_classifier(data,labelNum,label):   
    27     w={}            
    28     b={} 
    29     #多个感知机
    30     for i in range(0,len(label)):
    31         labelDuplicate=labelNum.copy();
    32         #标签分类
    33         for j in range(0,len(labelDuplicate)):
    34             if(labelDuplicate[j]!=i):
    35                 labelDuplicate[j]=-1;
    36             else:
    37                 labelDuplicate[j]=1;
    38         if(i==1):
    39             w[ label[i] ],b[ label[i] ]=perceptron(data,labelDuplicate,times=1000);
    40         else:
    41             w[ label[i] ],b[ label[i] ]=perceptron(data,labelDuplicate,times=100);
    42     return w,b
    43 def classify(data,w,b,label,labelNum):
    44     sumLabel={};    #每类总数
    45     corLabel={};  #正确数量
    46     for j in label: #初始化
    47         sumLabel[j]=0
    48         corLabel[j]=0
    49         
    50     for i in range(0,len(data)):
    51         print('',i+1,'个样本')
    52         for j in label:
    53             if( sum(w[j]*data[i])+b[j]>=0  ):
    54                 if(j==label[labelNum[i]]):
    55                     corLabel[j]+=1;
    56                 sumLabel[  label[ labelNum[i]]  ]+=1
    57                 print(j)
    58                 break;
    59     for j in label:
    60         print(' {0} 类正确率:{1}'.format(j, corLabel[j]/ sumLabel[j]));
    61 w,b=multi_classifier(data,labelNum,label);
    62 classify(data,w,b,label,labelNum);

     四、运行结果

  • 相关阅读:
    Mybatis学习--spring和Mybatis整合
    MyBatis学习--查询缓存
    MyBatis学习--延迟加载
    MyBatis学习--高级映射
    Mybatis学习--Mapper.xml映射文件
    java文件上传和下载
    【计算机视觉】Object Proposal之BING理解
    【计算机视觉】Object Proposal之BING++
    【计算机视觉】Object Proposal之BING++
    【计算机视觉】Objectness算法(一)---总体理解,整理及总结
  • 原文地址:https://www.cnblogs.com/z-bear/p/9765458.html
Copyright © 2011-2022 走看看