zoukankan      html  css  js  c++  java
  • pytorch批训练数据构造

    这是对莫凡python的学习笔记。

    1.创建数据

    import torch
    import torch.utils.data as Data
    
    BATCH_SIZE = 8
    x = torch.linspace(1,10,10)
    y = torch.linspace(10,1,10)

    可以看到创建了两个一维数据,x:1~10,y:10~1

    2.构造数据集对象,及数据加载器对象

    torch_dataset = Data.TensorDataset(x,y)
    loader = Data.DataLoader(
                dataset = torch_dataset,
                batch_size = BATCH_SIZE,
                shuffle = False,
                num_workers = 2)

    num_workers应该指的是多线程

    3.输出数据集,这一步主要是看一下batch长什么样子

    for epoch in range(3):
        for step, (batch_x, batch_y) in  enumerate(loader):
            print('Epoch:',epoch,'| Step:', step, '| batch x:',
                     batch_x.numpy(), '| batch y:', batch_y.numpy())

    输出如下

    ('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
    ('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
    ('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
    ('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
    ('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
    ('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))

    可以看到,batch_size等于8,则第二个bacth的数据只有两个。

    将batch_size改为5,输出如下

    ('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.], dtype=float32))
    ('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 6.,  7.,  8.,  9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
    ('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.], dtype=float32))
    ('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 6.,  7.,  8.,  9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
    ('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.], dtype=float32))
    ('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 6.,  7.,  8.,  9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
  • 相关阅读:
    JavaScript高级程序设计
    昨天听了林某的建议,开了自己的博客
    Unity是什么?
    依赖注入
    NHibernate 01 [简述]
    C#Delegate.Invoke、Delegate.BeginInvoke And Control.Invoke、Control.BeginInvoke
    C#调用http请求,HttpWebRequest添加http请求头信息
    JUnit入门笔记
    Spring:利用ApplicationContextAware装配Bean
    Java线程安全synchronize学习
  • 原文地址:https://www.cnblogs.com/wzyuan/p/9459744.html
Copyright © 2011-2022 走看看