zoukankan      html  css  js  c++  java
  • Theano中的Function

    Theano中的Function

    function是theano框架中极其重要的一个函数,另外一个很重要的函数是scan,在学习theano框架中deep learning的教程的时候,几乎所有的实例程序都用到了function和scan, theano function 就和 python 中的 function 类似, 不过因为要被用在多进程并行运算中,所以他的 function 有他自己的一套使用方式.

    function函数里面最典型的4个参数就是inputs,outputs,updates和givens

     

    定义一个简单的激活函数activate function,用的是最常见的sigmoid函数

    #!/usr/bin/env python2
    # -*- coding: utf-8 -*-
    """
    theano之function
    """
    
    import numpy as np
    import theano.tensor as T
    import theano
    
    #首先需要定义一个 tensor T
    x = T.dmatrix('x')
    #声明sigmoid激活函数函数
    s = 1/(1+T.exp(-x))
    #调用 theano 定义的计算函数 logistic
    logistic = theano.function([x], s)
    #输入为一个2行两列的矩阵
    print(logistic([[0, 1], [-2, -3]]))

     

    定义多输入,多输出函数

    #!/usr/bin/env python2
    # -*- coding: utf-8 -*-
    """
    theano之function
    """#多输入,多输出的函数
    #如输入值为两个,输出值为两个
    #输入的值为矩阵A, B
    a, b = T.dmatrices('a', 'b')
    #计算a, b之间的diff, abs_diff, diff_squared
    diff = a-b
    abs_diff = abs(diff)
    diff_squared = diff**2
    #定义多输出函数
    f = theano.function([a, b ], [diff, abs_diff, diff_squared])
    
    x1, x2, x3 = f(
            np.ones((2, 2)),
            np.arange(4).reshape((2,2))
            )
    print(x1, x2, x3)
    """
    output:
        (array([[ 1.,  0.],
           [-1., -2.]]), array([[ 1.,  0.],
           [ 1.,  2.]]), array([[ 1.,  0.],
           [ 1.,  4.]]))
    """

    函数定义参数默认值和名称

    #!/usr/bin/env python2
    # -*- coding: utf-8 -*-
    """
    theano之function
    """
    
    import numpy as np
    import theano.tensor as T
    import theano
    
    #使用 T.dscalars() 里面同时定义三个纯量的容器。 以及输出值z
    x, y, w = T.dscalars('x', 'y', 'w')
    z = (x + y) * w
    #定义theano函数
    f = theano.function(
            [x,
             theano.In(y, value=1),
             theano.In(w, value=2, name='weights')],
            z
            )
    #使用默认值
    print(f(23))
    #使用非默认值
    print(f(23, 1, 4))
    #试用名称赋值
    print(f(23,1, weights = 6))
    
    """
    output:
        48.0
        96.0
        144.0
    """
  • 相关阅读:
    [转]VirtualBox错误 Unable to load R3 module 解决方案
    2014工作感想
    人生的真正价值
    react生成二维码
    判断对象中是否包含某个属性
    使用reduce进行数组对象去重
    filter筛选
    判断区分安卓ios
    scrollIntoView 与 scrollIntoViewIfNeeded API 介绍
    vue中使用@scroll实现分页处理(分页要做节流处理)
  • 原文地址:https://www.cnblogs.com/xmeo/p/7238881.html
Copyright © 2011-2022 走看看