zoukankan      html  css  js  c++  java
  • pytorch(ch5

    读取图片数据集::
    # -*- coding: utf-8 -*-
    import torch as t
    from torch.utils import data
    import os
    from PIL import Image
    import numpy as np

    class DogCat(data.Dataset):
    def __init__(self,root):
    imgs=os.listdir(root)
    #所有图片的绝对路径
    #这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
    self.imgs=[os.path.join(root, img) for img in imgs]

    def __getitem__(self, index):
    img_path=self.imgs[index]
    #dog->1, cat->0
    label=1 if 'dog' in img_path.split("/")[-1] else 0
    pil_img=Image.open(img_path)
    array=np.asarray(pil_img)
    data=t.from_numpy(array)
    return data,label

    def __len__(self):
    return len(self.image)

    dataset=DogCat('data/train')
    img,label=dataset[0]#相当于调用dataset.__getitem__(0)
    for img,label in dataset:
    print(img.size(),img.float().mean(),label)



    第二:改变图片尺寸
    #-*- coding: utf-8 -*-
    import os
    from PIL import Image
    from torch.utils import data
    import numpy as np
    from torchvision import transforms as T


    transforms=T.Compose([
    T.Resize(224), #缩放图片(Image,保持长宽比不变,最短边为224像素
    T.CenterCrop(224), #从图片中间裁剪出224*224的图片
    T.ToTensor(), #将图片Image转换成Tensor,归一化至【0,1
    T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #标准化至【-1,1】,规定均值和方差
    ])

    class DogCat(data.Dataset):
    def __init__(self,root, transforms=None):
    imgs=os.listdir(root)
    self.imgs=[os.path.join(root, img) for img in imgs]
    self.transforms=transforms

    def __getitem__(self, index):
    img_path=self.imgs[index]
    #dog->1, cat->0
    label=1 if 'dog' in img_path.split("/")[-1] else 0
    data=Image.open(img_path)
    if self.transforms:
    data=self.transforms(data)
    return data,label

    def __len__(self):
    return len(self.imgs)
    dataset=DogCat('data/train', transforms=transforms)
    img,label=dataset[0]#相当于调用dataset.__getitem__(0)
    for img,label in dataset:
    print(img.size(),label)






    #使用ImageFolder读取图片
    #-*- coding: utf-8 -*-
    from torchvision.datasets import ImageFolder
    dataset=ImageFolder('data/')
    print(dataset.class_to_idx)
    print(dataset.imgs)
     
  • 相关阅读:
    Win7 中出现图标显示不全或消失的解决方法
    动态控制ToolStrip上ToolStripButton的图标大小
    TS——类型断言
    TS——函数的类型
    TS之对象类型——接口
    TS——联合类型
    Git文件合并
    1-1、作用域深入和面向对象
    webStrom2017.1版本如何添加vue.js插件
    二:搭建一个webpack3.5.5项目:建立项目的webpack配置文件
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/10309024.html
Copyright © 2011-2022 走看看