zoukankan      html  css  js  c++  java
  • pytorch实现批训练

    代码:

    #进行批训练
    import torch
    import torch.utils.data as Data
    
    BATCH_SIZE = 5  #每批5个数据
    
    if __name__ == '__main__':
        x = torch.linspace(1, 10, 10)  #x是从1到10共10个数据
        y = torch.linspace(10, 1, 10)  #y是从10到1共10个数据
    
        #torch_dataset = Data.TensorDataset(data_tensor = x, target_tensor=y)会报错
        torch_dataset = Data.TensorDataset(x,y)
        loader = Data.DataLoader(      #使我们的训练变成一小批一小批的
            dataset = torch_dataset,   #将所有数据放入dataset中
            batch_size= BATCH_SIZE,
            shuffle=True,              #true训练的时候随机打乱数据,false不打乱
            num_workers=2,             #每次训练用两个线程或进程进行提取
        )   
    
        for epoch in range(3):
            for step, (batch_x, batch_y) in enumerate(loader):  #利用enumerate可以同时获得索引(step)和值
                print('Epoch:', epoch, '| Step:', step, '| batch_x:', 
                batch_x.numpy(), '| batch_y:', batch_y.numpy())
    
    
    过程中遇到了问题,问题及解决办法都在https://blog.csdn.net/thunderf/article/details/94733747
  • 相关阅读:
    sizeof与strlen的区别
    面试题46:求1+2+...+n
    opennebula 安装指定参数
    opennebula 开发记录
    virsh 查看hypervisor特性
    opennebula kvm日志
    Cgroup
    opennebula kvm 创建VM oned报错日志
    opennebula kvm 创建虚拟机错误
    golang hello
  • 原文地址:https://www.cnblogs.com/loyolh/p/12299891.html
Copyright © 2011-2022 走看看