zoukankan      html  css  js  c++  java
  • pytorch(ch5

    读取图片数据集::
    # -*- coding: utf-8 -*-
    import torch as t
    from torch.utils import data
    import os
    from PIL import Image
    import numpy as np

    class DogCat(data.Dataset):
    def __init__(self,root):
    imgs=os.listdir(root)
    #所有图片的绝对路径
    #这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
    self.imgs=[os.path.join(root, img) for img in imgs]

    def __getitem__(self, index):
    img_path=self.imgs[index]
    #dog->1, cat->0
    label=1 if 'dog' in img_path.split("/")[-1] else 0
    pil_img=Image.open(img_path)
    array=np.asarray(pil_img)
    data=t.from_numpy(array)
    return data,label

    def __len__(self):
    return len(self.image)

    dataset=DogCat('data/train')
    img,label=dataset[0]#相当于调用dataset.__getitem__(0)
    for img,label in dataset:
    print(img.size(),img.float().mean(),label)



    第二:改变图片尺寸
    #-*- coding: utf-8 -*-
    import os
    from PIL import Image
    from torch.utils import data
    import numpy as np
    from torchvision import transforms as T


    transforms=T.Compose([
    T.Resize(224), #缩放图片(Image,保持长宽比不变,最短边为224像素
    T.CenterCrop(224), #从图片中间裁剪出224*224的图片
    T.ToTensor(), #将图片Image转换成Tensor,归一化至【0,1
    T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #标准化至【-1,1】,规定均值和方差
    ])

    class DogCat(data.Dataset):
    def __init__(self,root, transforms=None):
    imgs=os.listdir(root)
    self.imgs=[os.path.join(root, img) for img in imgs]
    self.transforms=transforms

    def __getitem__(self, index):
    img_path=self.imgs[index]
    #dog->1, cat->0
    label=1 if 'dog' in img_path.split("/")[-1] else 0
    data=Image.open(img_path)
    if self.transforms:
    data=self.transforms(data)
    return data,label

    def __len__(self):
    return len(self.imgs)
    dataset=DogCat('data/train', transforms=transforms)
    img,label=dataset[0]#相当于调用dataset.__getitem__(0)
    for img,label in dataset:
    print(img.size(),label)






    #使用ImageFolder读取图片
    #-*- coding: utf-8 -*-
    from torchvision.datasets import ImageFolder
    dataset=ImageFolder('data/')
    print(dataset.class_to_idx)
    print(dataset.imgs)
     
  • 相关阅读:
    hh
    SDUT 3923 打字
    最短路
    阶乘后面0的个数(51Nod 1003)
    大数加法
    Biorhythms(中国剩余定理)
    usaco-5.1-theme-passed
    usaco-5.1-starry-passed
    usaco-5.1-fc-passed
    usaco-4.4-frameup-passed
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/10309024.html
Copyright © 2011-2022 走看看