zoukankan      html  css  js  c++  java
  • Pytorch分步训练(只训练部分参数)

    我现在的问题是,我的模型由两部分组成,bert+gat,bert只需要3~5轮就能收敛,而gat需要几十次,

    我期望的目标是训练5轮过后,就只训练gat,bert不被更新

    总体上有两种思路,一种是将不想被训练的参数修改为requires_grad=False,另一种是只将要训练的参数放到优化器中

    第一种:设置requires_grad=Fasle

    点击查看代码
    import torch
    import torch.nn as nn
    from torch.nn.modules import loss
    from torch_geometric.nn import models
    import ipdb
    
    data = torch.randn(4,10)
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(1,10)
            self.fc2 = nn.Linear(10,1)
            self.relu = nn.ReLU()
            self.sigmoid = nn.Sigmoid()
        def forward(self, x):
            # ipdb.set_trace()
            x = self.relu(self.fc1(x))
            x = self.sigmoid(self.fc2(x))
            return x
    
    def print_gram(model):
        for name, param in model.named_parameters():
            if 'fc1' in name:
                print(name, param.data, param.grad.norm(), param.requires_grad)
    
    
    model = Net()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,betas=(0.5, 0.999))
    criterion = nn.MSELoss()
    
    data = torch.tensor([[1.0],[3.0],[5.0],[7.0]])
    label = torch.tensor([[1.0],[9.0],[25.0],[49.0]])
    for i in range(20):
        a = model(data)
        loss = criterion(a,label)
    
        print(i, loss)
    
        if i < 10:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            optimizer.zero_grad()
            model.fc1.requires_grad_(False)
            loss.backward()
            optimizer.step()    
    
        print_gram(model)
    

    你会发现,fc1的梯度的确为0,且requires_grad==False,但是weight变了!!!

    可见 关于pytorch中使用detach并不能阻止参数更新这档子事儿

    (1)backward之前,grad=None

    (2)backward之后,grad变成具体值

    (3)执行step,weight得到更新

    (4)执行zero_grad,grad=0

    除了逐一设置,也能直接将一个模块设为False,nn.Module.requires_grad_()

    因此,梯度为0,因为有历史信息,权重也能更新!!

    感谢这个问题,让我明白了pytorch为什么不在每个step自动清零grad而是非要让人主动去做zero_grad。

    第二种:只优化要更新的参数

    点击查看代码
    import torch
    import torch.nn as nn
    from torch.nn.modules import loss
    from torch_geometric.nn import models
    import ipdb
    
    data = torch.randn(4,10)
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(1,10)
            self.fc2 = nn.Linear(10,1)
            self.relu = nn.ReLU()
            self.sigmoid = nn.Sigmoid()
        def forward(self, x):
            # ipdb.set_trace()
            x = self.relu(self.fc1(x))
            x = self.sigmoid(self.fc2(x))
            return x
    
    model = Net()
    
    def get_parameters():
        for name, param in model.named_parameters():
            params = []
            if 'fc1' not in name:
               params.append(param)
        
        print(len(params))
        return params
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,betas=(0.5, 0.999))
    optimizer2 = torch.optim.Adam(get_parameters(), lr=0.001,betas=(0.5, 0.999))
    
    def print_gram(model):
        for name, param in model.named_parameters():
            if 'fc1' in name:
                print(name, param.data, param.grad.norm(), param.requires_grad)
    
    
    criterion = nn.MSELoss()
    data = torch.tensor([[1.0],[3.0],[5.0],[7.0]])
    label = torch.tensor([[1.0],[9.0],[25.0],[49.0]])
    
    for i in range(20):
        a = model(data)
        loss = criterion(a,label)
    
        print(i, loss)
    
        if i < 10:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            optimizer2.zero_grad()
            loss.backward()
            optimizer2.step()    
    
        print_gram(model)
    

    虽然fc1会有梯度,但权重不会被更新,感觉还不错

    总结:

    使用requires_grad可以对参数或模块进行设置;使用detach用法简单,例如冻结bert,x = pooled_output.detach()  这句简单,只要把bert的输出加个detach,而且这两种方法只能用于冻结上游参数,不能是中间或后面,而且只有之前没有进行过BP才不会有历史信息,才会完全不影响梯度

    使用优化器方法,只设置优化器还是会进行整个BP,不会更新参数值,但并不会节省算力。参考bert模型的微调,如何固定住BERT预训练模型参数,只训练下游任务的模型参数? - ZJU某小白的回答 - 知乎 https://www.zhihu.com/question/317708730/answer/634068499

    参考链接:

    https://www.cxyzjd.com/article/Answer3664/108493753

    https://blog.csdn.net/jinxin521125/article/details/83621268

    https://pytorch.org/docs/stable/notes/autograd.html

  • 相关阅读:
    事务管理
    QQ邮箱开启SMTP方法如何授权
    基于JavaMail的Java邮件发送:简单邮件发送
    Spring MVC 响应视图(六)
    Spring MVC 数据绑定 (四)
    Spring MVC Spring中的Model (五)
    Spring MVC 拦截器 (十)
    Spring MVC 异常处理 (九)
    Java 骚操作--生成二维码
    Java MD5校验与RSA加密
  • 原文地址:https://www.cnblogs.com/lfri/p/step_train.html
Copyright © 2011-2022 走看看