zoukankan      html  css  js  c++  java
  • pytorch的优化器optimizer使用方法

    基本定义:torch.optim是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。

    构建优化器:构建优化器可选择optim自定义的方法,一般也是调用其中的,如下可构建:

    optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

    optimizer = optim.Adam([var1, var2], lr = 0.0001)   # [var1, var2] 可理解为优化的变量 lr = 0.0001为梯度下降的学习率,其它未设置表示默认

    深度理解:若你构建了网络模型model,若你想给不同模块设置不同学习率,将可以采取以下方法:

    optimizer = optim.SGD([ {'params': model.base.parameters()}, {'params': model.classifier.parameters(), 'lr': 1e-3} ], lr=1e-2, momentum=0.9)

    以上optim.SGD()中的列表就是构建每个参数的学习率,若没有设置,则默认使用最外如:model.base.parameters()参数使用lr=1e-2  momentum=0.9

    添加个人需要的变量:

    若你想添加个人变量保存optimizer中,可使用:

    for b in optimizer.param_groups:
    b.setdefault('init_lr', 0.02)
    此时类似optimizer = optim.SGD([ {'params': model.base.parameters()}, {'params': model.classifier.parameters(), 'lr': 1e-3} ],init_lr=0.02, lr=1e-2, momentum=0.9)

    若你想更改学习率,可使用:

    for b in optimizer.param_groups:
    b.setdefault('init_lr', 0.00005)
    此时类似将optimizer = optim.SGD([ model.base.parameters(), lr=0.02, momentum=0.9) 变成
    optimizer = optim.SGD([model.base.parameters(), lr=0.00005, momentum=0.9)

    注:可理解optimezer已经保存了模型model需要使用的学习率参数。

    查看优化器参数:

    optimizer.param_groups[0]: 长度为6的字典,包括[‘amsgrad’, ‘params’, ‘lr’, ‘betas’, ‘weight_decay’, ‘eps’]这6个参数;

    optimizer.param_groups[1]: 好像是表示优化器的状态的一个字典;

    模型训练优化器一般方法使用:

    大多数optimizer所支持的简化版本。一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数

    for input, target in dataset:

      optimizer.zero_grad()

      output = model(input)

      loss = loss_fn(output, target)

      loss.backward()

      optimizer.step()

    深入用法:

    optimizer.step(closure)

    一些优化算法例如Conjugate Gradient和LBFGS需要重复多次计算函数,因此你需要传入一个闭包去允许它们重新计算你的模型。这个闭包应当清空梯度, 计算损失,然后返回。

    for input, target in dataset:

      def closure():

        optimizer.zero_grad()

        output = model(input)

        loss = loss_fn(output, target)

        loss.backward()

      return loss optimizer.step(closure)

    处理算法通用的辅助的code,如读取txt文件,读取xml文件,将xml文件转换成txt文件,读取json文件等
  • 相关阅读:
    LeetCode(111) Minimum Depth of Binary Tree
    LeetCode(108) Convert Sorted Array to Binary Search Tree
    LeetCode(106) Construct Binary Tree from Inorder and Postorder Traversal
    LeetCode(105) Construct Binary Tree from Preorder and Inorder Traversal
    LeetCode(99) Recover Binary Search Tree
    【Android】通过经纬度查询城市信息
    【Android】自定义View
    【OpenStack Cinder】Cinder安装时遇到的一些坑
    【积淀】半夜突然有点想法
    【Android】 HttpClient 发送REST请求
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/14794675.html
Copyright © 2011-2022 走看看