zoukankan      html  css  js  c++  java
  • pytorch中多个loss回传的参数影响示例

    写了一段代码如下:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class Test(nn.Module):
        def __init__(self):
            super(Test, self).__init__()
            self.fc1 = nn.Linear(5, 4)
            self.fc2 = nn.Linear(4, 3)
            self.fc3 = nn.Linear(4, 3)
    
        def forward(self, x):
            mid = self.fc1(x)
            out1 = self.fc2(mid)
            out2 = self.fc3(mid)
            return out1, out2
    
    
    x = torch.randn((3, 5))
    y = torch.torch.randint(3, (3,), dtype=torch.int64)
    model = Test()
    model.train()
    optim = torch.optim.RMSprop(model.parameters(), lr=0.001)
    
    print(model.fc2.weight)
    print(model.fc3.weight)
    for i in range(5):
        out1, out2 = model(x)
        loss1 = F.cross_entropy(out1, y)
        loss2 = F.cross_entropy(out2, y)
        loss = loss1 + loss2
        optim.zero_grad()
        loss.backward()
        optim.step()
    print("-------------after-----------")
    print(model.fc2.weight)
    print(model.fc3.weight)

    在loss.backward()处分别更换为loss1.backward()和loss2.backward(),观察fc2和fc3层的参数变化。

    得出的结论为:loss2只影响fc3的参数,loss1只影响fc2的参数。

    (粗略分析,抛砖引玉)

  • 相关阅读:
    权限系统设计-day02
    权限系统设计-day01
    SSH集成(Struts+Spring+Hibernate)
    Spring-day03
    Spring-day02
    神经网络与深度学习
    深度学习概述
    结构化机器学习项目
    无监督学习方法
    《2018自然语言处理研究报告》整理(附报告)
  • 原文地址:https://www.cnblogs.com/peony-jing/p/14462289.html
Copyright © 2011-2022 走看看