zoukankan      html  css  js  c++  java
  • 【Pytorch】dataloader使用教程

    # -*- coding: utf-8 -*-
    """
    Created on Mon Aug  3 23:30:39 2020
    
    @author: Administrator
    """
    
    import torch                       # 导入模块
    import torch.utils.data as Data
    
    BATCH_SIZE = 8                     # 每一批的数据量
    
    x=torch.linspace(1,10,10)          # 定义X为 1 到 10 等距离大小的数
    y=torch.linspace(10,1,10)
    
    # 转换成torch能识别的Dataset
    # 这个可以自定义DataSet:https://www.cnblogs.com/douzujun/p/13429912.html
    torch_dataset = Data.TensorDataset(x, y) # 将数据放入 torch_dataset
    
    loader=Data.DataLoader(
            dataset=torch_dataset,           # 将数据放入loader
            batch_size=BATCH_SIZE,           # 每个数据段大小为  BATCH_SIZE=5
            shuffle=True ,                   # 是否打乱数据的排布
            num_workers=0                    # 使用多进程加载的进程数,0代表不使用多进程
            )
    
    for epoch in range(3):
        
        for step, (batch_x,batch_y) in enumerate(loader):
            
            print('epoch',epoch,'|step:',step," | batch_x",batch_x.numpy(),
    
                  '|batch_y:',batch_y.numpy())
    
    epoch 0 |step: 0  | batch_x [ 7.  3.  1.  8. 10.  9.  5.  4.] |batch_y: [ 4.  8. 10.  3.  1.  2.  6.  7.]
    epoch 0 |step: 1  | batch_x [2. 6.] |batch_y: [9. 5.]
    epoch 1 |step: 0  | batch_x [ 6.  7.  5.  4.  1. 10.  2.  9.] |batch_y: [ 5.  4.  6.  7. 10.  1.  9.  2.]
    epoch 1 |step: 1  | batch_x [3. 8.] |batch_y: [8. 3.]
    epoch 2 |step: 0  | batch_x [ 4.  5.  7.  1.  6.  9. 10.  3.] |batch_y: [ 7.  6.  4. 10.  5.  2.  1.  8.]
    epoch 2 |step: 1  | batch_x [8. 2.] |batch_y: [3. 9.]
    

    DataLoader的函数定义如下:

    DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
               num_workers=0, collate_fn=default_collate, pin_memory=False,
               drop_last=False)
    
    • dataset:加载的数据集(Dataset对象)

    • batch_size:batch size

    • shuffle::是否将数据打乱

    • sampler: 样本抽样,后续会详细介绍

    • num_workers:使用多进程加载的进程数,0代表不使用多进程

    • collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可

    • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些

    • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

  • 相关阅读:
    koa mog
    sdl
    基于WindowImplBase 更简单 以及 可变大小的,才是标准的
    df
    ffplay vc
    开源1bo
    react学习前一部分
    0514 react路由
    nodejs 调用进程
    Ubuntu Linux, 不要弄什么 wine,龙井 或者什么等 QQ 了。
  • 原文地址:https://www.cnblogs.com/douzujun/p/13427930.html
Copyright © 2011-2022 走看看