zoukankan      html  css  js  c++  java
  • pytorch学习:准备自己的图片数据

    图片数据一般有两种情况:

    1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。

    2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。

    针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:

    一、所有图片放在一个文件夹内

    这里以mnist数据集的10000个test为例, 我先把test集的10000个图片保存出来,并生着对应的txt标签文件。

    先在当前目录创建一个空文件夹mnist_test, 用于保存10000张图片,接着运行代码:

    import torch
    import torchvision
    import matplotlib.pyplot as plt
    from skimage import io
    mnist_test= torchvision.datasets.MNIST(
        './mnist', train=False, download=True
    )
    print('test set:', len(mnist_test))
    
    f=open('mnist_test.txt','w')
    for i,(img,label) in enumerate(mnist_test):
        img_path="./mnist_test/"+str(i)+".jpg"
        io.imsave(img_path,img)
        f.write(img_path+' '+str(label)+'
    ')
    f.close()

    经过上面的操作,10000张图片就保存在mnist_test文件夹里了,并在当前目录下生成了一个mnist_test.txt的文件,大致如下:

    前期工作就装备好了,接着就进入正题了:

    from torchvision import transforms, utils
    from torch.utils.data import Dataset, DataLoader
    import matplotlib.pyplot as plt
    from PIL import Image
    
    
    def default_loader(path):
        return Image.open(path).convert('RGB')
    
    
    class MyDataset(Dataset):
        def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
            fh = open(txt, 'r')
            imgs = []
            for line in fh:
                line = line.strip('
    ')
                line = line.rstrip()
                words = line.split()
                imgs.append((words[0],int(words[1])))
            self.imgs = imgs
            self.transform = transform
            self.target_transform = target_transform
            self.loader = loader
    
        def __getitem__(self, index):
            fn, label = self.imgs[index]
            img = self.loader(fn)
            if self.transform is not None:
                img = self.transform(img)
            return img,label
    
        def __len__(self):
            return len(self.imgs)
    
    train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
    data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
    print(len(data_loader))
    
    
    def show_batch(imgs):
        grid = utils.make_grid(imgs)
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        plt.title('Batch from dataloader')
    
    
    for i, (batch_x, batch_y) in enumerate(data_loader):
        if(i<4):
            print(i, batch_x.size(),batch_y.size())
            show_batch(batch_x)
            plt.axis('off')
            plt.show()

    自定义了一个MyDataset, 继承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader将整个数据集分成多个批次。

    二、不同类别的图片放在不同的文件夹内

    同样先准备数据,这里以flowers数据集为例,下载:

    http://download.tensorflow.org/example_images/flower_photos.tgz

    花总共有五类,分别放在5个文件夹下。大致如下图:

    我的路径是d:/flowers/.

    数据准备好了,就开始准备Dataset吧,这里直接调用torchvision里面的ImageFolder

    import torch
    import torchvision
    from torchvision import transforms, utils
    import matplotlib.pyplot as plt
    
    img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
                                                transform=transforms.Compose([
                                                    transforms.Scale(256),
                                                    transforms.CenterCrop(224),
                                                    transforms.ToTensor()])
                                                )
    
    print(len(img_data))
    data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
    print(len(data_loader))
    
    
    def show_batch(imgs):
        grid = utils.make_grid(imgs,nrow=5)
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        plt.title('Batch from dataloader')
    
    
    for i, (batch_x, batch_y) in enumerate(data_loader):
        if(i<4):
            print(i, batch_x.size(), batch_y.size())
    
            show_batch(batch_x)
            plt.axis('off')
            plt.show()

    就是这样。

  • 相关阅读:
    基于三角形问题通过边界值分析和等价类划分进行黑盒测试
    小程序学习记录【数组操作相关(持续更新)】(1)
    Android实现九宫拼图过程记录
    高维数据Lasso思路
    CannyLab/tsne-cuda with cuda-10.0
    xgboost 多gpu支持 编译
    GDAL2.2.4 C#中的编译及使用
    SqlServer性能优化,查看CPU、内存占用大的会话及SQL语句
    WinForm任务栏最小化
    datatable与实体类之间相互转化的几种方法
  • 原文地址:https://www.cnblogs.com/denny402/p/7512516.html
Copyright © 2011-2022 走看看