zoukankan      html  css  js  c++  java
  • 【NAS工具箱】Pytorch中的Buffer&Parameter

    Parameter : 模型中的一种可以被反向传播更新的参数。

    第一种:

    • 直接通过成员变量nn.Parameter()进行创建,会自动注册到parameter中。
    def __init__(self):
        super(MyModel, self).__init__()
        self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成员变量
    

    或者:

    • 通过nn.Parameter() 创建普通对象
    • 通过register_parameter()进行注册
    • 可以通过model.parameters()返回
    def __init__(self):
        super(MyModel, self).__init__()
        param = nn.Parameter(torch.randn(3, 3))  # 普通 Parameter 对象
        self.register_parameter("param", param)
    

    Buffer : 模型中不能被反向传播算法更新的参数。

    • 创建tensor
    • 将tensor通过register_buffer进行注册
    • 可以通过model.buffers()返回
    def __init__(self):
        super(MyModel, self).__init__()
        buffer = torch.randn(2, 3)  # tensor
        self.register_buffer('my_buffer', buffer)
        self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成员变量
    

    总结:

    • 模型参数=parameter+buffer; optimizer只能更新parameter,不能更新buffer,buffer只能通过forward进行更新。
    • 模型保存的参数 model.state_dict() 返回一个OrderDict
    代码改变世界
  • 相关阅读:
    [POI2007]山峰和山谷Grz
    [POI2007]驾驶考试egz
    [POI2007]立方体大作战tet
    BZOJ1085 [SCOI2005]骑士精神
    BZOJ1975 [Sdoi2010]魔法猪学院
    codeforces754D Fedor and coupons
    UOJ79 一般图最大匹配
    BZOJ3944 Sum
    BZOJ3434 [Wc2014]时空穿梭
    UOJ58 【WC2013】糖果公园
  • 原文地址:https://www.cnblogs.com/pprp/p/14817025.html
Copyright © 2011-2022 走看看