zoukankan      html  css  js  c++  java
  • Pytorch 重写Dataloader

    是一个官网的例子:torch.nn入门

    一般而言,我们会根据自己的数据需求继承Dataset(from torch.utils.data import Dataset, DataLoader)重写数据读取函数。或者利用TensorDataset更加简洁实现读取数据。

    抑或利用 torchvision里面的ImageFolder也可管理数据。这几种方法已经可以实现数据读取了,而DataLoader的作用是更加全面管理批量数据:

    下面进入正题,MNIST数据利用CNN时需要转换为二维数据,所以需要对初始的线性数据进行转换。一般,可以读取先行数据后在模型中进行view来实现:

    class Lambda(nn.Module):
        def __init__(self, func):
            super().__init__()
            self.func = func
    
        def forward(self, x):
            return self.func(x)
    
    
    def preprocess(x):
        return x.view(-1, 1, 28, 28)
    
    model = nn.Sequential(
        Lambda(preprocess),
        nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.AvgPool2d(4),
        Lambda(lambda x: x.view(x.size(0), -1)),
    )

    文中给出另一种解决方案:重写DateLoader:将数据处理移到生成器里面

    def get_data(train_ds, valid_ds, bs):
        return (
            DataLoader(train_ds, batch_size=bs, shuffle=True),
            DataLoader(valid_ds, batch_size=bs * 2),
        )
    
    def preprocess(x, y):
        return x.view(-1, 1, 28, 28), y
    
    
    class WrappedDataLoader:
        def __init__(self, dl, func):
            self.dl = dl
            self.func = func
    
        def __len__(self):
            return len(self.dl)
    
        def __iter__(self):
            batches = iter(self.dl)
            for b in batches:
                yield (self.func(*b))
    
    train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
    train_dl = WrappedDataLoader(train_dl, preprocess)
    valid_dl = WrappedDataLoader(valid_dl, preprocess)

    模型就可以写成这样:

    model = nn.Sequential(
        nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.AdaptiveAvgPool2d(1),
        Lambda(lambda x: x.view(x.size(0), -1)),
    )
  • 相关阅读:
    Java读书笔记(2)-输入输出
    Java读书笔记(1)-异常处理
    Photoshop自动导出各尺寸Android和Iphone图标,支持新版Android Studio
    【原创】我的研究生活
    [原创]使用Fiddler抓取手机APP流量--360WIFI
    Federa 7 配置yum 源
    开源自己写的刷票器软件(windows和Android)
    更新linux kernel到3.14.10 LTS版后,virt-manager无法识别qemu hypervisor的问题
    Net Core Identity 身份验证:注册、登录和注销 (简单示例)
    Net Core的API文档工具Swagger
  • 原文地址:https://www.cnblogs.com/king-lps/p/12721758.html
Copyright © 2011-2022 走看看