zoukankan      html  css  js  c++  java
  • Pytorch数据读取框架

    训练一个模型需要有一个数据库,一个网络,一个优化函数。数据读取是训练的第一步,以下是pytorch数据输入框架。

    1)实例化一个数据库

    假设我们已经定义了一个FaceLandmarksDataset数据库,此数据库将在以下建立。

    import FaceLandmarksDataset
    face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                        root_dir='data/faces/',
                                        transform=transforms.Compose([ Rescale(256), RandomCrop(224), ToTensor()]) )

    或者使用torchvision.datasets里封装的数据集(MNIST、Fashion-MNIST、KMNIST、EMNIST、COCO、LSUN、ImageFolder、DatasetFolder、Imagenet-12、CIFAR、STL10、SVHN、PhotoTour、SBU、Flickr、VOC、Cityscapes)

    import torchvision.datasets
    imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')

    2)创建一个数据加载器

    import torch.utils.data.DataLoader
    imagenet_loader = torch.utils.data.DataLoader(imagenet_data,
                                              batch_size=4,  
                                              shuffle=True,
                                              num_workers=4)
    #or
    
    facelandmark_loader = torch.utils.data.DataLoader(face_dataset,
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=4) 

    可见,数据加载器是通用的,只有数据库实例不一样,其它的都参数都一样,参数值可以根据任务需要自己调。

    3)使用数据库

    数据加载器可迭代的,我们可以使用数据库:

    for item in facelandmark_loader:
         images,labels = item
    do_somethi

    当然, 我们也可以直接对数据库实例face_dataset进行下标操作,但这样只能够每次获取一条数据。

    sample = face_dataset[index]
  • 相关阅读:
    Lucene 基础理论
    .NET Micro Framework V4.2 QFE2新版本简介
    FlashPaper
    在django中实现QQ登录
    基于lucene的搜索服务器
    ASP.NET MVC的Razor引擎:RazorViewEngine
    .Net Micro Framework
    关于基于DDD+Event Sourcing设计的模型如何处理模型重构的问题的思考
    泛型
    Log4j源码分析及配置拓展
  • 原文地址:https://www.cnblogs.com/houjun/p/10214017.html
Copyright © 2011-2022 走看看