zoukankan      html  css  js  c++  java
  • Theano 学习二 自动求导和T.grad

    1.复合函数的链式求导

    教科书中已有。

    2.复杂函数的链式求导

    下图左侧为计算序列,右侧为导数序列 。

    计算序列形式的每一步都与其导数计算的步骤有一一对应的关系。

    3.Theano中的求导

    Theano首先将计算过程编译成一个图模型:

    采用后向传播的方式从每个节点获取梯度。

    下面为t.grad使用示例。其中fill((x** TensorConstant{2}), TensorConstant{1.0})指创建一个与x**2同样大小的矩阵,并填充1.0。

    所以第一个求导结果是 2*x 。

    第二个求导的结果可以简化为

    (1 / (1 + exp(-x) * (1 + exp(-x)) * exp(-x))
    import theano.tensor as T
    from theano import pp
    from theano import function
    x= T.dscalar('x')
    y= x ** 2
    gy= T.grad(y, x)
    f= function([x], gy)
    print pp(gy)
    print f(4)
    """
    ((fill((x ** TensorConstant{2}), TensorConstant{1.0}) *   
    TensorConstant{2}) * (x ** (TensorConstant{2} - TensorConstant{1})))
    8.0
    """
    
    x= T.dmatrix('x')
    s= T.sum(1 / (1 + T.exp(-x)))
    sx= 1 / (1 + T.exp(-x))
    gs= T.grad(s, x)
    dlogistic= function([x], gs)
    print pp(gs)
    """
    (-(((-(fill((TensorConstant{1} / (TensorConstant{1} + exp((-x)))), fill(Sum{acc_dtype=float64}((TensorConstant{1} / 
    (TensorConstant{1} + exp((-x))))), TensorConstant{1.0})) * TensorConstant{1})) / ((TensorConstant{1} + exp((-x))) * 
    (TensorConstant{1} + exp((-x))))) * exp((-x))))
    """
    a=[[0, 1], [-1, -2]]
    print dlogistic(a)
    k=function([x],sx)
    print k(a)*(1-k(a))
    """
    [[ 0.25        0.19661193]
     [ 0.19661193  0.10499359]]
    [[ 0.25        0.19661193]
     [ 0.19661193  0.10499359]]
    """

    4.简易自动求导

    # coding: utf8
    
    class Vars(object):
    
        def __init__(self):
            self.count = 0
            self.defs = {}
            self.lookup = {}
    
        def add(self, *v):
            name = self.lookup.get(v, None)  # 避免重复
            if name is None:
                if v[0] == '+':
                    if v[1] == 0:
                        return v[2]
                    elif v[2] == 0:
                        return v[1]
                elif v[0] == '*':
                    if v[1] == 1:
                        return v[2]
                    elif v[2] == 1:
                        return v[1]
                    elif v[1] == 0:
                        return 0
                    elif v[2] == 0:
                        return 0
    
                self.count += 1
                name = "v" + str(self.count)
                self.defs[name] = v
                self.lookup[v] = name
            return name
    
        def __getitem__(self, name):
            return self.defs[name]
    
        def __iter__(self):
            return self.defs.iteritems()
    
        def get_func(self,name):
            v= self.defs[name]
            if v[0] in ['+','*']:
                return self.get_func(v[1])+v[0]+self.get_func(v[2])
            elif v[0]=='input':
                return v[1]
            else:
                return v[0]+'('+self.get_func(v[1])+')'
    
    
    
    def diff(vars, acc, v, w):
        if v == w:
            return acc   # 终点
    
        v = vars[v]
        if v[0] == 'input':
            return 0  # 终点
        elif v[0] == "sin":
            return diff(vars, vars.add("*", acc, vars.add("cos", v[1])), v[1], w)  # 相应导数
        elif v[0] == '+':
            gx = diff(vars, acc, v[1], w)
            gy = diff(vars, acc, v[2], w)
            return vars.add("+", gx, gy)  # 链式法则
        elif v[0] == '*':
            gx = diff(vars, vars.add("*", v[2], acc), v[1], w)
            gy = diff(vars, vars.add("*", v[1], acc), v[2], w)
            return vars.add("+", gx, gy)  # 链式法则
    
        raise NotImplementedError
    
    
    def autodiff(vars, v, *wrt):
        return tuple(diff(vars, 1, v, w) for w in wrt)
    # z = (sin x) + (x * y)
    
    vars = Vars()
    x = vars.add("input",'x')
    y = vars.add("input",'y')
    z = vars.add("+", vars.add("*",x,y),vars.add("sin",x))   #sin(x)+x*y
    
    result= autodiff(vars, z, x, y)
    
    for k, v in vars:
        print k, v
    for v_ in result:
        print v_, vars.get_func(v_)
    
    """
    v1 ('input', 'x')
    v2 ('input', 'y')
    v3 ('*', 'v1', 'v2')  x*y
    v4 ('sin', 'v1')  sin(x)
    v5 ('+', 'v3', 'v4') sin(x)+x*y
    后两个为求导时增加的
    v6 ('cos', 'v1')  cos(x)
    v7 ('+', 'v2', 'v6') cos(x)+y
    函数表达式
    v7, y+cos(x)
    v1, x
    """
  • 相关阅读:
    MyString
    Django疑难问题
    mysql 疑难问题-django
    python时间转换 ticks-FYI
    django建议入门-FYI
    Python风格规范-FYI
    scrum敏捷开发☞
    git基本命令
    centos下的安装mysql,jdk
    memcached for .net on windows
  • 原文地址:https://www.cnblogs.com/qw12/p/6216377.html
Copyright © 2011-2022 走看看