分为两种方式
第一种:通过加载pytorch官方常用的数据集
# dataloader = torch.utils.data.DataLoader( # datasets.MNIST( # "../../data/mnist", # train=True, # download=True, # transform=transforms.Compose( # [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] # ), # ), # batch_size=opt.batch_size, # shuffle=True, # )
第二种:通过加载本地的数据集
train_loader=datasets.ImageFolder(args.datasets, transform=transforms.Compose([ transforms.Resize(opt.img_size), transforms.Grayscale(1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]), ) dataloader = torch.utils.data.DataLoader( train_loader, batch_size=args.batch_size, shuffle=True)