zoukankan      html  css  js  c++  java
  • Pytorch dataset自定义【直播】2019 年县域农业大脑AI挑战赛---数据准备(二),Dataset定义

    在我的torchvision库里介绍的博文(https://www.cnblogs.com/yjphhw/p/9773333.html)里说了对pytorch的dataset的定义方式。

    本文相当于实现一个自定义的数据集,而这正是我们在做自己工程所需要的,我们总是用自己的数据嘛。

    继承 from torch.utils.data import Dataset 类

    然后实现 __len__(self) ,和 __getitem__(self,idx) 两个方法。以及数据增强也可以写入,数据增强想了想还是放到了Dataset里,

    习惯上可能与常用的不同,但是觉得由于每种数据都有自己的增强方法所以,增强方法可以和数据集绑定到一起的。

    接上一节我们通过切割,获取了2217个图像切片。

    这就是我的FarmDataset

    from torch.utils.data import Dataset, DataLoader
    from PIL import Image,ImageEnhance
    from osgeo import gdal
    from torchvision import transforms
    import glob
    import torch as tc 
    import numpy as np
    
    
    class FarmDataset(Dataset):
        def __init__(self,istrain=True,isaug=True):
            self.istrain=istrain
            self.trainxformat='./data/train/data1500/*.bmp'
            self.trainyformat='./data/train/label1500/*.bmp'
            self.testxformat='./data/test/*.png'
            self.fns=glob.glob(self.trainxformat) if istrain else glob.glob(self.testxformat)
            self.length=len(self.fns)
            self.transforms=transforms
            self.isaug=isaug
            
        def __len__(self):
            #total length is 2217 
            return self.length
        def __getitem__(self,idx):
            if self.istrain:
                
                imgxname=self.fns[idx]
                sampleimg = Image.open(imgxname)
                imgyname=imgxname.replace('data1500','label1500')
                targetimg = Image.open(imgyname).convert('L')
                #sampleimg.save('original.bmp')
                
                #data augmentation
                if self.isaug:
                    sampleimg,targetimg=self.imgtrans(sampleimg,targetimg)
                
                #check the result of dataautmentation
                #sampleimg.save('sampletmp.bmp')
                #targetimg.save('targettmp.bmp')
                
                sampleimg=transforms.ToTensor()(sampleimg) 
                #targetimg=transforms.ToTensor()(targetimg).squeeze(0).long() 
                targetimg=np.array(targetimg)
                targetimg=tc.from_numpy(targetimg).long()         #to tensor
                #print(sampleimg.shape,targetimg.shape)
                return sampleimg,targetimg
            else:
                return gdal.Open(self.fns[idx])
        def imgtrans(self,x,y,outsize=1024):
            '''input is a PIL image 
               image dataaugumentation
               return also aPIL image。
            '''
            #rotate should consider y
            degree=np.random.randint(360)
            x=x.rotate(degree,resample=Image.NEAREST,fillcolor=0)
            y=y.rotate(degree,resample=Image.NEAREST,fillcolor=0)  #here should be carefull, in case of label damage
            
            #random do the input image augmentation
            if np.random.random()>0.5:
                #sharpness 
                factor=0.5+np.random.random()
                enhancer=ImageEnhance.Sharpness(x)
                x=enhancer.enhance(factor)
            if np.random.random()>0.5:
                #color augument
                factor=0.5+np.random.random()
                enhancer=ImageEnhance.Color(x)
                x=enhancer.enhance(factor)
            if np.random.random()>0.5:
                #contrast augument
                factor=0.5+np.random.random()
                enhancer=ImageEnhance.Contrast(x)
                x=enhancer.enhance(factor)
            if np.random.random()>0.5:
                #brightness
                factor=0.5+np.random.random()
                enhancer=ImageEnhance.Brightness(x)
                x=enhancer.enhance(factor)
            
            #img flip
            transtypes=[Image.FLIP_LEFT_RIGHT,Image.FLIP_TOP_BOTTOM,
                    Image.ROTATE_90,Image.ROTATE_180,Image.ROTATE_270]
            transtype=transtypes[np.random.randint(len(transtypes))]
            x = x.transpose(transtype)
            y = y.transpose(transtype)
            
            #img resize between 0.8-1.2
            w,h=x.size
            factor=1+np.random.normal()/5
            if factor>1.2: factor=1.2
            if factor<0.8: factor=0.8
            #print(factor,x.size)
            x=x.resize((int(w*factor),int(h*factor)),Image.NEAREST)
            y=y.resize((int(w*factor),int(h*factor)),Image.NEAREST)
            
            #random crop
            w,h=x.size
            stx=np.random.randint(w-outsize)
            sty=np.random.randint(h-outsize)
            #print((stx,sty,outsize,outsize))
            x=x.crop((stx,sty,stx+outsize,sty+outsize)) #stx,sty,width,height
            y=y.crop((stx,sty,stx+outsize,sty+outsize))
            #print(x.size,y.size)
            return x,y   #return outsized pil image
        
    
    if __name__=='__main__':
        d=FarmDataset(istrain=True)
        x,y=d[2216]
        print(x.shape)
        print(y.shape)
    

      

      输入的是个1500x1500的图像,输出的是增强后的1024x1024后的图像。

      其实对于分割问题来看,以后这个就可以作为一个模板,修改修改就可以换到另一个数据集中。

    放几张图片:

    原始图像:

    进行数据增强后可以得到的一系列:

    经过check 发现没有的问题通过测试。

     

  • 相关阅读:
    php 函数ignore_user_abort()
    关于VMAX中存储资源池(SRP)
    VMware Integrated OpenStack (VIO)简介
    云计算服务的三种类型(SaaS、PaaS、IaaS)
    vMware存储:SAN配置基础
    关于不同应用程序存储IO类型的描述
    (转)OpenFire源码学习之十:连接管理(上)
    (转)OpenFire源码学习之九:OF的缓存机制
    (转)OpenFire源码学习之八:MUC用户聊天室
    (转)OpenFire源码学习之七:组(用户群)与花名册(用户好友)
  • 原文地址:https://www.cnblogs.com/yjphhw/p/11077727.html
Copyright © 2011-2022 走看看