zoukankan      html  css  js  c++  java
  • 命令式和符号式混合编程

    # 命令式和符号式编程
    
    def add_str():
        return '''
    def add(a,b):
        return a + b
    '''
    
    
    
    def fancy_func_str():
        return '''
    def fancy_func(a, b, c, d):
        e = add(a,b)
        f = add(c,d)
        g = add(e,f)
        return g
    '''
    
    def evoke_str():
        return add_str() + fancy_func_str() + '''
    print(fancy_func(1,2,3,4))
    '''
    
    prog = evoke_str()
    # print(prog)
    
    y = compile(prog,'','exec')
    # exec(y)
    
    
    from mxnet import nd,autograd,sym
    from mxnet.gluon import nn,loss as gloss
    
    def get_net():
        net = nn.HybridSequential()
        net.add(nn.Dense(256,activation='relu'),
                nn.Dense(128,activation='relu'),
                nn.Dense(2))
        net.initialize()
        return net
    
    net = get_net()
    X = nd.random.normal(shape=(1,512))
    print(net(X))
    
    # 通过net.hybridize()来编译和优化HybridSequential实例中的串联层的计算
    net.hybridize()
    print(net(X))
    
    # 对比
    import time
    def benchmark(net, x):
        start = time.time()
        for i in range(1000):
            _ = net(x)
        nd.waitall()
        return time.time() - start
    
    net = get_net()
    print('before hybridizing: %.4f sec' % benchmark(net,X))
    
    net.hybridize()
    print('after hybridizing: %.4f sec' % benchmark(net,X))
    
    # 保存参数
    net.export('my_mlp')
    
    x = sym.var('data')
    print(net(x))
    from mxnet.gluon import nn,loss
    from mxnet import nd,autograd
    
    class HybirdNet(nn.HybridBlock):
        def __init__(self, **kwargs):
            super(HybirdNet,self).__init__(**kwargs)
            self.hidden = nn.Dense(10)
            self.output = nn.Dense(2)
    
        def hybrid_forward(self, F, x, *args, **kwargs):
            print('F: ',F)
            print('x: ',x)
            x = F.relu(self.hidden(x))
            print('hidden: ',x)
            return self.output(x)
    
    net = HybirdNet()
    net.initialize()
    
    X = nd.random.normal(shape=(1,4))
    print(X)
    print(net(X))
    
    # 编译优化
    net.hybridize()
    print(net(X))

  • 相关阅读:
    JMeter 关联
    JMeter MD5加密
    JMeter 时间函数
    JMeter 常用设置
    JMeter 服务器资源监控
    js制作列表滚动(有滚动条)
    js监听事件
    获取窗口大小 并自适应大小变化
    js 标签云
    js 显示数字不断增加
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10228894.html
Copyright © 2011-2022 走看看