zoukankan      html  css  js  c++  java
  • 【转载】 浅谈PyTorch的可重复性问题(如何使实验结果可复现)

    原文地址:

    https://www.zhangshengrong.com/p/9MNlDK09NJ/

     ================================================

    由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

    许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

    全部设置可以分为三部分:

     1. CUDNN

    cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

    from torch.backends import cudnn
    cudnn.benchmark = False      # if benchmark=True, deterministic will be False
    cudnn.deterministic = True

    不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

    2. Pytorch

    torch.manual_seed(seed)      # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

    3. Python & Numpy

    如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)

    最后,关于dataloader:

    注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

    对于不同线程的随机数种子设置,主要通过DataLoader的 worker_init_fn 参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

    GLOBAL_SEED = 1
     
    def set_seed(seed):
      random.seed(seed)
      np.random.seed(seed)
      torch.manual_seed(seed)
      torch.cuda.manual_seed(seed)
      torch.cuda.manual_seed_all(seed)
     
    GLOBAL_WORKER_ID = None
    def worker_init_fn(worker_id):
      global GLOBAL_WORKER_ID
      GLOBAL_WORKER_ID = worker_id
      set_seed(GLOBAL_SEED + worker_id)
     
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

    以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

    =================================================

    本博客是博主个人学习时的一些记录,不保证是为原创,个别文章加入了转载的源地址还有个别文章是汇总网上多份资料所成,在这之中也必有疏漏未加标注者,如有侵权请与博主联系。
  • 相关阅读:
    [Z] Windows 8/10 audio编程
    [Z]The Boost C++ Libraries
    [Z] windows进程在32、64位系统里用户和系统空间的地址范围
    [Z] 关于c++ typename的另一种用法
    [z] 人工智能和图形学、图像处理方面的各种会议的评级
    [Z] 计算机类会议期刊根据引用数排名
    关于windows的service编程
    关于Linux session管理与GUI架构
    搭建框架-ECS.ECommerce
    不调用构造函数而创建一个类型实例
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/14693658.html
Copyright © 2011-2022 走看看