zoukankan      html  css  js  c++  java
  • DCGAN增强图片数据集

    DCGAN增强图片数据集

    1.Dependencies

    2.DCGAN

    步骤:

    • 将图片数集放在/Anime_GAN/DCGAN/faces

    • 进行如下的命令:

      $ cd Anime_GAN/DCGAN/
    $ python main.py --help # 查看默认参数信息,根据需求可进行修改

    执行完上述命令会产生相应的一张图片(存储位 置:/Anime_GAN/DCGAN/saved/img/xx.png)

    • 调用SegmentePictures.py进行图片的切割

    $ cd DCGAN/saved
    $ python SegmentePictures.py   
    # encoding:utf-8
    from PIL import Image
    import sys
    import math
    import argparse
    
    def fill_image(image):
        """
        将图片填充为正方形
        :param image:
        :return:
        """
        width, height = image.size
        #选取长和宽中较大值作为新图片的
        new_image_length = width if width > height else height
        #生成新图片[白底]
        new_image = Image.new(image.mode, (new_image_length, new_image_length), color='white')
        #将之前的图粘贴在新图上,居中
        if width > height:#原图宽大于高,则填充图片的竖直维度
            #(x,y)二元组表示粘贴上图相对下图的起始位置
            new_image.paste(image, (0, int((new_image_length - height) / 2)))
        else:
            new_image.paste(image,(int((new_image_length - width) / 2),0))
    
        return new_image
    
    
    def cut_image(image,cut_num):
        """
        切图
        :param image:
        :return:
        """
        flag_value = int(math.sqrt(cut_num))
        width, height = image.size
        item_width = int(width / flag_value)
        box_list = []
        for i in range(0,flag_value):
            for j in range(0,flag_value):
                box = (j*item_width,i*item_width,(j+1)*item_width,(i+1)*item_width)
                box_list.append(box)
        image_list = [image.crop(box) for box in box_list]
    
        return image_list
    
    
    def save_images(image_list):
        """
        保存
        :param image_list:
        :return:
        """
        index = 1
        for image in image_list:
            image.save('./img_add/'+str(index) + '.png', 'PNG')
            index += 1
    
    def main():
        parse = argparse.ArgumentParser()
    
        parse.add_argument("--lr", type=float, default=0.0001,
                           help="learning rate of generate and discriminator")
        parse.add_argument("--beta1", type=float, default=0.5,
                           help="adam optimizer parameter")
        parse.add_argument("--batch_size", type=int, default=81,
                           help="number of dataset in every train or test iteration")
        parse.add_argument("--epochs", type=int, default=0,
                           help="number of training epochs")
        parse.add_argument("--loaders", type=int, default=4,
                           help="number of parallel data loading processing")
        parse.add_argument("--size_per_dataset", type=int, default=30000,
                           help="number of training data")
    
    
        args = parse.parse_args()
    
        file_path = "./img/"+args.epochs+".png"   # 图片路径
        image = Image.open(file_path)
        image = fill_image(image)
        image_list = cut_image(image,batch_size)
        save_images(image_list)
    
    if __name__ == '__main__':
        main()

    需要注意的是:下面的命令中batch_size的数一定要一致

    $ python main.py --batch_size=xx 
    
    $ python SegmentePictures.py --batch_size=xx 

    3.遇到的问题

    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 370 and 667 in dimension 2 at /pytorch/aten/src/TH/generic/THTensor.cpp:711

    • 错误分析:使用DataLoader加载图像,这些图像中的一些具有3个通道(彩色图像),而其他图像可能具有单个通道(BW图像),由于dim1的尺寸不同,因此无法将它们连接成批次。 尝试将img = img.convert(‘RGB’)添加到数据集中的getitem中。

    • 将图片的通道进行统一

      from PIL import Image
      import matplotlib.pyplot as plt
      import os
      ​
      ​
      def GetAllFiles(dir):
          files_ = []
          list = os.listdir(dir)
          for i in range(0, len(list)):
              path = os.path.join(dir, list[i])
              if os.path.isdir(path):
                  files_.extend(GetAllFiles(path))
              if os.path.isfile(path):
                  files_.append(path)
          return files_
      ​
      def ConvertRGB():
          """
          将图片转换为RGB格式
          :return:
          """
          files_ = GetAllFiles(file_path)
          for id,item in enumerate(files_):
              img=Image.open(item)
              gray=img.convert('RGB')
              plt.imshow(gray)
              plt.axis('off')
              save_path = "./save_img"+"\"+str(id)+".jpg"
              plt.savefig(save_path)
              # plt.show()
      if __name__ == "__main__":
          file_path = "your path"
          ConvertRGB()

       

    参考链接:https://github.com/FangYang970206/Anime_GAN/blob/master/README.md

  • 相关阅读:
    其他
    Win10
    Win10
    面向对象与设计模式
    Git
    Java
    Git
    Git
    Git
    一、I/O操作(File文件对象)
  • 原文地址:https://www.cnblogs.com/shierlou-123/p/11236594.html
Copyright © 2011-2022 走看看