zoukankan      html  css  js  c++  java
  • 【PyTorch】使用中注意事项

    参考博客:

    https://blog.csdn.net/u011276025/article/details/73826562/

    1. 把Label要转成LongTensor格式

    self.y = torch.LongTensor(y)

    完整使用代码如下:

     1 class ImgDataset(Dataset):
     2     def __init__(self, x, y=None, transform=None):
     3         self.x = x
     4         # label is required to be a LongTensor
     5         self.y = y
     6         if y is not None:
     7             self.y = torch.LongTensor(y)
     8         self.transform = transform
     9     def __len__(self):
    10         return len(self.x)
    11     def __getitem__(self, index):
    12         X = self.x[index]
    13         if self.transform is not None:
    14             X = self.transform(X)
    15         if self.y is not None:
    16             Y = self.y[index]
    17             return X, Y
    18         else:
    19             return X
    View Code

    需要保证target类型为torch.cuda.LongTensor,需要在数据读取的迭代其中把target的类型转换为int64位的:target = target.astype(np.int64),这样,输出的target类型为torch.cuda.LongTensor。(或者在使用前使用Tensor.type(torch.LongTensor)进行转换)。

    *LongTensor其实就是int64,有符号整型

    2. 做预测时,没有y值,从dataloader中传入给model的直接是data,而不再是data[0]了

    model_best.eval()
    prediction = []
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            #print(data[0].size())
            # 特别要注意的是,这里直接传入data,因为已经没有y值了,所以无需data[0]。
            # 如果传了data[0]反而导致没有传入整个batch,计算错误
            test_pred = model_best(data.cuda())
            test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)
            for y in test_label:
                prediction.append(y)

    3. 训练时,要设成model.train(),这样optimizer就可以更新model的参数

      验证时,要设成model.val(),以此来固定model的参数。例如 去掉dropout、bn参数不变 等。

    4. ModuleNotFoundError:No module named ”Classifier“

    训练完保存模型后,再load模型去做预测时,仍然需要原来训练时的Classifier,即整个网络结构。。

    有点匪夷所思呀。。那存模型和存参数比有啥区别呢?存了个寂寞?

    还需要查一下存模型和存参数的区别

    5.pytorch中nn.crossEntropyLoss 自带softmax,无需将输出经softmax层再计算交叉熵损失

     其源码实现时,将 input 经过 softmax 激活函数之后,再计算其与 target 的交叉熵损失

    未完待续。。。

  • 相关阅读:
    结对第一次—原型设计(文献摘要热词统计)
    第一次作业-准备篇
    201771010135杨蓉庆《面向对象程序设计(java)》第二周学习总结
    杨蓉庆201771010135《面向对象程序设计(java)》第一周学习总结
    2019 SDN阅读作业
    第01组 Alpha冲刺 (2/4)
    2019 SDN上机第3次作业
    第01组 ALPHA冲刺(1/4)
    2019SDN上机第二次作业
    2019 SDN上机第1次作业
  • 原文地址:https://www.cnblogs.com/YeZzz/p/13067470.html
Copyright © 2011-2022 走看看