zoukankan      html  css  js  c++  java
  • Gluon 参数读取

    ndarray: save , load

    from mxnet import nd
    from mxnet.gluon import nn
    
    x = nd.ones(3)
    # nd.save('x',x)
    # x2 = nd.load('x')
    # print(x2)
    
    y = nd.zeros(4)
    # print([x,y])
    # nd.save('xy',[x,y])
    
    # x2, y2 = nd.load('xy')
    # print(x2,y2)
    
    mydict = {'x':x,'y':y}
    # nd.save('mydict',mydict)
    
    # mydict2 = nd.load('mydict')
    # print(mydict2)

    Gluon 模型参数:save_parameters , load_parameters

    from mxnet import nd
    from mxnet.gluon import nn
    
    class MLP(nn.Block):
        def __init__(self, **kwargs):
            super(MLP, self).__init__(**kwargs)
            self.hidden = nn.Dense(256,activation='relu')
            self.output = nn.Dense(10)
    
        def forward(self, x):
            return self.output(self.hidden(x))
    
    # net = MLP()
    # net.initialize()
    # X = nd.random.uniform(shape=(2,20))
    # Y = net(X)
    # print(Y)
    # nd.save('X',X)
    # nd.save('Y',Y)
    
    filename = 'mlp.params'
    # net.save_parameters(filename)
    
    net2 = MLP()
    net2.load_parameters(filename)
    X = nd.load('X')
    Y = nd.load('Y')
    # print(X[0])
    Y2 = net2(X[0])
    print(Y[0]==Y2)

  • 相关阅读:
    hibernate hql
    数据库锁机制
    Spring 事物管理
    spring自动代理
    spring 其它增强类型
    spring
    mybatis动态sql
    SSH注解整合
    ssh整合
    错题解析
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10237373.html
Copyright © 2011-2022 走看看