zoukankan      html  css  js  c++  java
  • (十)pytorch多线程训练,DataLoader的num_works参数设置

    一、概述

    数据集较小时(小于2W)建议num_works不用管默认就行,因为用了反而比没用慢。
    当数据集较大时建议采用,num_works一般设置为(CPU线程数+-1)为最佳,可以用以下代码找出最佳num_works(注意windows用户如果要使用多核多线程必须把训练放在if __name__ == '__main__':下才不会报错)

    二、代码

    import time
    import torch.utils.data as d
    import torchvision
    import torchvision.transforms as transforms
     
     
    if __name__ == '__main__':
        BATCH_SIZE = 100
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5,), (0.5,))])
        train_set = torchvision.datasets.MNIST('mnist', download=False, train=True, transform=transform)
        
        # data loaders
        train_loader = d.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
        
        for num_workers in range(20):
            train_loader = d.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)
            # training ...
            start = time.time()
            for epoch in range(1):
                for step, (batch_x, batch_y) in enumerate(train_loader):
                    pass
            end = time.time()
            print('num_workers is {} and it took {} seconds'.format(num_workers, end - start))

     三、查看线程数

    1、cpu个数

    grep 'physical id' /proc/cpuinfo | sort -u

    2、核心数

    grep 'core id' /proc/cpuinfo | sort -u | wc -l

    3、线程数

    grep 'processor' /proc/cpuinfo | sort -u | wc -l

    4、例子

    命令执行结果如图所示,根据结果得知,此服务器有1个cpu,6个核心,每个核心2线程,共12线程。

  • 相关阅读:
    面试题:能谈谈Date、Datetime、Time、Timestamp、year的区别吗?
    面试题:对NotNull字段插入Null值 有啥现象?
    聊聊什么是慢查、如何监控?如何排查?
    谈谈MySQL的基数统计
    .vimrc
    HISKrrr的板子库
    CSP 模拟35
    晚测1
    CSP 模拟34
    nim板子题异或正确性YY
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/15079209.html
Copyright © 2011-2022 走看看