zoukankan      html  css  js  c++  java
  • pytorch 6 batch_train 批训练

    import torch
    import torch.utils.data as Data
    
    torch.manual_seed(1)    # reproducible
    
    # BATCH_SIZE = 5  
    BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
    
    x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
    y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
    
    torch_dataset = Data.TensorDataset(x, y)
    loader = Data.DataLoader(
        dataset=torch_dataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=False,              # 设置不随机打乱数据 random shuffle for training
        num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
    )
    
    
    def show_batch():
        for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
            for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
                # train your data...
                print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                      batch_x.numpy(), '| batch y: ', batch_y.numpy())
    
    
    if __name__ == '__main__':
        show_batch()
    

    BATCH_SIZE = 8 , 所有数据利用三次

    Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
    Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
    Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
    Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
    Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
    Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
    

    END

  • 相关阅读:
    【redis】--安全
    【redis】-- 数据备份和恢复
    2018.2.8 cf
    寒假零碎的东西 不定时更新补充.......
    hdu 1018
    2018寒假acm训练计划
    UVAlive 7466
    母函数
    简单数学题(水的不能在水的题了)
    随便写写的搜索
  • 原文地址:https://www.cnblogs.com/yangzhaonan/p/10439839.html
Copyright © 2011-2022 走看看