zoukankan      html  css  js  c++  java
  • pytorch性能瓶颈检查

    看到pytorch居然自带了瓶颈检查的工具:torch.utils.bottleneck 用法:
    python -m torch.utils.bottleneck 待测脚本路径
    例如你原来执行的是 python train.py,现在改成 python -m torch.utils.bottleneck train.py
    写个测试用例:

    import numpy as np
    
    class RNN(object):
        def __init__(self, input_size, hidden_size, output_size):
            self.input_size = input_size
            self.hidden_size = hidden_size
            self.output_size = output_size
            self.Wx = np.random.randn(self.input_size, self.hidden_size)
            self.Wh = np.random.randn(self.hidden_size, self.hidden_size)
            self.W = np.random.randn(self.hidden_size, self.output_size)
            self.bh = np.zeros((1, self.hidden_size))
            self.b = np.zeros((1, self.output_size))
            self.h = np.zeros((1, self.hidden_size))
        def forward(self, x):
            self.h = np.tanh(np.dot(x, self.Wx) + np.dot(self.h, self.Wh) + self.bh)
            y = np.dot(self.h, self.W) + self.b
            return y
    
    # finbonacci sequence
    def dfs(x):
        if x <= 1:
            return x
        return dfs(x-1) + dfs(x-2)
    
    
    if __name__ == '__main__':
        rnn = RNN(input_size=2, hidden_size=2, output_size=2)
        x = np.array([[0, 1], [2, 3], [4, 5], [6, 7]])
        y = rnn.forward(x)
        tmp = dfs(30)
        print(tmp)
        print(y)
    

    性能分析结果:

    点击查看分析结果
    (deeplearning) ➜  268week python -m torch.utils.bottleneck model.py
    `bottleneck` is a tool that can be used as an initial step for debugging
    bottlenecks in your program.
    
    It summarizes runs of your script with the Python profiler and PyTorch's
    autograd profiler. Because your script will be profiled, please ensure that it
    exits in a finite amount of time.
    
    For more complicated uses of the profilers, please see
    https://docs.python.org/3/library/profile.html and
    https://pytorch.org/docs/master/autograd.html#profiler for more information.
    Running environment analysis...
    Running your script with cProfile
    832040
    [[-0.90337846 -0.35234026]
     [-0.90253927 -0.52576657]
     [-0.88766077 -0.52643475]
     [-0.87231183 -0.52685616]]
    Running your script with the autograd profiler...
    832040
    [[0.95950176 0.46764053]
     [2.72475046 1.42962624]
     [2.88909968 1.66971284]
     [2.90319655 1.70963283]]
    --------------------------------------------------------------------------------
      Environment Summary
    --------------------------------------------------------------------------------
    PyTorch 1.9.0 DEBUG not compiled w/ CUDA
    Running with Python 3.8 and 
    
    `pip3 list` truncated output:
    numpy==1.20.1
    torch==1.9.0
    torch-cluster==1.5.9
    torch-geometric==2.0.2
    torch-scatter==2.0.9
    torch-sparse==0.6.12
    torch-spline-conv==1.2.1
    torch-summary==1.4.5
    torch-tb-profiler==0.3.1
    torchaudio==0.9.0a0+33b2469
    torchfile==0.1.0
    torchtext==0.10.0
    torchvision==0.10.0
    --------------------------------------------------------------------------------
      cProfile output
    --------------------------------------------------------------------------------
             2692826 function calls (278 primitive calls) in 0.416 seconds
    
       Ordered by: internal time
       List reduced from 69 to 15 due to restriction <15>
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    2692537/1    0.414    0.000    0.414    0.414 model.py:20(dfs)
            5    0.001    0.000    0.001    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
            1    0.000    0.000    0.000    0.000 /Users/rogn/opt/anaconda3/envs/deeplearning/lib/python3.8/site-packages/numpy/core/arrayprint.py:890(fillFormat)
            3    0.000    0.000    0.000    0.000 {method 'randn' of 'numpy.random.mtrand.RandomState' objects}
            2    0.000    0.000    0.001    0.000 {built-in method builtins.print}
            1    0.000    0.000    0.416    0.416 model.py:1(<module>)
            1    0.000    0.000    0.001    0.001 model.py:14(forward)
         13/1    0.000    0.000    0.000    0.000 /Users/rogn/opt/anaconda3/envs/deeplearning/lib/python3.8/site-packages/numpy/core/arrayprint.py:745(recurser)
           16    0.000    0.000    0.000    0.000 {built-in method numpy.core._multiarray_umath.dragon4_positional}
            1    0.000    0.000    0.000    0.000 {built-in method builtins.locals}
            1    0.000    0.000    0.001    0.001 /Users/rogn/opt/anaconda3/envs/deeplearning/lib/python3.8/site-packages/numpy/core/arrayprint.py:1500(_array_str_implementation)
            1    0.000    0.000    0.001    0.001 /Users/rogn/opt/anaconda3/envs/deeplearning/lib/python3.8/site-packages/numpy/core/arrayprint.py:409(_get_format_function)
            1    0.000    0.000    0.001    0.001 /Users/rogn/opt/anaconda3/envs/deeplearning/lib/python3.8/site-packages/numpy/core/arrayprint.py:516(array2string)
            1    0.000    0.000    0.000    0.000 model.py:4(__init__)
            1    0.000    0.000    0.000    0.000 /Users/rogn/opt/anaconda3/envs/deeplearning/lib/python3.8/site-packages/numpy/core/arrayprint.py:366(<lambda>)
    
    
    --------------------------------------------------------------------------------
      autograd profiler output (CPU mode)
    --------------------------------------------------------------------------------
            top 15 events sorted by cpu_time_total
    

    显然这个程序的瓶颈在 model.py:20(dfs) 这个函数调用

    个性签名:时间会解决一切
  • 相关阅读:
    Linux目录规范和含义(转)
    一个mq崩溃的线上问题
    砥砺前行_二零二一年年终总结
    https://gitscm.com/book/zh/v2服务器上的 Git 配置服务器
    TCGA 数据下载分析利器 —— TCGAbiolinks(二)临床数据下载
    tidyverse
    ActiveX关于“此网页需要运行以下加载项:"xxx" 的 "xxx" ” 是否允许的询问
    Vista/Win7以上系统查看和清除本地DNS缓存新方法
    Delphi 动态数组另类笔记
    Delphi 开发ActiveX控件(非ActiveForm)
  • 原文地址:https://www.cnblogs.com/lfri/p/15586649.html
Copyright © 2011-2022 走看看