zoukankan      html  css  js  c++  java
  • Pytorch基础(5)——批数据训练

    一、知识点:

    • 相关包:torch.utils.data

    import torch
    import torch.utils.data as Data
    • 包装数据类:TensorDataset

    【包装数据和目标张量的数据集,通过沿着第一个维度索引两个张量来】

    class torch.utils.data.TensorDataset(data_tensor, target_tensor)
    #data_tensor (Tensor) - 包含样本数据
    #target_tensor (Tensor) - 包含样本目标(标签)

     

    • 加载数据类:DataLoader

    【数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。】

    class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
    #num_workers (int, optional) – 用多少个子进程加载数据
    #drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

    二、利用torch.utils.data进行批数据训练:

    导入包:

    import torch
    import torch.utils.data as Data

    设置参数并创建数据:

    Batch_size = 5
    
    x = torch.linspace(1,10,10)
    y = torch.linspace(10,1,10)

    将数据包装到TensorDataset中:

    torch_dataset = Data.TensorDataset(x , y)

    加载数据:

    loader = Data.DataLoader(
        dataset = torch_dataset,
        batch_size = Batch_size,
        shuffle=True,
        num_workers = 2,  #采用两个进程来提取
    )

    epoch 3次,每次epoch的训练步数steps = 2【batch_size = 5,总数据量为10】:

    若最后不够一个batch_size,就只拿剩下的。

    for epoch in range(3):
        for step , (batch_x,batch_y) in enumerate(loader):
            #training……
            print('epoch:',epoch,
                  '| step:',step,
                  '| batch_x:',batch_x.numpy(),
                  '| batch_y:',batch_y.numpy()    
                 )

    结果:

  • 相关阅读:
    Mybatis中#{}和${}传参的区别
    笔记摘抄 —— shiro学习篇
    使用Spring的Testcase的单元测试的写法
    【转】FreeMarker学习笔记
    破解Pycharm,IDEA,PhpStrom等系列产品的,有关JetbrainsCrack的使用方法
    Python的字符串
    python的变量
    python开头注释
    h5-动画小案例-滚动展示
    h5-钟表动画案例
  • 原文地址:https://www.cnblogs.com/Lee-yl/p/10139766.html
Copyright © 2011-2022 走看看