zoukankan      html  css  js  c++  java
  • 神经网络学习之----线性神经网络,delta学习规则,递归下降法(代码实现)

    线性神经网络解决异或问题:

    import numpy as np
    import matplotlib.pyplot as plt
    
    
    # In[7]:
    
    #输入数据
    X = np.array([[1,0,0,0,0,0],
                  [1,0,1,0,0,1],
                  [1,1,0,1,0,0],
                  [1,1,1,1,1,1]])
    #标签
    Y = np.array([-1,1,1,-1])
    #权值初始化,1行3列,取值范围-1到1
    W = (np.random.random(6)-0.5)*2
    print(W)
    #学习率设置
    lr = 0.11
    #计算迭代次数
    n = 0
    #神经网络输出
    O = 0
    
    def update():
        global X,Y,W,lr,n
        n+=1
        O = np.dot(X,W.T)
        W_C = lr*((Y-O.T).dot(X))/int(X.shape[0])
        W = W + W_C
    
    
    # In[10]:
    
    for _ in range(100000):
        update()#更新权值
        #-0.1,0.1,0.2,-0.2
        #-1,1,1,-1
    
    
    #正样本
    x1 = [0,1]
    y1 = [1,0]
    #负样本
    x2 = [0,1]
    y2 = [0,1]
    
    def calculate(x,root):
        a = W[5]
        b = W[2]+x*W[4]
        c = W[0]+x*W[1]+x*x*W[3]
        if root==1:
            return (-b+np.sqrt(b*b-4*a*c))/(2*a)
        if root==2:
            return (-b-np.sqrt(b*b-4*a*c))/(2*a)
        
    
    xdata = np.linspace(-1,2)
    
    plt.figure()
    
    plt.plot(xdata,calculate(xdata,1),'r')
    plt.plot(xdata,calculate(xdata,2),'r')
    
    plt.plot(x1,y1,'bo')
    plt.plot(x2,y2,'yo')
    plt.show()
    
    print(W)
    
    
    # In[15]:
    
    O = np.dot(X,W.T)
    print(O)
  • 相关阅读:
    Anniversary party(树形DP入门)
    Neither shaken nor stirred(DFS理解+vector存图)
    统计单词数
    洛谷---三连击
    Educational Codeforces Round 68 (Rated for Div. 2)---B
    HDU-1201--18岁生日
    HDU-盐水的故事
    Flower(规律+逆向思维)
    The puzzle
    XOR Clique(按位异或)
  • 原文地址:https://www.cnblogs.com/mengqimoli/p/10348989.html
Copyright © 2011-2022 走看看