zoukankan      html  css  js  c++  java
  • optimizer.state_dict()、optimizer.param_groups

    net = t.nn.Linear(2, 3)
    optimizer = t.optim.SGD(net.parameters(), lr=0.2)
    for key, value in optimizer.state_dict().items():
    print(key, value)
    for i, param_group in enumerate(optimizer.param_groups):
    print(i+1)
    print(param_group)

    1、optimizer.state_dict()

    """

    state {}
    param_groups [{'lr': 0.2, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140327302981024, 140327686399752]}]

    """

    是一个字典,包括优化器的状态(state)以及一些超参数信息(param_groups)

    2、optimizer.param_groups

    """

    1
    {'params': [Parameter containing:
    tensor([[-0.2604, 0.0777],
    [-0.6420, 0.5030],
    [-0.3879, -0.5129]], requires_grad=True), Parameter containing:
    tensor([ 0.6245, 0.4680, -0.3667], requires_grad=True)], 'lr': 0.2, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}

    """

    是param_groups是一个数组,数组内部包含n个字典

    总结:state_dict()包括param_groups

  • 相关阅读:
    Lambda表达式
    多态之美
    集合那点事
    程序员艺术家
    MySQL:如何导入导出数据表和如何清空有外建关联的数据表
    Ubuntu修改桌面为Desktop
    shutil.rmtree()
    SCP命令
    kickstart
    数据哈希加盐
  • 原文地址:https://www.cnblogs.com/liujianing/p/13428387.html
Copyright © 2011-2022 走看看