zoukankan      html  css  js  c++  java
  • pytorch训练常见问题

    Cuda out of memory

    1.在训练循环除非必要,不要形成积累历史记录的变量

    total_loss = 0
    for i in range(10000):
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output)
        loss.backward()
        optimizer.step()
        total_loss += loss #pytorch中任何一个变量加上一个required_grad变量都会变成required_grad变量,这样每次反向传播都会增加内存占用
    #应该直接访问变量的底层数据
    #total_loss += float(loss)
    
    return accuracy
    # return accuracy.data[0] 
    

    2.不要保持一个不必要的张量或变量

    for i in range(5):
        intermediate = f(input[i])
        result += g(intermediate)
        # del intermediate  解决方法
    output = h(result)#这里计算时,intermediate依然存在,因为intermediate的作用域超出了循环部分。
    #对于分配给局部变量的变量或张量,除非超出了变量作用域,否则python不会主动回收这些内存
    return output
    

    3.RNN的BPTT问题---Backpropagation through time
    RNN中反向传播内存占用和RNN输入序列的长度成正比。因此如果喂给RNN一个太长的输入序列,内存会很快耗尽。

    4.不要用太大的线性层---线性层占用内存巨大。

    训练结果的复现---随机数产生器(RNG--Random Number Generators)的固定

    1.给cpu及GPU设置随机种子

    import torch
    torch.manual_seed(0)
    

    2.在cudnn后端上运行时,以下要设置

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

    3.如果使用了numpy库依赖

    import numpy as np
    np.random.seed(0)
    

    模型序列化及保存--两种方法

    1.仅保存模型参数,然后仅加载模型参数(推荐

    torch.save(the_model.state_dict(), PATH)
    
    the_model = TheModelClass(*args, **kwargs)
    the_model.load_state_dict(torch.load(PATH)) #
    

    2.保存及加载整个模型(包括模型结构及参数)

    torch.save(the_model, PATH)
    
    the_model = torch.load(PATH)
    #这种情况下序列化数据绑定了特定的类及结构,因此在重构或应用到其他项目中时容易导致中断。
    

    XLA device--XLA(Accelerated Linear Algebra)-加速线性代数

    作用:

    1. 合并子图后再计算,而不是像模型构建时那样逐步计算
    2. 提高内存利用率
    3. 减少模型可执行文件大小
    4. 方便支持不同硬件后端

    Variable(已弃用)

    之前的版本不支持直接对tensor设置requir_grad,需要用Variable封装tensor设置自动梯度与否。
    现在剥离掉了Variable,不再需要对tensor封装,tensor直接默认支持required_grad,也可以直接对tensor设置required_grad = False

    Variable(tensor,required_gard = True)
    Variable(tensor)#默认required_gard = False
    

  • 相关阅读:
    how to uninstall devkit
    asp.net中bin目录下的 dll.refresh文件
    查找2个分支的共同父节点
    Three ways to do WCF instance management
    WCF Concurrency (Single, Multiple, and Reentrant) and Throttling
    检查string是否为double
    How to hide TabPage from TabControl
    获取当前系统中的时区
    git svn cygwin_exception
    lodoop打印控制具体解释
  • 原文地址:https://www.cnblogs.com/Henry-ZHAO/p/13201501.html
Copyright © 2011-2022 走看看