zoukankan      html  css  js  c++  java
  • PyTorch-批量训练技巧

    来自:https://morvanzhou.github.io/tutorials/machine-learning/torch/3-05-train-on-batch/ 

    import torch
    import torch.utils.data as Data
    
    torch.manual_seed(1)
    
    BATCH_SIZE = 8  # 批训练的数据个数
    
    x = torch.linspace(1, 10, 10)  # x data (torch tensor)
    y = torch.linspace(10, 1, 10)  # y data (torch tensor)
    
    # 先转换成 torch 能识别的 Dataset
    torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
    
    # 把 dataset 放入 DataLoader
    loader = Data.DataLoader(
        dataset=torch_dataset,  # torch TensorDataset format
        batch_size=BATCH_SIZE,  # mini batch size
        shuffle=True,  # 要不要打乱数据 (打乱比较好)
        num_workers=2,  # 多线程来读数据
    )
    
    for epoch in range(3):  # 训练所有!整套!数据 3 次
        for step, (batch_x, batch_y) in enumerate(loader):  # 每一步 loader 释放一小批数据用来学习
            # 假设这里就是你训练的地方...
    
            # 打出来一些数据
            print('批次Epoch: ', epoch, '| Step: ', step, '| 数据batch x: ',
                  batch_x.numpy(), '| y: ', batch_y.numpy())
    

    批次Epoch:  0 | Step:  0 | 数据batch x:  [  6.   7.   2.   3.   1.   9.  10.   4.] | y:  [  5.   4.   9.   8.  10.   2.   1.   7.]
    批次Epoch:  0 | Step:  1 | 数据batch x:  [ 8.  5.] | y:  [ 3.  6.]
    批次Epoch:  1 | Step:  0 | 数据batch x:  [  3.   4.   2.   9.  10.   1.   7.   8.] | y:  [  8.   7.   9.   2.   1.  10.   4.   3.]
    批次Epoch:  1 | Step:  1 | 数据batch x:  [ 5.  6.] | y:  [ 6.  5.]
    批次Epoch:  2 | Step:  0 | 数据batch x:  [  3.   9.   2.   6.   7.  10.   4.   8.] | y:  [ 8.  2.  9.  5.  4.  1.  7.  3.]
    批次Epoch:  2 | Step:  1 | 数据batch x:  [ 1.  5.] | y:  [ 10.   6.]

  • 相关阅读:
    elasticsearch7.1 安装启动报错
    jvm调优
    基于redis实现IP访问频次控制
    docker 搭建redis集群
    Tomcat安全配置与性能优化
    mybaties 的 applicationContext.xml
    SSH阶段常见错误及说明
    hibernate 7种映射关系
    (四)SpringBoot如何定义消息转换器
    java之package与import
  • 原文地址:https://www.cnblogs.com/onenoteone/p/12441711.html
Copyright © 2011-2022 走看看