zoukankan      html  css  js  c++  java
  • PyTorch的动态学习率和optimizer.param_groups[0]

    如何设置PyTorch的动态学习率

    本文主要涉及内容:Optimizer_LRScheduler等源码分析。
    本文依旧基于PyTorch 1.1.0。

    Optimizer

    PyTorch提供了torch.optim.lr_scheduler来帮助用户改变学习率,下边将从Optimizer入手,看一下这个类是如何工作的。

    为什么从Optimizer入手,因为无论是Adam还是SGD,都是继承的这个类。同时,scheduler也是给所有的Optimizer服务的,所以需要用的方法都会定义在这个基类里,直接看一下这个类的属性即可。给出Doc中的代码链接

    首先是初始化方法def __init__(self, params, defaults),这个方法的params参数,就是我们在初始化优化器的时候传入的网络的参数,如Alexnet.parameters(),而后边所有的参数都将合并成dict参数作为这个方法的defaults。
    看一下Alexnet.parameters()中存的都是什么:

    1
    2
    for alex in Alexnet.parameters():
    print(alex.shape)

    可以看到,这里边存的就是整个网络的参数。
    有两种定义optimizer的方法:
    1
    2
    3
    4
    5
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    optimizer = optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
    ], lr=1e-2, momentum=0.9)

    如果是第一种定义的方法:在这个初始化方法中,会把这些参数先改造成[{'params': Alexnet.parameters()}]这样的一个长度为1的list。然后对这个list进行加工,添加上defaults中的参数,如果我们使用Alexnet来做一个例子的话,就是下边这个样子:

    1
    2
    3
    optimizer = torch.optim.Adam(Alexnet.parameters(), lr=0.001)
    print([group.keys() for group in optimizer.param_groups])
    # [dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'])]

    如果是第二种定义的方法:因为传入的本身就是dict的形式,所以会继续对他进行加工,添加上后边的参数,我们直接看疗效:
    1
    2
    3
    4
    5
    6
    optimizer = torch.optim.SGD([
    {'params': Alexnet.features.parameters()},
    {'params': Alexnet.classifier.parameters(), 'lr': 1e-3}
    ], lr=1e-2, momentum=0.9)
    print([group.keys() for group in optimizer.param_groups])
    # [dict_keys(['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov']), dict_keys(['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'])]

    这次的list变成了两个元素,而且每个元素的组成和使用Adam也不一样了,这很明显,因为不同的优化器需要的参数不同嘛~(关于不同层的lr不同的设置这里给出官网链接)

    但是两者是相似的,就是每个元素都有params和lr,这就够了。

    _LRScheduler

    所有的动态修改lr的类,都是继承的这个类,所以我们看一下这个类包含什么方法。源码链接

    在初始化方法中def __init__(self, optimizer, last_epoch=-1),包含两个参数,第一个参数就是我们上边提到的optimizer的任何一个子类。第二个参数的意思是当前执行到了哪个epoch。我们不指定它的时候,虽然默认是-1,但是init中会调用一次step并设置为0。

    一定要注意PyTorch的版本!我的windows上用的是1.0.1,服务器用的是1.1.0,就闹了很多问题。就拿这个类来说,在1.0.1中是先setp()再训练,而1.1.0进行了更新,先训练,然后再step()

    当我们调用了初始化后,会给optimizer增加一个字段,看一下:

    1
    2
    3
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    print([group.keys() for group in optimizer.param_groups])
    # [dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad', 'initial_lr'])]

    新增加的initial_lr字段就是原始的lr。

    def step(self, epoch=None)方法中,通常情况下我们不需要指定这个参数epoch,因为每次调用他都会增加1。在这个函数中会调用一个需要重载的方法get_lr(),每次调用都会从这个方法中提取改变后的lr,赋值给optimizer。

    这里其实我一直有个疑问的,就是scheduler的step和optimizer的step是一个什么关系,其实通过源码,看到这里,这俩函数没啥关系!scheduler的step只会修改lr,两者都需要执行!

    下边看一下两个scheduler的get_lr()对比一下。先看一下SetpLR:

    1
    2
    3
    4
    5
    def get_lr(self):
    if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
    return [group['lr'] for group in self.optimizer.param_groups]
    return [group['lr'] * self.gamma
    for group in self.optimizer.param_groups]

    这个会在设置的步长的整倍数的时候将lr*gamma。
    而ExponentialLR则会在每轮结束的时候都进行乘gamma的操作,这个减小也真的是指数倍的。
    1
    2
    3
    4
    5
    def get_lr(self):
    if self.last_epoch == 0:
    return self.base_lrs
    return [group['lr'] * self.gamma
    for group in self.optimizer.param_groups]

    Demo

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
    train_loader = Data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
    for epoch in range(100):
    for X, y in train_loader:
    ...
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    optimizer.param_groups:是长度为2的list,其中的元素是2个字典;
    optimizer.param_groups[0]:长度为6的字典,包括[‘amsgrad’, ‘params’, ‘lr’, ‘betas’, ‘weight_decay’, ‘eps’]这6个参数
    optimizer.param_groups[1]:表示优化器的状态的一个字典
    ————————————————
    版权声明:本文为CSDN博主「Wanderer001」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    原文链接:https://blog.csdn.net/weixin_36670529/article/details/107531773




    如果这篇文章帮助到了你,你可以请作者喝一杯咖啡

  • 相关阅读:
    Permission denied (publickey). SSH用户名密码登录报错
    git工作流(Gitflow/gitlab代码权限管理)
    Spring多数据源配置(2)[PageHelper插件下应用bug修复]
    Spring多数据源配置
    基于Redis实现分布式锁
    .NetCore Autofac依赖注入获取注册后的实例、全局容器获取
    C++注入记事本升级版,给记事本弄爱心
    C++注入记事本
    WINAPI实现简易扫雷游戏
    .net 公共基础类
  • 原文地址:https://www.cnblogs.com/sddai/p/14653322.html
Copyright © 2011-2022 走看看