zoukankan      html  css  js  c++  java
  • 单层感知器

    单层感知器是神经网络的入门常识,基本的单层感知器可以解决线性分类问题。这里我们通过实例体验感知器是如何运作的。本次实例参照教材《MATLAB神经网络原理与实例精解》。

    单层感知器的基本结构

    单层感知器基本结构

    如图,单层感知器可以有多个输入,它们通过与权值相乘,再相加(即加权求和)后,经过一定的偏置,再由激活函数处理,最后输出得到预测结果。这里面存在两种变化:线性变化与非线性变化。其中,加权求和属于线性变化,激活函数做的是非线性变化。通过上述两种变化 ,可以把输入的数据空间扭曲,使得只需要一个超平面就可以将其分开(线性可分),从而达到分类的目的。

    单层感知器的工作原理

    单层感知器工作原理

    与其他的优化算法一样,感知器做的工作就是不断的调整权值,使得输入的数据空间扭曲到适当的程度,然后再利用超平面一刀切开,达到二分类的效果。所有的算法都会有一个迭代终止指标,对于单层感知器来说,当输出的预测值与期望值之间的误差达到一定的精度要求,或者迭代次数超过一定的次数时(计算机也不可以无限的运行下去),算法结束。

    单层感知器解决坐标的二分类问题

    我们给出6个点的坐标,并给每个点的坐标设置分类,标签为0(第一类)和1(第二类)。利用单层感知器,找到一个超平面(就是一根直线)将两类坐标分开(即两类坐标分别处在直线的两边)。

    六坐标二分类问题

    代码实现

    Python代码

    import numpy as np
    import matplotlib.pyplot as plt
    
    # 参数初始化
    n = 0.2  # 学习率
    w = np.array([0, 0, 0])  # 权值
    p = np.array([[-9,  1, -12, -4,  0, 5],
                  [15, -8,   4,  5, 11, 9]])   # 坐标
    d = np.array([0, 1, 0, 0, 0, 1])  # 坐标分类标签
    P = np.vstack((np.ones((1, 6)), p))  # 输入矩阵
    MAX = 20    # 最大迭代次数
    ee = []    # 误差
    i = 0  # 记录迭代次数
    
    
    # 定义激活函数
    def hardlim(a):
        for i in range(len(a)):
            if a[i] >= 0:
                a[i] = 1
            else:
                a[i] = 0
        return a
    
    
    # 定义平均绝对误差
    def mae(a):
        return sum(abs(a))/len(a)
    
    
    while 1:
        v = np.matmul(w, P)
        y = hardlim(v)  # 实际输出
        # 更新
        e = (d - y)
        ee.append(mae(e))
        if ee[i] < 0.001:
            print('we have got it:')
            print(w)
            break
        w = w + n*np.matmul(d-y,P.T)
        i = i + 1
        if i >= MAX:
            print('MAX times loop')
            print(w)
            print(ee[i])
            break
    
    # 画图
    plt.figure()
    plt.rcParams['font.sans-serif'] = ['Simhei']
    plt.rcParams['axes.unicode_minus'] = False
    plt.subplot(211)
    plt.xlim(-13, 6)
    plt.ylim(-10, 16)
    plt.plot([-9, -12, -4, 0], [15, 4, 5, 11], 'o', label='第一类')
    plt.plot([1, 5], [-8, 9], '*', label='第二类')
    plt.legend(loc='lower right')
    plt.title('6个坐标点的二分类')
    x = np.arange(-13, 6, 0.2)
    y = x * (-w[1]/w[2]) - w[0]/w[2]
    plt.plot(x, y)
    
    plt.subplot(212)
    x = np.arange(0,len(ee))
    plt.plot(x, ee, 'o-')
    plt.title('mae的值(迭代次数:%.0f)'%len(ee))
    plt.subplots_adjust(wspace =0, hspace =0.5)
    plt.show()
    

    输出画面

    Python输出画面

    Matlab代码

    % perception_hand.m
    %% 清理
    clear,clc
    close all
    
    %%
    n=0.2;                  % 学习率
    w=[0,0,0]; 
    P=[ -9,  1, -12, -4,   0, 5;...
       15,  -8,   4,  5,  11, 9];
    d=[0,1,0,0,0,1];        % 期望输出
    
    P=[ones(1,6);P];
    MAX=20;                 % 最大迭代次数为20次
    %% 训练
    i=0;
    while 1
        v=w*P; 
        y=hardlim(v);       % 实际输出
        %更新
        e=(d-y);
        ee(i+1)=mae(e);
        if (ee(i+1)<0.001)   % 判断
            disp('we have got it:');
            disp(w);
            break;
        end
        % 更新权值和偏置
        w=w+n*(d-y)*P';
        
        if (i>=MAX)         % 达到最大迭代次数,退出
            disp('MAX times loop');
            disp(w);
            disp(ee(i+1));
           break; 
        end
        i= i+1;
    end
    
    
    %% 显示
    figure;
    subplot(2,1,1);         % 显示待分类的点和分类结果
    plot([-9 ,  -12  -4    0],[15, 4   5   11],'o');
    hold on;
    plot([1,5],[-8,9],'*');
    axis([-13,6,-10,16]);
    legend('第一类','第二类');
    title('6个坐标点的二分类');
    x=-13:.2:6;
    y=x*(-w(2)/w(3))-w(1)/w(3);
    plot(x,y);
    hold off;
    
    subplot(2,1,2);         % 显示mae值的变化
    x=0:i;
    plot(x,ee,'o-');
    s=sprintf('mae的值(迭代次数:%d)', i+1);
    title(s);
    

    输出画面

    Matlab输出画面

  • 相关阅读:
    UVa 1349 (二分图最小权完美匹配) Optimal Bus Route Design
    UVa 1658 (拆点法 最小费用流) Admiral
    UVa 11082 (网络流建模) Matrix Decompressing
    UVa 753 (二分图最大匹配) A Plug for UNIX
    UVa 1451 (数形结合 单调栈) Average
    UVa 1471 (LIS变形) Defense Lines
    UVa 11572 (滑动窗口) Unique Snowflakes
    UVa 1606 (极角排序) Amphiphilic Carbon Molecules
    UVa 11054 Wine trading in Gergovia
    UVa 140 (枚举排列) Bandwidth
  • 原文地址:https://www.cnblogs.com/gshang/p/10959749.html
Copyright © 2011-2022 走看看