借助TensorDataset直接将数据包装成dataset类
直接使用 TensorDataset 来将数据包装成Dataset类,再使用dataloader。
import torch from torch import nn from torch.utils.data import Dataset, DataLoader, TensorDataset src = torch.sin(torch.arange(1, 1000, 0.1)) trg = torch.cos(torch.arange(1, 1000, 0.1)) data = TensorDataset(src, trg) data_loader = DataLoader(data, batch_size=5, shuffle=False) for i_batch, batch_data in enumerate(data_loader): print(i_batch) # 打印batch编号 print(batch_data[0].size()) # 打印该batch里面src print(batch_data[1].size()) # 打印该batch里面trg
output:
0 torch.Size([5]) torch.Size([5]) 1 torch.Size([5]) torch.Size([5]) ...
希望后续多看到这种,若有好的资源可以在评论区留言
参考:https://blog.csdn.net/weixin_42468475/article/details/108714940