zoukankan      html  css  js  c++  java
  • (十)pytorch多线程训练,DataLoader的num_works参数设置

    一、概述

    数据集较小时(小于2W)建议num_works不用管默认就行,因为用了反而比没用慢。
    当数据集较大时建议采用,num_works一般设置为(CPU线程数+-1)为最佳,可以用以下代码找出最佳num_works(注意windows用户如果要使用多核多线程必须把训练放在if __name__ == '__main__':下才不会报错)

    二、代码

    import time
    import torch.utils.data as d
    import torchvision
    import torchvision.transforms as transforms
     
     
    if __name__ == '__main__':
        BATCH_SIZE = 100
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5,), (0.5,))])
        train_set = torchvision.datasets.MNIST('mnist', download=False, train=True, transform=transform)
        
        # data loaders
        train_loader = d.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
        
        for num_workers in range(20):
            train_loader = d.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)
            # training ...
            start = time.time()
            for epoch in range(1):
                for step, (batch_x, batch_y) in enumerate(train_loader):
                    pass
            end = time.time()
            print('num_workers is {} and it took {} seconds'.format(num_workers, end - start))

     三、查看线程数

    1、cpu个数

    grep 'physical id' /proc/cpuinfo | sort -u

    2、核心数

    grep 'core id' /proc/cpuinfo | sort -u | wc -l

    3、线程数

    grep 'processor' /proc/cpuinfo | sort -u | wc -l

    4、例子

    命令执行结果如图所示,根据结果得知,此服务器有1个cpu,6个核心,每个核心2线程,共12线程。

  • 相关阅读:
    spring自动装配的歧义性
    spring装配bean
    spring面向切面编程理解
    spring入门实现打印Hello Spring!
    spring依赖注入的理解
    java中数组和集合的区别
    java中final关键字的作用
    什么是视图?
    什么是事务?
    sql多表查询的总结
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/15079209.html
Copyright © 2011-2022 走看看