zoukankan      html  css  js  c++  java
  • pytorch(二)数据准备工作ETL

    1.Extract: 从网络等下载图像数据集
    2.Transform: 图片---->tensor
    3.Loader: tensor装进数据流管道,以便获取到流出batch长度数据。
    ()


    1.torch.utils.data.datasets ---(Extract, Transform)

    抽象类:具有必须待实现(重写)的方法的Python类. 因此我们可以通过扩展这个抽象类的功能,创建子类来构造自定义数据集。

    需要重写-override的函数:

    • len:实现数据集长度功能
    • getitem:实现对数据的位置索引,可以根据索引来访问数据元素
    #这里直接继承了MNIST类,MNIST类也是继承Dataset类实现。
    class FashionMNIST(MNIST):(
        urls = [
            'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
            'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
            'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
            'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
        ]
    
    )
    
    train_set = torchvision.datasets.FashionMNIST(
        root='./data'
        ,train=True
        ,download=True
        ,transform=transforms.Compose([
            transforms.ToTensor()
        ])
    )
    """
    #设置保存路径
    #选择训练集,默认测试集
    #是否下载,下载前方法会检查目录下是否已经,不会重复下载
    #通过torchvison.transforms对数据做变换
    #开代理会很快
    
    """
    

    2.torch.utils.data.dataloader

    train_loader = torch.utils.data.DataLoader(train_set, batch_size = 1000,shuffle = True,num_workers = 0)
    #num_workers默认为0表示用主进程来装载数据
    # 注意:Pytorch multiprocessing does not work on Windows!!!因此在windows系统上此处必须设置为0,即只有一个进程存在。
    #数据成可训练的pipeline最后只需要给三个入参:数据集,批次,是否打乱。
    #batch_size:  data size of  per batch
    """
    batch_size小对训练的影响:
    1.很小比如为1,训练震荡严重,不易收敛
    2.增大,下山路线,开始变正确
    3.继续增大,已经足够准确,不再变化
    4.但是随着batch_size增大,相同的epoch次数下,迭代次数变小,因此需要注意在增大batch_size的同时,增加epoch。不能增大了batch导致迭代次数明显减少,会导致最优化效果变差。
    """
    

    3.利用dataset和dataloader探索数据

    train_set.targets#每个的标签
    
    train_set.targets.bincount()
    #统计每个标签类别的数目
    #用途:检查是否有严重的类别不平衡问题,class imbalanace
    
    sample = next(iter(train_set))
    #iter和next都是python内置函数(既可以用在train_set,也可以用在封装好的dataloader)
    # iter返回迭代器对象, next获取一个迭代器元素
    
    len(sample)
    print(type(sample))
    #从torchvision获取的数据集每个sample样式:(tentor(img),tensor(label))
    
    image.shape
    # torch.Size([1, 28, 28]) 
    
    plt.imshow(image.squeeze(), cmap="gray") # 因为plt.imshow对于单通道显示格式是 H W
    
    display_loader = torch.utils.data.DataLoader(train_set, batch_size=10)
    images, labels = next(iter(display_loader)  # len(next(iter(...))) = 2
    
    print('types:', type(images), type(labels))
    # types: <class 'torch.Tensor'> <class 'torch.Tensor'>
    
    print('shapes:', images.shape, labels.shape)
    # shapes: torch.Size([10, 1, 28, 28]) torch.Size([10])
    
    #两种可视化数据集的方式
    #1.torchvision.utils.make_grid()
    
    grid = torchvision.utils.make_grid(images, nrow=10)
    plt.figure(figsize=(15,15))
    plt.imshow(np.transpose(grid, (1,2,0))) #plt.imshow(grid.permute(1,2,0))
    
    #2.利用DataLoader显示
    
    how_many_to_plot = 20
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True)
    
    plt.figure(figsize=(50,50))
    for i, img_batch in enumerate(train_loader, start=1):
    #enumerate(sequence,start=1):枚举出的就是(i,img_batch)形状,如(1,a),(2,b), (3,c)
        image, label = img_batch
        plt.subplot(10,10,i)
        plt.imshow(image.reshape(28,28), cmap='gray')
        plt.axis('off')
        plt.title(train_set.classes[label.item()], fontsize=28)
        if (i >= how_many_to_plot): break
    plt.show()
    
  • 相关阅读:
    PL/SQL Developer 和 instantclient客户端安装配置(图文)
    VirtualBox + Centos 使用NAT + Host-Only 方式联网
    zookeeper的安装
    Socket编程基础篇
    WebSocket教程(二)
    WebSocket教程(一)
    Js判断浏览器类型
    JVM内存模型
    js 正则去除指定的单词
    Java线上应用故障排查之一:高CPU占用
  • 原文地址:https://www.cnblogs.com/Henry-ZHAO/p/13068647.html
Copyright © 2011-2022 走看看