zoukankan      html  css  js  c++  java
  • pytorch批训练数据构造

    这是对莫凡python的学习笔记。

    1.创建数据

    import torch
    import torch.utils.data as Data
    
    BATCH_SIZE = 8
    x = torch.linspace(1,10,10)
    y = torch.linspace(10,1,10)

    可以看到创建了两个一维数据,x:1~10,y:10~1

    2.构造数据集对象,及数据加载器对象

    torch_dataset = Data.TensorDataset(x,y)
    loader = Data.DataLoader(
                dataset = torch_dataset,
                batch_size = BATCH_SIZE,
                shuffle = False,
                num_workers = 2)

    num_workers应该指的是多线程

    3.输出数据集,这一步主要是看一下batch长什么样子

    for epoch in range(3):
        for step, (batch_x, batch_y) in  enumerate(loader):
            print('Epoch:',epoch,'| Step:', step, '| batch x:',
                     batch_x.numpy(), '| batch y:', batch_y.numpy())

    输出如下

    ('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
    ('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
    ('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
    ('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
    ('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
    ('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))

    可以看到,batch_size等于8,则第二个bacth的数据只有两个。

    将batch_size改为5,输出如下

    ('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.], dtype=float32))
    ('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 6.,  7.,  8.,  9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
    ('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.], dtype=float32))
    ('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 6.,  7.,  8.,  9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
    ('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.], dtype=float32))
    ('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 6.,  7.,  8.,  9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
  • 相关阅读:
    dd if=/dev/zero of=/dev/null 使用
    Linux 下的dd命令使用详解以及dd if=/dev/zero of=的含义
    windows 以及 linux 查看时间
    Linux下vi命令大全(文件修改)
    python test online
    python ssh登录下载上传脚本
    python telnet 中的数据判断(或者执行cmd后返回的数据 OperatingSystem.Run)
    python 转化串口中的数据 ,并分组判断
    python cmd下关闭exe程序(关闭浏览器驱动)
    robot 网卡连接情况
  • 原文地址:https://www.cnblogs.com/wzyuan/p/9459744.html
Copyright © 2011-2022 走看看