zoukankan      html  css  js  c++  java
  • 《深度学习框架PyTorch:入门与实践》的Loss函数构建代码运行问题

    在学习陈云的教程《深度学习框架PyTorch:入门与实践》的损失函数构建时代码如下:

    可我运行如下代码:

    output = net(input)
    target = Variable(t.arange(0,10))  
    criterion = nn.MSELoss()
    loss = criterion(output, target)
    loss
    

    运行结果:

    RuntimeError                              Traceback (most recent call last)
    <ipython-input-37-e5c73861a53b> in <module>()
          2 target = Variable(t.arange(0,10))
          3 criterion = nn.MSELoss()
    ----> 4 loss = criterion(output, target)
          5 loss
    
    RuntimeError: Expected object of type torch.FloatTensor but found type torch.LongTensor for argument #2 'target'
    

    根据stackoverflo的问题Pytorch: Convert FloatTensor into DoubleTensorPyTorch(总)——PyTorch遇到令人迷人的BUG与记录,用torch.from_numpy(Y).float()这样的形式修改下target的类型。

    #torch.from_numpy(Y).float()
    output = net(input)
    y = np.arange(0,10).reshape(1,10)
    target = Variable(t.from_numpy(y).float())  
    criterion = nn.MSELoss()
    loss = criterion(output, target)
    loss
    

    运行结果:

    tensor(28.5897, grad_fn=<MseLossBackward>)
    

    同样的,后面优化器Optim代码中target也是出现这样的错误:

    import torch.optim as optim
    #新建一个优化器,指定要调整的参数和学习率
    optimizer = optim.SGD(net.parameters(),lr=0.01)
    
    #在训练过程中
    #先梯度清零(与net.zero_grad()效果一样)
    optimizer.zero_grad()
    
    #计算损失
    output = net(input)
    #把target改为Variable(t.from_numpy(y).float())就不会出错了
    loss = criterion(output, target)
    
    #反向传播
    loss.backward()
    
    #更新参数
    optimizer.step()
    

    运行结果:

    修改targetVariable(t.from_numpy(y).float())后成功运行:

    import torch.optim as optim
    #新建一个优化器,指定要调整的参数和学习率
    optimizer = optim.SGD(net.parameters(),lr=0.01)
    
    #在训练过程中
    #先梯度清零(与net.zero_grad()效果一样)
    optimizer.zero_grad()
    
    #计算损失
    output = net(input)
    #把target改为Variable(t.from_numpy(y).float())就不会出错了
    loss = criterion(output, Variable(t.from_numpy(y).float()))
    
    #反向传播
    loss.backward()
    
    #更新参数
    optimizer.step()
    
  • 相关阅读:
    vue打包不显示或图片不显示配置
    Vue::is特性用法
    毕业实习报告
    前端vscode常用快捷键总结
    1. 变量常量
    信息收集之CMS指纹识别
    4. EIGRP的复合度量值
    3. EIGRP报文,三张表,邻居建立
    信息收集之目录扫描
    2. EIGRP路由单播邻居和被动接口
  • 原文地址:https://www.cnblogs.com/HongjianChen/p/9445011.html
Copyright © 2011-2022 走看看