zoukankan      html  css  js  c++  java
  • 利用torch.utils.data.Dataset自定义数据加载类

    import torch as t
    from torch.utils import data
    import os
    from PIL import Image
    import numpy as np

    import torchvision.transforms as T

    transforms = T.Compose([

      T.Resize(224),

      T.CenterCrop(224),

      T.ToTensor(),

      T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

    ])

    # 继承Dataset类要重写__getitem__()和__len__()
    class CatDog(data.Dataset):
      def __init__(self, root, transforms=None):

        # 临时变量不用加self
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root, img) for img in imgs]

        self.transforms = transforms

      def __getitem__(self, index):
        label = 1 if dog else 0

        data = Image.open(self.imgs[index])
        if self.transform:

          data = self.transform(data)
        return data, label

      def __len__(self):
        return len(self.imgs)

  • 相关阅读:
    webpack的安装与配置
    npm初始化
    gitignore的配置
    git本地已有文件夹和远程仓库对应
    git 配置
    开发环境和开发工具
    git 码云使用教程
    递归
    LeetCode 392. 判断子序列
    MongoDB基本操作
  • 原文地址:https://www.cnblogs.com/liujianing/p/12320539.html
Copyright © 2011-2022 走看看