zoukankan      html  css  js  c++  java
  • pytorch查看网络权重参数更新、梯度的小实例

    本文内容来自知乎:浅谈 PyTorch 中的 tensor 及使用

    首先创建一个简单的网络,然后查看网络参数在反向传播中的更新,并查看相应的参数梯度。

    # 创建一个很简单的网络:两个卷积层,一个全连接层
    class Simple(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False)
            self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1, bias=False)
            self.linear = nn.Linear(32*10*10, 20, bias=False)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.linear(x.view(x.size(0), -1))
            return x
    
    model = Simple()
    # 为了方便观察数据变化,把所有网络参数都初始化为 0.1
    for m in model.parameters():
        m.data.fill_(0.1)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
    
    model.train()
    # 模拟输入8个 sample,每个的大小是 10x10,
    # 值都初始化为1,让每次输出结果都固定,方便观察
    images = torch.ones(8, 3, 10, 10)
    targets = torch.ones(8, dtype=torch.long)
    
    output = model(images)
    print(output.shape)
    # torch.Size([8, 20])
    
    loss = criterion(output, targets)
    
    print(model.conv1.weight.grad)
    # None
    loss.backward()
    print(model.conv1.weight.grad[0][0][0])
    # tensor([-0.0782, -0.0842, -0.0782])
    # 通过一次反向传播,计算出网络参数的导数,
    # 因为篇幅原因,我们只观察一小部分结果
    
    print(model.conv1.weight[0][0][0])
    # tensor([0.1000, 0.1000, 0.1000], grad_fn=<SelectBackward>)
    # 我们知道网络参数的值一开始都初始化为 0.1 的
    
    optimizer.step()
    print(model.conv1.weight[0][0][0])
    # tensor([0.1782, 0.1842, 0.1782], grad_fn=<SelectBackward>)
    # 回想刚才我们设置 learning rate 为 1,这样,
    # 更新后的结果,正好是 (原始权重 - 求导结果) !
    
    optimizer.zero_grad()
    print(model.conv1.weight.grad[0][0][0])
    # tensor([0., 0., 0.])
    # 每次更新完权重之后,我们记得要把导数清零啊,
    # 不然下次会得到一个和上次计算一起累加的结果。
    # 当然,zero_grad() 的位置,可以放到前边去,
    # 只要保证在计算导数前,参数的导数是清零的就好。
    
  • 相关阅读:
    JVM 综述
    看 Netty 在 Dubbo 中如何应用
    Netty 心跳服务之 IdleStateHandler 源码分析
    Netty 高性能之道
    Netty 解码器抽象父类 ByteToMessageDecoder 源码解析
    Netty 源码剖析之 unSafe.write 方法
    Netty 出站缓冲区 ChannelOutboundBuffer 源码解析(isWritable 属性的重要性)
    Netty 源码剖析之 unSafe.read 方法
    Netty 内存回收之 noCleaner 策略
    Netty 源码阅读的思考------耗时业务到底该如何处理
  • 原文地址:https://www.cnblogs.com/picassooo/p/14153787.html
Copyright © 2011-2022 走看看