zoukankan      html  css  js  c++  java
  • Theano中的共享变量

    Theano中的共享变量

    定义共享变量的原因在于GPU的使用,如果不定义共享的话,那么当GPU调用这些变量时,遇到一次就要调用一次,这样就会花费大量时间在数据存取上,导致使用GPU代码运行很慢,共享变量的类型必须为floatX

    因为GPU要求在floatX上操作,所以所有的共享变量都要声明为floatX类型,shared variable是一种符号变量(symbolic variable),但是这个symbolic variable又拥有自己的值。

    shared可以存储在显存中,因为这个特性,我们才会有"把神经网络参数放

    shared中"的这种做法。

    shared指向显存中的一块区域,这块区域在运算中是共享的,所以常常在运算中用来存储权值参数。

    #!/usr/bin/env python2
    # -*- coding: utf-8 -*-
    """
    theano de shared value
    """
    
    import numpy as np
    import theano.tensor as T
    import theano
    
    #共享变量中数据类型很重要,定义vector和matrix时需要统一
    #最后一个参数是共享变量的名称
    #初始化共享变量state为0
    state = theano.shared(np.array(0, dtype = np.float64), 'state')
    
    #定义累加值, 名称为inc, 定义数据类型是需要用state.dtype,而不是dtype = np.float64
    #否则会报错
    inc = T.scalar('inc', dtype = state.dtype)
    #定义一个accumulator函数
    #输入为inc, 输出为state
    #累加的过程叫做updates, 作用是state = state+inc
    accumulator = theano.function([inc], state, updates=[(state, state + inc)])

    获取与设置共享变量的值get_value, set_value 这两种只能在 Shared 变量 的时候调用。

    #!/usr/bin/env python2
    # -*- coding: utf-8 -*-
    """
    theano de shared value
    """
    
    import numpy as np
    import theano.tensor as T
    import theano
    
    """
    get_value, set_value 这两种只能在 Shared 变量 的时候调用。
    """
    
    #获取共享变量的值get_value
    print(state.get_value())
    #output: 0.0
    
    accumulator(1)
    print(state.get_value())
    #output: 1.0
    
    accumulator(10)
    print(state.get_value())
    #output: 11.0
    
    #设置共享变量的值set_value
    state.set_value(-1)
    accumulator(3)
    print(state.get_value())
    #output : 2.0

    临时共享变量,有时需要暂时试用shared变量,不需要更新,这时可以定义一个临时变量代替共享变量.

    #!/usr/bin/env python2
    # -*- coding: utf-8 -*-
    """
    theano de shared value
    """
    import numpy as np
    import theano.tensor as T
    import theano
    
    #临时试用共享变量
    #有时需要暂时试用shared变量,不需要更新,这时可以定义一个临时变量代替共享变量
    #函数输入值为[Inc, a], 要用a带入state, 输出是tmp_func函数形式
    #givens表示需要把什么替换成什么, 而state不会改变
    #最后输出结果中, state暂时被替换成a,state值不会变,还是上步的值2, a的值是3
    a = T.scalar(dtype=state.dtype)
    tmp_func = state * 2 + inc
    skip_shared = theano.function([inc, a], tmp_func, givens=[(state, a)])
    print(skip_shared(2, 3))
    #output: 3 *2+2 = 8
  • 相关阅读:
    Beta冲刺(4/4)
    2019 SDN上机第7次作业
    Beta冲刺(3/4)
    Beta冲刺(2/4)
    机器学习第二次作业
    机器学习第一次作业
    软工实践个人总结
    第04组 Beta版本演示
    第04组 Beta冲刺(5/5)
    第04组 Beta冲刺(4/5)
  • 原文地址:https://www.cnblogs.com/xmeo/p/7240648.html
Copyright © 2011-2022 走看看