zoukankan      html  css  js  c++  java
  • pytorch之 batch_train

     1 import torch
     2 import torch.utils.data as Data
     3 
     4 torch.manual_seed(1)    # reproducible
     5 
     6 BATCH_SIZE = 5
     7 # BATCH_SIZE = 8
     8 
     9 x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
    10 y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
    11 
    12 torch_dataset = Data.TensorDataset(x, y)
    13 loader = Data.DataLoader(
    14     dataset=torch_dataset,      # torch TensorDataset format
    15     batch_size=BATCH_SIZE,      # mini batch size
    16     shuffle=True,               # random shuffle for training
    17     num_workers=2,              # subprocesses for loading data
    18 )
    19 
    20 
    21 def show_batch():
    22     for epoch in range(3):   # train entire dataset 3 times
    23         for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
    24             # train your data...
    25             print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
    26                   batch_x.numpy(), '| batch y: ', batch_y.numpy())
    27 
    28 
    29 if __name__ == '__main__':
    30     show_batch()
  • 相关阅读:
    cannot resolve symbol 'XXX'
    jwt单点登入
    空3
    Hibernate持久化,生命周期
    Hibernate主键生成策略
    Hibernate常用api以及增删改查
    Hibernate配置流程
    Hibernate定义
    Git总结
    spring整合MQ
  • 原文地址:https://www.cnblogs.com/dhName/p/11742985.html
Copyright © 2011-2022 走看看