刚入门深度学习,试着跑了一下resnet18和resnet50都没有问题,但是在运行inception v3的时候遇到一个问题怎么都解决不了
错误输出如下:
Traceback (most recent call last): File "transfer_learning_tutorial.py", line 277, in <module> num_epochs=25) File "transfer_learning_tutorial.py", line 179, in train_model _, preds = torch.max(outputs, 1) TypeError: max() received an invalid combination of arguments - got (tuple, int), but expected one of: * (Tensor input) * (Tensor input, Tensor other, Tensor out) * (Tensor input, int dim, bool keepdim, tuple of Tensors out)
在github上找到答案
inception源码连接第125行,在train模式下并且aux_logits打开的情况下,返回x, aux。
所以解决方法有以下两个:
方法1.创建inception模型的时候,关闭aux_logits。设置关键字参数aux_logits=False
方法2.接收返回的aux。
output, aux = model(input_var)#注意只有在训练模式下接收两个参数
out=model(input_var)#在求值模式下,仍然只返回一个参数
if phase=='train': outputs,aux= model(inputs) else: outputs=model(inputs)