zoukankan      html  css  js  c++  java
  • Theano中的Function

    Theano中的Function

    function是theano框架中极其重要的一个函数,另外一个很重要的函数是scan,在学习theano框架中deep learning的教程的时候,几乎所有的实例程序都用到了function和scan, theano function 就和 python 中的 function 类似, 不过因为要被用在多进程并行运算中,所以他的 function 有他自己的一套使用方式.

    function函数里面最典型的4个参数就是inputs,outputs,updates和givens

     

    定义一个简单的激活函数activate function,用的是最常见的sigmoid函数

    #!/usr/bin/env python2
    # -*- coding: utf-8 -*-
    """
    theano之function
    """
    
    import numpy as np
    import theano.tensor as T
    import theano
    
    #首先需要定义一个 tensor T
    x = T.dmatrix('x')
    #声明sigmoid激活函数函数
    s = 1/(1+T.exp(-x))
    #调用 theano 定义的计算函数 logistic
    logistic = theano.function([x], s)
    #输入为一个2行两列的矩阵
    print(logistic([[0, 1], [-2, -3]]))

     

    定义多输入,多输出函数

    #!/usr/bin/env python2
    # -*- coding: utf-8 -*-
    """
    theano之function
    """#多输入,多输出的函数
    #如输入值为两个,输出值为两个
    #输入的值为矩阵A, B
    a, b = T.dmatrices('a', 'b')
    #计算a, b之间的diff, abs_diff, diff_squared
    diff = a-b
    abs_diff = abs(diff)
    diff_squared = diff**2
    #定义多输出函数
    f = theano.function([a, b ], [diff, abs_diff, diff_squared])
    
    x1, x2, x3 = f(
            np.ones((2, 2)),
            np.arange(4).reshape((2,2))
            )
    print(x1, x2, x3)
    """
    output:
        (array([[ 1.,  0.],
           [-1., -2.]]), array([[ 1.,  0.],
           [ 1.,  2.]]), array([[ 1.,  0.],
           [ 1.,  4.]]))
    """

    函数定义参数默认值和名称

    #!/usr/bin/env python2
    # -*- coding: utf-8 -*-
    """
    theano之function
    """
    
    import numpy as np
    import theano.tensor as T
    import theano
    
    #使用 T.dscalars() 里面同时定义三个纯量的容器。 以及输出值z
    x, y, w = T.dscalars('x', 'y', 'w')
    z = (x + y) * w
    #定义theano函数
    f = theano.function(
            [x,
             theano.In(y, value=1),
             theano.In(w, value=2, name='weights')],
            z
            )
    #使用默认值
    print(f(23))
    #使用非默认值
    print(f(23, 1, 4))
    #试用名称赋值
    print(f(23,1, weights = 6))
    
    """
    output:
        48.0
        96.0
        144.0
    """
  • 相关阅读:
    SQL Server 阻止了对组件 'Ole Automation Procedures' 的 过程'sys.sp_OACreate' 的访问
    谷歌浏览器扩展程序manifest.json参数详解
    获取天气api
    UVA 10385 Duathlon
    UVA 10668 Expanding Rods
    UVALIVE 3891 The Teacher's Side of Math
    UVA 11149 Power of Matrix
    UVA 10655 Contemplation! Algebra
    UVA 11210 Chinese Mahjong
    UVA 11384 Help is needed for Dexter
  • 原文地址:https://www.cnblogs.com/xmeo/p/7238881.html
Copyright © 2011-2022 走看看