zoukankan      html  css  js  c++  java
  • 《深度学习框架PyTorch:入门与实践》的Loss函数构建代码运行问题

    在学习陈云的教程《深度学习框架PyTorch:入门与实践》的损失函数构建时代码如下:

    可我运行如下代码:

    output = net(input)
    target = Variable(t.arange(0,10))  
    criterion = nn.MSELoss()
    loss = criterion(output, target)
    loss
    

    运行结果:

    RuntimeError                              Traceback (most recent call last)
    <ipython-input-37-e5c73861a53b> in <module>()
          2 target = Variable(t.arange(0,10))
          3 criterion = nn.MSELoss()
    ----> 4 loss = criterion(output, target)
          5 loss
    
    RuntimeError: Expected object of type torch.FloatTensor but found type torch.LongTensor for argument #2 'target'
    

    根据stackoverflo的问题Pytorch: Convert FloatTensor into DoubleTensorPyTorch(总)——PyTorch遇到令人迷人的BUG与记录,用torch.from_numpy(Y).float()这样的形式修改下target的类型。

    #torch.from_numpy(Y).float()
    output = net(input)
    y = np.arange(0,10).reshape(1,10)
    target = Variable(t.from_numpy(y).float())  
    criterion = nn.MSELoss()
    loss = criterion(output, target)
    loss
    

    运行结果:

    tensor(28.5897, grad_fn=<MseLossBackward>)
    

    同样的,后面优化器Optim代码中target也是出现这样的错误:

    import torch.optim as optim
    #新建一个优化器,指定要调整的参数和学习率
    optimizer = optim.SGD(net.parameters(),lr=0.01)
    
    #在训练过程中
    #先梯度清零(与net.zero_grad()效果一样)
    optimizer.zero_grad()
    
    #计算损失
    output = net(input)
    #把target改为Variable(t.from_numpy(y).float())就不会出错了
    loss = criterion(output, target)
    
    #反向传播
    loss.backward()
    
    #更新参数
    optimizer.step()
    

    运行结果:

    修改targetVariable(t.from_numpy(y).float())后成功运行:

    import torch.optim as optim
    #新建一个优化器,指定要调整的参数和学习率
    optimizer = optim.SGD(net.parameters(),lr=0.01)
    
    #在训练过程中
    #先梯度清零(与net.zero_grad()效果一样)
    optimizer.zero_grad()
    
    #计算损失
    output = net(input)
    #把target改为Variable(t.from_numpy(y).float())就不会出错了
    loss = criterion(output, Variable(t.from_numpy(y).float()))
    
    #反向传播
    loss.backward()
    
    #更新参数
    optimizer.step()
    
  • 相关阅读:
    session监听
    Ubuntu上安装MongoDB(转)
    JAVA中的集合(转)
    Iterator的用法(转)
    PHPExcel常用方法汇总(转)
    MongoDB的安装及在PHP中的配置Windows版
    [转载]使用FastReport 3.0及以上版本创建动态报表的几个技巧
    FastReport
    ZeosLib
    [转载]FastReport问题整理
  • 原文地址:https://www.cnblogs.com/HongjianChen/p/9445011.html
Copyright © 2011-2022 走看看