zoukankan      html  css  js  c++  java
  • 显存计算与优化

    参考链接

    如何计算模型以及中间变量的显存占用大小:

    https://oldpan.me/archives/how-to-calculate-gpu-memory

    如何在Pytorch中精细化利用显存:

    https://oldpan.me/archives/how-to-use-memory-pytorch

    torchsummary库打印信息:
    https://blog.csdn.net/andyL_05/article/details/109266862

    计算方式

    模型权重及中间变量显存占用计算:

    # 模型显存占用监测函数
    # model:输入的模型
    # input:实际中需要输入的Tensor变量
    # type_size 默认为 4 默认类型为 float32 
     
    def modelsize(model, input, type_size=4):
        para = sum([np.prod(list(p.size())) for p in model.parameters()])
        print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))
     
        input_ = input.clone()
        input_.requires_grad_(requires_grad=False)
     
        mods = list(model.modules())
        out_sizes = []
     
        for i in range(1, len(mods)):
            m = mods[i]
            if isinstance(m, nn.ReLU):
                if m.inplace:
                    continue
            out = m(input_)
            out_sizes.append(np.array(out.size()))
            input_ = out
     
        total_nums = 0
        for i in range(len(out_sizes)):
            s = out_sizes[i]
            nums = np.prod(np.array(s))
            total_nums += nums
     
     
        print('Model {} : intermedite variables: {:3f} M (without backward)'
              .format(model._get_name(), total_nums * type_size / 1000 / 1000))
        print('Model {} : intermedite variables: {:3f} M (with backward)'
              .format(model._get_name(), total_nums * type_size*2 / 1000 / 1000))
    

    显存占用追踪工具

    https://github.com/Oldpan/Pytorch-Memory-Utils

    import gc
    import datetime
    import pynvml
    
    import torch
    import numpy as np
    
    
    class MemTracker(object):
        """
        Class used to track pytorch memory usage
        Arguments:
            frame: a frame to detect current py-file runtime
            detail(bool, default True): whether the function shows the detail gpu memory usage
            path(str): where to save log file
            verbose(bool, default False): whether show the trivial exception
            device(int): GPU number, default is 0
        """
    
        def __init__(self, frame, detail=True, path='', verbose=False, device=0):
            self.frame = frame
            self.print_detail = detail
            self.last_tensor_sizes = set()
            self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt'
            self.verbose = verbose
            self.begin = True
            self.device = device
    
            self.func_name = frame.f_code.co_name
            self.filename = frame.f_globals["__file__"]
            if (self.filename.endswith(".pyc") or
                    self.filename.endswith(".pyo")):
                self.filename = self.filename[:-1]
            self.module_name = self.frame.f_globals["__name__"]
            self.curr_line = self.frame.f_lineno
    
        def get_tensors(self):
            for obj in gc.get_objects():
                try:
                    if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                        tensor = obj
                    else:
                        continue
                    if tensor.is_cuda:
                        yield tensor
                except Exception as e:
                    if self.verbose:
                        print('A trivial exception occured: {}'.format(e))
    
        def track(self):
            """
            Track the GPU memory usage
            """
            pynvml.nvmlInit()
            handle = pynvml.nvmlDeviceGetHandleByIndex(self.device)
            meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
            self.curr_line = self.frame.f_lineno
            where_str = self.module_name + ' ' + self.func_name + ':' + ' line ' + str(self.curr_line)
    
            with open(self.gpu_profile_fn, 'a+') as f:
    
                if self.begin:
                    f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |"
                            f" Total Used Memory:{meminfo.used / 1000 ** 2:<7.1f}Mb
    
    ")
                    self.begin = False
    
                if self.print_detail is True:
                    ts_list = [tensor.size() for tensor in self.get_tensors()]
                    new_tensor_sizes = {
                        (type(x), tuple(x.size()), ts_list.count(x.size()), np.prod(np.array(x.size())) * 4 / 1000 ** 2)
                        for x in self.get_tensors()}
                    for t, s, n, m in new_tensor_sizes - self.last_tensor_sizes:
                        f.write(f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m * n)[:6]} M | {str(t):<20}
    ')
                    for t, s, n, m in self.last_tensor_sizes - new_tensor_sizes:
                        f.write(f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m * n)[:6]} M | {str(t):<20} 
    ')
                    self.last_tensor_sizes = new_tensor_sizes
    
                f.write(f"
    At {where_str:<50}"
                        f"Total Used Memory:{meminfo.used / 1000 ** 2:<7.1f}Mb
    
    ")
    
            pynvml.nvmlShutdown()
    

    追踪检测

    import torch
    import inspect
    
    from torchvision import models
    from gpu_mem_track import MemTracker  # 引用显存跟踪代码
    
    device = torch.device('cuda:0')
    
    frame = inspect.currentframe()     
    gpu_tracker = MemTracker(frame)      # 创建显存检测对象
    
    gpu_tracker.track()                  # 开始检测
    cnn = models.vgg19(pretrained=True).to(device)  # 导入VGG19模型并且将数据转到显存中
    gpu_tracker.track()
    
  • 相关阅读:
    绳关节(b2RopeJoint)
    公式推导 圆面积公式 圆周长公式
    ALTER TABLE causes auto_increment resulting key 'PRIMARY'
    MySQL通过Binlog恢复删除的表
    Barracuda VS antelope
    mysqldump 参数说明
    MySQL复制
    MySQL Server-id的作用
    MySQL 简洁连接数据库方式
    Kill 所有MySQL进程
  • 原文地址:https://www.cnblogs.com/xiaxuexiaoab/p/14467430.html
Copyright © 2011-2022 走看看