zoukankan      html  css  js  c++  java
  • pytorch inception v3 KeyError: <class 'tuple'>解决方法

    刚入门深度学习,试着跑了一下resnet18和resnet50都没有问题,但是在运行inception v3的时候遇到一个问题怎么都解决不了

    错误输出如下:

    Traceback (most recent call last):
      File "transfer_learning_tutorial.py", line 277, in <module>
        num_epochs=25)
      File "transfer_learning_tutorial.py", line 179, in train_model
        _, preds = torch.max(outputs, 1)
    TypeError: max() received an invalid combination of arguments - got (tuple, int), but expected one of:
     * (Tensor input)
     * (Tensor input, Tensor other, Tensor out)
     * (Tensor input, int dim, bool keepdim, tuple of Tensors out)

    在github上找到答案

    inception源码连接第125行,在train模式下并且aux_logits打开的情况下,返回x, aux。

    所以解决方法有以下两个:

    方法1.创建inception模型的时候,关闭aux_logits。设置关键字参数aux_logits=False

    方法2.接收返回的aux。

    output, aux = model(input_var)#注意只有在训练模式下接收两个参数
    out=model(input_var)#在求值模式下,仍然只返回一个参数
                        if phase=='train':
                            outputs,aux= model(inputs)
                        else:
                            outputs=model(inputs)
     
  • 相关阅读:
    关于java.lang.reflect.InvocationTargetException
    Java并发编程(三)后台线程(Daemon Thread)
    Lab 7-2
    Lab 7-1
    Lab 6-3
    Lab 6-2
    Lab 6-1
    Lab 5-1
    Lab 3-4
    Lab 3-3
  • 原文地址:https://www.cnblogs.com/MalcolmMeng/p/9604029.html
Copyright © 2011-2022 走看看