zoukankan      html  css  js  c++  java
  • 人工智能实战第五次作业_李大

    项目 内容
    课程 人工智能实战2019
    作业要求 人工智能实战第五次作业
    我的课程目标 第一次作业 介绍自己,提出课程项目建议
    此作业帮助 熟悉反向传播,sigmoid激活函数
    我的Github主页 LeeDua

    代码

    import numpy as np
    import matplotlib.pyplot as plt
    
    def Sigmiod(x):
        A = 1/(1+np.exp(-x))
        return A
    
    def ForwardCalculation(w,b,x):
        z = np.dot(w,x) + b
        A = Sigmiod(z)
        return z, A
    
    def BackPropagation(x,y,A,m):
        dZ = A - y
        dB = dZ.sum(axis = 1,keepdims = True)/m
        dW = np.dot(dZ, x.T)/m
        return dW, dB
    
    def UpdateWeights(w, b, dW, dB, eta):
        w = w - eta*dW
        b = b - eta*dB
        return w,b
    
    def CheckLoss(A, Y, m):
        Loss = np.sum(-(np.multiply(Y, np.log(A)) + np.multiply((1 - Y), np.log(1 - A))))/m
        return Loss
    
    def ShowResult(X, Y, w, b, m):
        for i in range(m):
            if Y[i] == 0:
                plt.plot(X[0,i], X[1,i], '^', c='r')
            else:
                plt.plot(X[0,i], X[1,i], 'x', c='g')
        x = np.linspace(-0.1,1.1,100)
        y = - (w[0,0] / w[0,1]) * x - (b[0,0] / w[0,1])
        plt.plot(x,y)
        plt.axis([-0.1,1.1,-0.1,1.1])
        plt.show()
    
    
    def GetSample(GateType):
        X_And = np.array([0, 0, 1, 1, 0, 1, 0, 1]).reshape(2, 4)
        y_And = np.array([0, 0, 0, 1])
    
        X_Or = np.array([0,0,1,1,0,1,0,1]).reshape(2,4)
        y_Or = np.array([0,1,1,1])
        if GateType == "or":
            return X_Or,y_Or
        elif GateType == "and":
            return X_And,y_And
    
    if __name__ == '__main__':
        eta = 0.1
        eps = 1e-4
        max_epoch = 10000
        loss = 1
        X,Y = GetSample("and")
    
        num_features = X.shape[0]
        num_example = X.shape[1]
        w = np.zeros((1,num_features))
        b = np.zeros((1,1))
        for epoch in range(max_epoch):
            Z, A = ForwardCalculation(w,b,X)
            dW, dB = BackPropagation(X,Y,A,num_example)
            w, b = UpdateWeights(w, b, dW, dB, eta)
            loss = CheckLoss(A, Y, num_example)
            print(epoch,loss)
            if loss < eps:
                break
    
        ShowResult(X, Y, w, b, num_example)
    

    结果

    • 与门 iter:9999 Loss:0.017445056023741874

    • 或门 iter:9999 Loss:0.009305581322240441

    • 可以看到10000次迭代后都能很好地达到预期分类效果

  • 相关阅读:
    Python阶段复习
    Python阶段复习
    Python学习笔记
    Python爬虫学习
    Python爬虫学习
    Python学习笔记
    史上最全的Maven Pom文件标签详解
    css3 animation动画技巧
    常用的sass编译库
    compass做雪碧图
  • 原文地址:https://www.cnblogs.com/lixiaoda/p/10669658.html
Copyright © 2011-2022 走看看