zoukankan      html  css  js  c++  java
  • 神经网络--异或问题

    前言:

    这个博客是为了解决异或问题,原理是利用非线性的量来进行划分,和前面的知识有些类似。

    正文:

    import numpy as np
    import matplotlib.pyplot as plt
    
    #输入数据
    X = np.array([[1,0,0,0,0,0],
                  [1,0,1,0,0,1],
                  [1,1,0,1,0,0],
                  [1,1,1,1,1,1]])
    #标签
    Y = np.array([-1,1,1,-1])
    #权值初始化,1行6列,取值范围-1到1
    W = (np.random.random(6)-0.5)*2
    
    print(W)
    #学习率设置
    lr = 0.11
    #计算迭代次数
    n = 0
    #神经网络输出
    O = 0
    #除以x.shape()防止每次更新权值过大
    #给x转置是为了符合矩阵乘法的规范
    def update():
        global X,Y,W,lr,n
        n+=1
        O = np.dot(X,W.T)
        W_C = lr*((Y-O.T).dot(X))/int(X.shape[0])
        W = W + W_C
    
    for _ in range(1000):
        update()#更新权值
    
    #正样本
    x1 = [0,1]
    y1 = [1,0]
    #负样本
    x2 = [0,1]
    y2 = [0,1]
    
    def calculate(x,root):
        a = W[5]
        b = W[2]+x*W[4]
        c = W[0]+x*W[1]+x*x*W[3]
        if root == 1:
            return (-b+np.sqrt(b*b-4*a*c))/(2*a)
        if root == 2:
            return (-b-np.sqrt(b*b-4*a*c))/(2*a)
    
    xdata = np.linspace(-1,2)
    plt.figure()
    plt.plot(xdata,calculate(xdata,1),'r')
    plt.plot(xdata,calculate(xdata,2),'r')
    plt.scatter(x1,y1,c='b')
    plt.scatter(x2,y2,c='y')
    plt.show()
    

    在这里插入图片描述

    总结:

    这个专门用来解决异或问题,和单层感知器的知识有所不同的是用了不同的激活函数,以及用n来计数,引入了6个输入量,相当于在求解一个二次方程(关于y的二次方程),再利用求根公式来进行画线。

  • 相关阅读:
    给JavaScript新手的24条实用建议
    javascript之HTML(select option)详解
    PHP的正则处理函数总结分析
    多级关联菜单:
    理解json两种结构:数组和对象
    dede标签学习笔记(一)
    Jewel_M PHP定时执行任务的实现
    网站刷新器
    PHP_SELF、 SCRIPT_NAME、 REQUEST_URI区别
    RemoveXSS()
  • 原文地址:https://www.cnblogs.com/lqk0216/p/12854687.html
Copyright © 2011-2022 走看看