zoukankan      html  css  js  c++  java
  • Pytorch基础(5)——批数据训练

    一、知识点:

    • 相关包:torch.utils.data

    import torch
    import torch.utils.data as Data
    • 包装数据类:TensorDataset

    【包装数据和目标张量的数据集,通过沿着第一个维度索引两个张量来】

    class torch.utils.data.TensorDataset(data_tensor, target_tensor)
    #data_tensor (Tensor) - 包含样本数据
    #target_tensor (Tensor) - 包含样本目标(标签)

     

    • 加载数据类:DataLoader

    【数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。】

    class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
    #num_workers (int, optional) – 用多少个子进程加载数据
    #drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

    二、利用torch.utils.data进行批数据训练:

    导入包:

    import torch
    import torch.utils.data as Data

    设置参数并创建数据:

    Batch_size = 5
    
    x = torch.linspace(1,10,10)
    y = torch.linspace(10,1,10)

    将数据包装到TensorDataset中:

    torch_dataset = Data.TensorDataset(x , y)

    加载数据:

    loader = Data.DataLoader(
        dataset = torch_dataset,
        batch_size = Batch_size,
        shuffle=True,
        num_workers = 2,  #采用两个进程来提取
    )

    epoch 3次,每次epoch的训练步数steps = 2【batch_size = 5,总数据量为10】:

    若最后不够一个batch_size,就只拿剩下的。

    for epoch in range(3):
        for step , (batch_x,batch_y) in enumerate(loader):
            #training……
            print('epoch:',epoch,
                  '| step:',step,
                  '| batch_x:',batch_x.numpy(),
                  '| batch_y:',batch_y.numpy()    
                 )

    结果:

  • 相关阅读:
    【题解】P1999 高维正方体
    【题解】 P1850 [NOIP2016 提高组] 换教室(又是一道debug的DP,debug经验++)
    【题解】P1439 【模板】最长公共子序列
    【笔记】还是发上来作为学习过的记录吧,凌乱,勿进
    为什么我不会做数位DP
    【题解】HUD3652 B-number && 数位DP学习笔记
    【题解】LIS(longest increasing subsequence)最长上升子序列
    lingo重点部分快速上手
    koa2转移json文件地址
    Koa2创建项目
  • 原文地址:https://www.cnblogs.com/Lee-yl/p/10139766.html
Copyright © 2011-2022 走看看