zoukankan      html  css  js  c++  java
  • 【PyTorch】使用中注意事项

    参考博客:

    https://blog.csdn.net/u011276025/article/details/73826562/

    1. 把Label要转成LongTensor格式

    self.y = torch.LongTensor(y)

    完整使用代码如下:

     1 class ImgDataset(Dataset):
     2     def __init__(self, x, y=None, transform=None):
     3         self.x = x
     4         # label is required to be a LongTensor
     5         self.y = y
     6         if y is not None:
     7             self.y = torch.LongTensor(y)
     8         self.transform = transform
     9     def __len__(self):
    10         return len(self.x)
    11     def __getitem__(self, index):
    12         X = self.x[index]
    13         if self.transform is not None:
    14             X = self.transform(X)
    15         if self.y is not None:
    16             Y = self.y[index]
    17             return X, Y
    18         else:
    19             return X
    View Code

    需要保证target类型为torch.cuda.LongTensor,需要在数据读取的迭代其中把target的类型转换为int64位的:target = target.astype(np.int64),这样,输出的target类型为torch.cuda.LongTensor。(或者在使用前使用Tensor.type(torch.LongTensor)进行转换)。

    *LongTensor其实就是int64,有符号整型

    2. 做预测时,没有y值,从dataloader中传入给model的直接是data,而不再是data[0]了

    model_best.eval()
    prediction = []
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            #print(data[0].size())
            # 特别要注意的是,这里直接传入data,因为已经没有y值了,所以无需data[0]。
            # 如果传了data[0]反而导致没有传入整个batch,计算错误
            test_pred = model_best(data.cuda())
            test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)
            for y in test_label:
                prediction.append(y)

    3. 训练时,要设成model.train(),这样optimizer就可以更新model的参数

      验证时,要设成model.val(),以此来固定model的参数。例如 去掉dropout、bn参数不变 等。

    4. ModuleNotFoundError:No module named ”Classifier“

    训练完保存模型后,再load模型去做预测时,仍然需要原来训练时的Classifier,即整个网络结构。。

    有点匪夷所思呀。。那存模型和存参数比有啥区别呢?存了个寂寞?

    还需要查一下存模型和存参数的区别

    5.pytorch中nn.crossEntropyLoss 自带softmax,无需将输出经softmax层再计算交叉熵损失

     其源码实现时,将 input 经过 softmax 激活函数之后,再计算其与 target 的交叉熵损失

    未完待续。。。

  • 相关阅读:
    cad.net 仿lisp函数专篇
    操作篇 cad一个小技巧,通过块中块插入含有字段块,保证更新
    cad.net 外部参照功能和相对路径转换
    cad.net 动态块名 .IsDynamicBlock出错 eInvalidObjectId错误.
    cad.net 委托的学习
    cad.net 关于保存文件Database.SaveAs()出现"eFileAccessErr"错误的解决方法
    测试篇 c# winFrom Close报错 System.ObjectDisposedException:“无法访问已释放的对象。
    测试篇 c#枚举类型怎么用?
    cad.net 2008使用WPF(摘录山人)
    日志篇 随着win10更新...
  • 原文地址:https://www.cnblogs.com/YeZzz/p/13067470.html
Copyright © 2011-2022 走看看