zoukankan      html  css  js  c++  java
  • Numpy梯度下降反向传播代码实现

    代码

    # -*- coding: utf-8 -*-
    import numpy as np
    
    # N是批量大小; D_in是输入维度;
    # 49/5000 H是隐藏的维度; D_out是输出维度。
    N, D_in, H, D_out = 64, 1000, 100, 10
    
    # 创建随机输入和输出数据
    x = np.random.randn(N, D_in)
    y = np.random.randn(N, D_out)
    
    # 随机初始化权重
    w1 = np.random.randn(D_in, H)
    w2 = np.random.randn(H, D_out)
    
    learning_rate = 1e-6
    for t in range(500):
        # 前向传递:计算预测值y
        h = x.dot(w1)
        h_relu = np.maximum(h, 0)
        y_pred = h_relu.dot(w2)
    
        # 计算和打印损失loss
        loss = np.square(y_pred - y).sum()
        print(t, loss)
    
        # 反向传播,计算w1和w2对loss的梯度
        grad_y_pred = 2.0 * (y_pred - y)
        grad_w2 = h_relu.T.dot(grad_y_pred)
        grad_h_relu = grad_y_pred.dot(w2.T)
        grad_h = grad_h_relu.copy()
        grad_h[h < 0] = 0
        grad_w1 = x.T.dot(grad_h)
    
        # 更新权重
        w1 -= learning_rate * grad_w1
        w2 -= learning_rate * grad_w2
    

    这段代码是我随便找的,包含一个隐藏层,很简单,就以这个作为举例。

    反向传播

          先看下正向传播:

    $$h = xw^{1}$$
    $$h\_relu = ReLU(h)$$
    $$y\_pred=h\_relu · w^{2}$$
    $$loss=(y\_pred-y)^2$$
          当我们反向传播时,需要从Output Layer层开始,利用链式求导法则,一步一步求导计算。
          E.g. 计算loss对$w^2$的偏导过程如下:
    $$frac{partial loss}{partial w^2} = frac{partial loss}{partial y\_pred}frac{partial y\_pred}{partial w^2}=2(y\_pred-y)·h\_relu$$
           然而,虽然推导出来了,但是用代码实现时可能又会遇到困难,不知道谁在前谁在后,而且往往还需要转置。最好的解决办法其实就是看维度,需要记住的是,向量对标量求导的结果的维度和向量的维度是一致的。
           故在上式中,$frac{partial loss}{partial w^2}$的维度是$(100,10)$,$frac{partial loss}{partial y\_pred}$的维度是$(64,10)$,$frac{partial y\_pred}{partial w^2}$的维度是$(64,100)$。这两者相乘后的维度得为$(100, 10)$,那就只有将后者转置后相乘,即$(64,100)^T·(64,10)$。写成代码就正好是:
     grad_w2 = h_relu.T.dot(grad_y_pred)
    

          其余的推导皆是如此。可以看到手动实现反向传播是十分麻烦的,层数一多根本不可能自己一个一个去算,所以后面需要用到自动求导。

     
     
     
    参考:
  • 相关阅读:
    exiting pxe rom 无法启动
    nginx 动静分离
    tomcat apr 部署
    zabbix_agentd.conf配置文件详解
    Zabbix点滴记录
    zabbix监控haproxy
    Zabbix使用Omsa来监控Dell服务器的硬件状态
    Zabbix监控Zookeeper健康状况
    Redis 多数据库
    Zabbix实现自动发现端口并监控
  • 原文地址:https://www.cnblogs.com/zyb993963526/p/13741577.html
Copyright © 2011-2022 走看看