zoukankan      html  css  js  c++  java
  • pytorch1.0批训练神经网络

    pytorch1.0批训练神经网络

    import torch
    import torch.utils.data as Data
    # Torch 中提供了一种帮助整理数据结构的工具, 叫做 DataLoader, 能用它来包装自己的数据, 进行批训练.
    torch.manual_seed(1)    # reproducible
    # 批训练的数据个数
    BATCH_SIZE = 5
    BATCH_SIZE = 8
    
    x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
    y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
    # DataLoader 是 torch 用来包装开发者自己的数据的工具.
    # 将自己的 (numpy array 或其他) 数据形式装换成 Tensor, 然后再放进这个包装器中.
    # 使用 DataLoader 的好处就是他们帮你有效地迭代数据
    
    # 先转换成 torch 能识别的 Dataset
    torch_dataset = Data.TensorDataset(x, y)  # torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
    # 把 dataset 放入 DataLoader
    loader = Data.DataLoader(
        dataset=torch_dataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=True,               # random shuffle for training     # 随机打乱数据--打乱比较好
        num_workers=2,              # subprocesses for loading data   # 多线程来读数据
    )
    
    
    def show_batch():
        for epoch in range(3):   # train entire dataset 3 times   # 训练所有/整套数据 3 次
            for step, (batch_x, batch_y) in enumerate(loader):  # for each training step  # 每一步 loader 释放一小批数据用来学习
                # train your data...  # 假设这里就是训练的代码块...
                print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                      batch_x.numpy(), '| batch y: ', batch_y.numpy())
    
    if __name__ == '__main__':
        show_batch()
    # BATCH_SIZE = 5
    '''
    Epoch: 0 | Step: 0 | batch x: [ 5. 7. 10. 3. 4.] | batch y: [6. 4. 1. 8. 7.]
    Epoch: 0 | Step: 1 | batch x: [2. 1. 8. 9. 6.] | batch y: [ 9. 10. 3. 2. 5.]
    Epoch: 1 | Step: 0 | batch x: [ 4. 6. 7. 10. 8.] | batch y: [7. 5. 4. 1. 3.]
    Epoch: 1 | Step: 1 | batch x: [5. 3. 2. 1. 9.] | batch y: [ 6. 8. 9. 10. 2.]
    Epoch: 2 | Step: 0 | batch x: [ 4. 2. 5. 6. 10.] | batch y: [7. 9. 6. 5. 1.]
    Epoch: 2 | Step: 1 | batch x: [3. 9. 1. 8. 7.] | batch y: [ 8. 2. 10. 3. 4.]
    '''
    # BATCH_SIZE = 8
    '''
    Epoch: 0 | Step: 0 | batch x: [ 5. 7. 10. 3. 4. 2. 1. 8.] | batch y: [ 6. 4. 1. 8. 7. 9. 10. 3.]
    Epoch: 0 | Step: 1 | batch x: [9. 6.] | batch y: [2. 5.]
    Epoch: 1 | Step: 0 | batch x: [ 4. 6. 7. 10. 8. 5. 3. 2.] | batch y: [7. 5. 4. 1. 3. 6. 8. 9.]
    Epoch: 1 | Step: 1 | batch x: [1. 9.] | batch y: [10. 2.]
    Epoch: 2 | Step: 0 | batch x: [ 4. 2. 5. 6. 10. 3. 9. 1.] | batch y: [ 7. 9. 6. 5. 1. 8. 2. 10.]
    Epoch: 2 | Step: 1 | batch x: [8. 7.] | batch y: [3. 4.]
    '''
  • 相关阅读:
    iOS 自动化测试踩坑(二):Appium 架构原理、环境命令、定位方式
    干货 | 掌握 Selenium 元素定位,解决 Web 自动化测试痛点
    代理技术哪家强?接口 Mock 测试首选 Charles
    浅谈MVC缓存
    PetaPoco 快速上手
    解释器模式(26)
    享元模式(25)
    中介者模式(24)
    职责链模式(23)
    命令模式(22)
  • 原文地址:https://www.cnblogs.com/jeshy/p/11200000.html
Copyright © 2011-2022 走看看