zoukankan      html  css  js  c++  java
  • 批量训练网络

    """
    如果整个数据库中图片的数量不是每批数据图片数量的整数倍,体统会将剩余的图片放入最后一批
    """
    import torch
    import torch.utils.data as Data
    
    torch.manual_seed(1)    # reproducible
    
    BATCH_SIZE = 5
    
    x = torch.linspace(1, 10, 10)
    y = torch.linspace(10, 1, 10)
    
    # 定义一个数据库
    torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
    
    # 将数据库中的数据变成多批数据集
    loader = Data.DataLoader(
        dataset=torch_dataset,      # TensorDataset数据集
        batch_size=BATCH_SIZE,      # 每批数据的大小
        shuffle=True,               # 打乱顺序
        num_workers=2,              # 加载数据的进程数
    )
    
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # 这个部分为训练数据
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
  • 相关阅读:
    第三章例3-3
    第三章例3-2
    第二章例2-11
    第二章例2-10
    第二章例2-9
    204
    205
    202
    203
    201
  • 原文地址:https://www.cnblogs.com/czz0508/p/10335699.html
Copyright © 2011-2022 走看看