zoukankan      html  css  js  c++  java
  • 【源码解读】cycleGAN(三):数据读取

    源码地址:https://github.com/aitorzip/PyTorch-CycleGAN

    数据的读取是比较简单的,cycleGAN对数据没有pair的需求,不同域的两个数据集分别存放于A,B两个文件夹,写好dataset接口即可

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()
    
    # Dataset loader
    transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC), 
                    transforms.RandomCrop(opt.size), 
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
    dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True), 
                            batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)

    上面的代码中,首先定义好buffer(后面细说),然后定义好图像变换,调用定义好的ImageDataset(继承自dataset) 对象,即可从dataloader中读取数据。下面是ImageDataset的代码

    class ImageDataset(Dataset):
        def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
            self.transform = transforms.Compose(transforms_)
            self.unaligned = unaligned
    
            self.files_A = sorted(glob.glob(os.path.join(root, '%s/A' % mode) + '/*.*'))
            self.files_B = sorted(glob.glob(os.path.join(root, '%s/B' % mode) + '/*.*'))
    
        def __getitem__(self, index):
            item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
    
            if self.unaligned:
                item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
            else:
                item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))
    
            return {'A': item_A, 'B': item_B}
    
        def __len__(self):
            return max(len(self.files_A), len(self.files_B))

    标准的实现了__init__, __getitem__, __len__三个接口,不过我还不太清楚这里对数据进行排序和对齐的目的,对齐可以按序读取,不对齐则随机读取最后,关于buffer,参考cycleGAN的论文,原话是这么说的“Second, to reduce model oscillation [15], we follow Shrivastava et al.’s strategy [46] and update the discriminators using a history of generated images rather than the ones produced by the latest generators. We keep an image buffer that stores the 50 previously created images

    也就是说,是为了训练的稳定,采用历史生成的虚假样本来更新判别器,而不是当前生成的虚假样本,至于原理,参考的是另一篇论文。我们来看一下代码

    class ReplayBuffer():
        def __init__(self, max_size=50):
            assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
            self.max_size = max_size
            self.data = []
    
        def push_and_pop(self, data):
            to_return = []
            for element in data.data:
                element = torch.unsqueeze(element, 0)
                if len(self.data) < self.max_size:
                    self.data.append(element)
                    to_return.append(element)
                else:
                    if random.uniform(0,1) > 0.5:
                        i = random.randint(0, self.max_size-1)
                        to_return.append(self.data[i].clone())
                        self.data[i] = element
                    else:
                        to_return.append(element)
            return Variable(torch.cat(to_return))

    定义了一个buffer对象,有一个数据存储表data,大小预设为50,我认为它的运转流程是这样的:数据表未填满时,每次读取的都是当前生成的虚假图像,当数据表填满时,随机决定 1. 在数据表中随机抽取一批数据,返回,并且用当前数据补充进来 2. 采用当前数据

    至于为什么这样有道理,要看参考论文了

     

  • 相关阅读:
    Kafka 生产者 自定义分区策略
    同步互斥
    poj 1562 Oil Deposits(dfs)
    poj 2386 Lake Counting(dfs)
    poj 1915 KnightMoves(bfs)
    poj 1664 放苹果(dfs)
    poj 1543 Perfect Cubes (暴搜)
    poj 1166 The Clocks (暴搜)
    poj 3126 Prime Path(bfs)
    处理机调度
  • 原文地址:https://www.cnblogs.com/wzyuan/p/11899821.html
Copyright © 2011-2022 走看看