下面是定义初始化
#初始化输入的张量 - torch.empty是返回一个包含未初始化数据的张量 self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label = torch.empty(size=(self.opt.batchsize,), dtype=torch.float32, device=self.device)
然后进行动态赋值:
#设置输入的数据 def set_input(self, input): self.input.data.resize_(input[0].size()).copy_(input[0]) #把data的第一项:图片数据赋值给self.input self.label.data.resize_(input[1].size()).copy_(input[1]) #把data的第二项:图片的标签赋值给sele.gt