zoukankan      html  css  js  c++  java
  • pytorch之 batch_train

     1 import torch
     2 import torch.utils.data as Data
     3 
     4 torch.manual_seed(1)    # reproducible
     5 
     6 BATCH_SIZE = 5
     7 # BATCH_SIZE = 8
     8 
     9 x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
    10 y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
    11 
    12 torch_dataset = Data.TensorDataset(x, y)
    13 loader = Data.DataLoader(
    14     dataset=torch_dataset,      # torch TensorDataset format
    15     batch_size=BATCH_SIZE,      # mini batch size
    16     shuffle=True,               # random shuffle for training
    17     num_workers=2,              # subprocesses for loading data
    18 )
    19 
    20 
    21 def show_batch():
    22     for epoch in range(3):   # train entire dataset 3 times
    23         for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
    24             # train your data...
    25             print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
    26                   batch_x.numpy(), '| batch y: ', batch_y.numpy())
    27 
    28 
    29 if __name__ == '__main__':
    30     show_batch()
  • 相关阅读:
    Date类型 方法
    迭代方法和归并函数
    js快速排序方法
    reset
    水平垂直居中
    css清除浮动
    box-shadow
    display---我的第一篇博客
    centos7基础安装
    aws和ufile挂载数据盘EBS
  • 原文地址:https://www.cnblogs.com/dhName/p/11742985.html
Copyright © 2011-2022 走看看