zoukankan      html  css  js  c++  java
  • pytorch中多个loss回传的参数影响示例

    写了一段代码如下:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class Test(nn.Module):
        def __init__(self):
            super(Test, self).__init__()
            self.fc1 = nn.Linear(5, 4)
            self.fc2 = nn.Linear(4, 3)
            self.fc3 = nn.Linear(4, 3)
    
        def forward(self, x):
            mid = self.fc1(x)
            out1 = self.fc2(mid)
            out2 = self.fc3(mid)
            return out1, out2
    
    
    x = torch.randn((3, 5))
    y = torch.torch.randint(3, (3,), dtype=torch.int64)
    model = Test()
    model.train()
    optim = torch.optim.RMSprop(model.parameters(), lr=0.001)
    
    print(model.fc2.weight)
    print(model.fc3.weight)
    for i in range(5):
        out1, out2 = model(x)
        loss1 = F.cross_entropy(out1, y)
        loss2 = F.cross_entropy(out2, y)
        loss = loss1 + loss2
        optim.zero_grad()
        loss.backward()
        optim.step()
    print("-------------after-----------")
    print(model.fc2.weight)
    print(model.fc3.weight)

    在loss.backward()处分别更换为loss1.backward()和loss2.backward(),观察fc2和fc3层的参数变化。

    得出的结论为:loss2只影响fc3的参数,loss1只影响fc2的参数。

    (粗略分析,抛砖引玉)

  • 相关阅读:
    浅尝《Windows核心编程》之 等待函数
    linux 下 解压rar的过程
    一些多线程编程的例子(转)
    js数组操作《转》
    缩略图片处理<收藏>
    .net 框架
    详解NeatUpload上传控件的使用
    NHibernate工具
    xml xpath语法《转》
    C#事务技术
  • 原文地址:https://www.cnblogs.com/peony-jing/p/14462289.html
Copyright © 2011-2022 走看看