zoukankan      html  css  js  c++  java
  • pytorch 中Dataloader中的collate_fn参数

    一般的,默认的collate_fn函数是要求一个batch中的图片都具有相同size(因为要做stack操作),当一个batch中的图片大小都不同时,可以使用自定义的collate_fn函数,则一个batch中的图片不再被stack操作,可以全部存储在一个list中,当然还有对应的label,如下面这个例子:

    import torch
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import torchvision.datasets as datasets
    import matplotlib.pyplot as plt
    
    # a simple custom collate function, just to show the idea
    def my_collate(batch):
        data = [item[0] for item in batch]
        target = [item[1] for item in batch]
        target = torch.LongTensor(target)
        return [data, target]
    
    
    def show_image_batch(img_list, title=None):
        num = len(img_list)
        fig = plt.figure()
        for i in range(num):
            ax = fig.add_subplot(1, num, i+1)
            ax.imshow(img_list[i].numpy().transpose([1,2,0]))
            ax.set_title(title[i])
    
        plt.show()
    
    #  do not do randomCrop to show that the custom collate_fn can handle images of different size
    train_transforms = transforms.Compose([transforms.Scale(size = 224),
                                           transforms.ToTensor(),
                                           ])
    
    # change root to valid dir in your system, see ImageFolder documentation for more info
    train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset",
                                         transform=train_transforms)
    
    trainset = DataLoader(dataset=train_dataset,
                          batch_size=4,
                          shuffle=True,
                          collate_fn=my_collate, # use custom collate function here
                          pin_memory=True)
    
    trainiter = iter(trainset)
    imgs, labels = trainiter.next()
    
    # print(type(imgs), type(labels))
    show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])
  • 相关阅读:
    188. Best Time to Buy and Sell Stock IV
    452. Minimum Number of Arrows to Burst Balloons
    435. Non-overlapping Intervals
    28. Implement strStr() KMP
    10. Regular Expression Matching
    877. Stone Game
    格式化日期
    Designer属性编辑器简介
    pandas 根据内容匹配并获取索引
    access 中sql语句之“like” 语句的用法
  • 原文地址:https://www.cnblogs.com/zf-blog/p/11360557.html
Copyright © 2011-2022 走看看