-
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
-
num_works设置过高出错(多线程错误,使用gpu就没事了)
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 3 23:30:39 2020
@author: Administrator
"""
import torch # 导入模块
import torch.utils.data as Data
BATCH_SIZE = 8 # 每一批的数据量
x=torch.linspace(1,10,10) # 定义X为 1 到 10 等距离大小的数
y=torch.linspace(10,1,10)
# 转换成torch能识别的Dataset
# 这个可以自定义DataSet:https://www.cnblogs.com/douzujun/p/13429912.html
torch_dataset = Data.TensorDataset(x, y) # 将数据放入 torch_dataset
loader=Data.DataLoader(
dataset=torch_dataset, # 将数据放入loader
batch_size=BATCH_SIZE, # 每个数据段大小为 BATCH_SIZE=5
shuffle=True , # 是否打乱数据的排布
num_workers=0 # 使用多进程加载的进程数,0代表不使用多进程
)
for epoch in range(3):
for step, (batch_x,batch_y) in enumerate(loader):
print('epoch',epoch,'|step:',step," | batch_x",batch_x.numpy(),
'|batch_y:',batch_y.numpy())
epoch 0 |step: 0 | batch_x [ 7. 3. 1. 8. 10. 9. 5. 4.] |batch_y: [ 4. 8. 10. 3. 1. 2. 6. 7.]
epoch 0 |step: 1 | batch_x [2. 6.] |batch_y: [9. 5.]
epoch 1 |step: 0 | batch_x [ 6. 7. 5. 4. 1. 10. 2. 9.] |batch_y: [ 5. 4. 6. 7. 10. 1. 9. 2.]
epoch 1 |step: 1 | batch_x [3. 8.] |batch_y: [8. 3.]
epoch 2 |step: 0 | batch_x [ 4. 5. 7. 1. 6. 9. 10. 3.] |batch_y: [ 7. 6. 4. 10. 5. 2. 1. 8.]
epoch 2 |step: 1 | batch_x [8. 2.] |batch_y: [3. 9.]
DataLoader的函数定义如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False)
-
dataset:加载的数据集(Dataset对象)
-
batch_size:batch size
-
shuffle::是否将数据打乱
-
sampler: 样本抽样,后续会详细介绍
-
num_workers:使用多进程加载的进程数,0代表不使用多进程
-
collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
-
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
-
drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃