zoukankan      html  css  js  c++  java
  • 《深度学习框架PyTorch:入门与实践》的Loss函数构建代码运行问题

    在学习陈云的教程《深度学习框架PyTorch:入门与实践》的损失函数构建时代码如下:

    可我运行如下代码:

    output = net(input)
    target = Variable(t.arange(0,10))  
    criterion = nn.MSELoss()
    loss = criterion(output, target)
    loss
    

    运行结果:

    RuntimeError                              Traceback (most recent call last)
    <ipython-input-37-e5c73861a53b> in <module>()
          2 target = Variable(t.arange(0,10))
          3 criterion = nn.MSELoss()
    ----> 4 loss = criterion(output, target)
          5 loss
    
    RuntimeError: Expected object of type torch.FloatTensor but found type torch.LongTensor for argument #2 'target'
    

    根据stackoverflo的问题Pytorch: Convert FloatTensor into DoubleTensorPyTorch(总)——PyTorch遇到令人迷人的BUG与记录,用torch.from_numpy(Y).float()这样的形式修改下target的类型。

    #torch.from_numpy(Y).float()
    output = net(input)
    y = np.arange(0,10).reshape(1,10)
    target = Variable(t.from_numpy(y).float())  
    criterion = nn.MSELoss()
    loss = criterion(output, target)
    loss
    

    运行结果:

    tensor(28.5897, grad_fn=<MseLossBackward>)
    

    同样的,后面优化器Optim代码中target也是出现这样的错误:

    import torch.optim as optim
    #新建一个优化器,指定要调整的参数和学习率
    optimizer = optim.SGD(net.parameters(),lr=0.01)
    
    #在训练过程中
    #先梯度清零(与net.zero_grad()效果一样)
    optimizer.zero_grad()
    
    #计算损失
    output = net(input)
    #把target改为Variable(t.from_numpy(y).float())就不会出错了
    loss = criterion(output, target)
    
    #反向传播
    loss.backward()
    
    #更新参数
    optimizer.step()
    

    运行结果:

    修改targetVariable(t.from_numpy(y).float())后成功运行:

    import torch.optim as optim
    #新建一个优化器,指定要调整的参数和学习率
    optimizer = optim.SGD(net.parameters(),lr=0.01)
    
    #在训练过程中
    #先梯度清零(与net.zero_grad()效果一样)
    optimizer.zero_grad()
    
    #计算损失
    output = net(input)
    #把target改为Variable(t.from_numpy(y).float())就不会出错了
    loss = criterion(output, Variable(t.from_numpy(y).float()))
    
    #反向传播
    loss.backward()
    
    #更新参数
    optimizer.step()
    
  • 相关阅读:
    Linux Mint---shutter截图软件
    Linux Mint---fcitx中文,日语输入法
    Linux Mint---安装docky
    Linux Mint---开启桌面三维特效
    Linux Mint---ATI显卡驱动安装篇
    Linux Mint---更新软件源
    Spring Cloud 微服务服务间调用session共享问题
    Jooq批量插入 batch
    idea安装SonarLint语法检测插件
    JVM到底是什么?
  • 原文地址:https://www.cnblogs.com/HongjianChen/p/9445011.html
Copyright © 2011-2022 走看看