zoukankan      html  css  js  c++  java
  • Pytorch checkpoint

    checkpoint一种用时间换空间的策略

    torch.utils.checkpoint.checkpoint(function*args**kwargs)

     

    为模型或模型的一部分设置Checkpoint 。

    检查点用计算换内存(节省内存)。 检查点部分并不保存中间激活值,而是在反向传播时重新计算它们。 它可以应用于模型的任何部分。

    具体而言,在前向传递中,function将以torch.no_grad()的方式运行,即不存储中间激活值。 相反,前向传递将保存输入元组和function参数。 在反向传播时,检索保存的输入和function参数,然后再次对函数进行正向计算,现在跟踪中间激活值,然后使用这些激活值计算梯度。

    (也即,检查点部分在前向计算时不存储中间量,等反向传播需要计算梯度时重新计算这些中间量)

    WARNING

    • 检查点不适用于torch.autograd.grad(),而仅适用于torch.autograd.backward()。
    • 如果反向传播过程中的函数调用与前向传播过程中的函数调用有任何的不同,例如由于某个全局变量,则检查点版本将不相等,并且很遗憾,它无法被检测到。

    Parameters

    function:

    描述模型或模型的一部分在前向传播中运行什么。它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户通过(activation, hidden),则函数应正确使用第一个输入作为activation,第二个输入作为hidden。

    reserve_rng_state(bool, optional, default=True)

    在每个检查点期间省略存储和恢复RNG状态。

    args

    包含函数输入的元组(输入)

    Returns

    在*args(输入)上运行function得到的输出

    torch.utils.checkpoint.checkpoint_sequential(functionssegments*inputs**kwargs)

    用于在sequential model中设置检查点的辅助函数。

    sequential model按顺序执行模块/函数列表。因此,我们可以将这种模型划分为不同的段,并在每个段上检查点。除最后一个段外的所有段都将以torch.no_grad()方式运行,即不存储中间激活。将保存每个检查点段的输入部分,以便在反向传播中重新运行该段。

    See checkpoint() on how checkpointing works.

    Parameters

    functions:

    torch.nn.Sequential 或 依次运行的模块或函数(包含模型)的列表。

    segments:

    在模型中创建的块数

    *inputs

    作为函数输入的张量元组

    reserve_rng_state(bool, optional, default=True)

    在每个检查点期间省略存储和恢复RNG状态。

    Returns

    在* input上顺序运行函数得到的输出

    Example

    >>> model = nn.Sequential(...)
    >>> input_var = checkpoint_sequential(model, chunks, input_var)

    在DenseNet中为了解决GPU内存占用大的问题,就采用了这种策略缓解显存占用大的问题。

    下面是denselayer的细节:

     1 class _DenseLayer(nn.Sequential): # bottleneck + conv
     2     def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
     3         super(_DenseLayer, self).__init__()
     4         self.add_module("norm1", nn.BatchNorm2d(num_input_features))
     5         self.add_module("relu1", nn.ReLU(inplace=True))
     6         self.add_module("conv1", nn.Conv2d(num_input_features, bn_size*growth_rate,
     7                                            kernel_size=1, stride=1, bias=False))
     8 
     9         self.add_module("norm2", nn.BatchNorm2d(bn_size*growth_rate))
    10         self.add_module("relu2", nn.ReLU(inplace=True))
    11         self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate,
    12                                            kernel_size=3, stride=1, padding=1, bias=False))
    13 
    14         self.drop_rate = drop_rate
    15         self.memory_efficient = memory_efficient
    16 
    17     def forward(self, *prev_features):
    18         bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
    19         if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
    20             bottleneck_output = cp.checkpoint(bn_function, *prev_features)
    21         else:
    22             bottleneck_output = bn_function(*prev_features)
    23         new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
    24         if self.drop_rate > 0:
    25             new_features = F.dropout(new_features, self.drop_rate, training=self.training)
    26         return new_features
  • 相关阅读:
    【LeetCode & 剑指offer刷题】动态规划与贪婪法题13:Coin Change(系列)
    【LeetCode & 剑指offer刷题】动态规划与贪婪法题12:Jump Game(系列)
    【LeetCode & 剑指offer刷题】动态规划与贪婪法题11:121. Best Time to Buy and Sell Stock(系列)
    【LeetCode & 剑指offer刷题】动态规划与贪婪法题10:Longest Increasing Subsequence
    linux安装rabbitmq
    微服务-springboot-读写分离(多数据源切换)
    微服务-springboot-rabbitmq:实现延时队列
    java-NIO-DatagramChannel(UDP)
    java-NIO-FileChannel(文件IO)
    java-NIO-概念
  • 原文地址:https://www.cnblogs.com/jiangkejie/p/13049684.html
Copyright © 2011-2022 走看看