zoukankan      html  css  js  c++  java
  • Pytorch中dataloader之enumerate与iter,tqdm

    dataloader本质上是一个可迭代对象,使用iter()访问,不能使用next()访问;

    使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问;

    也可以使用for inputs,labels in enumerate(dataloader)形式访问,但是enumerate和iter的区别是什么呢?暂时不明白。

    补充:

    如下代码形式调用enumerate(dataloader'train')每次都会读出一个batchsize的数据,可根据以下代码做一个测试。下面代码的前提是,数据集中总共包含245张图像,dataloader'train'设置时drop_last=True,其中batch_size=32,经过以下代码后输出的count为224(正好等于32*7),而多出来的245-224=21张图像不够一个batch因此被drop掉了。换句话说,enumerate(dataloader'train')会把dataloader'train'中的数据一个batch一个batch地取出来用于训练。也就是说,使用enumerate进行dataloader中的数据读取用于神经网络的训练是第一种数据读取方法,其基本形式即为for index, item in enumerate(dataloader'train'),其中item中0为数据,1为label.

    count=0
    for index, item in enumerate(dataloader['train']):
        count+=len(item[1].numpy)
    print(count)
    

    第二种读取dataloader中数据的方法是使用默认的iter函数,其基本样式可参照以下代码:

    for epoch in range(opt.begin_epoch,opt.end_epoch):
        iter=myDataLoader['train'].__iter__() #返回值iter是一个基本的迭代器
        batchNum=len(myDataLoader['train']) #返回batch的数量,如上应该等于7
    
        myNet.train() #使得我定义的网络进入训练模式
        
        for i in range(0,batchNum):
            batchData=iter.__next__() #读取一个batch的数据,batchsize=32时实际对应32张图像
            img=batchData[0].to(opt.device) #opt.device=cuda,即转移到GPU运行
            ...
    

    还有一个tqdm后续补充!!!

    tqdm是一个可以显示进度条的模块

    from tqdm import tqdm
    for item in tqdm(range(100)):
        # do something
    

    enumerate()函数是python的内置函数,可以同时遍历lt中元素及其索引,i是索引,item是lt中的元素,如图:

    from tqdm import tqdm
    lt=['a','b','c']
    for i, item in enumerate(lt):
        print(i,item)
    
    #输出结果如下:
    0 a
    1 b
    2 c
    

    tqdm和enumerate()结合

    from tqdm import tqdm
    lt=['a','b','c']
    for i,item in enumerate(tqdm(lt))
        print(i,item)
    

    以上关于tqdm部分内容转载自:https://blog.csdn.net/m0_37586991/article/details/89435193

    原文链接:https://www.daimajiaoliu.com/daima/479d31ca3100408




    如果这篇文章帮助到了你,你可以请作者喝一杯咖啡

  • 相关阅读:
    tomcat使用不同的jdk版本 liunx 装两个jdk
    接下来自己的研究对象
    钉钉小程序开发的所有坑
    java 在web应用中获取本地目录和服务器上的目录不一致的问题
    Python2.7更新pip:UnicodeDecodeError: 'ascii' codec can't decode byte 0xb7 in position 7: ordinal not in range(128)
    vue项目中禁止移动端双击放大,双手拉大放大的方法
    JZ56 删除链表中重复的结点
    JZ55 链表中环的入口结点
    JZ54 字符流中第一个不重复的字符
    JZ53 表示数值的字符串
  • 原文地址:https://www.cnblogs.com/sddai/p/14648966.html
Copyright © 2011-2022 走看看