zoukankan      html  css  js  c++  java
  • softmax 在计算图中的前向和后向

    一 简介

    题外话:昨晚将矩阵求导复习了一遍,仔细推导了大部分公式,这次复习略有体会,相比第一次学习更加熟悉了,这种东西就应该多看看,常看常新。矩阵求导,它的本质就是多元函数的求导,矩阵只是为了方便书写,是一种整体的视角。矩阵求导,还可以用矩阵加下标表示标量来逐元素求导,是一种微观的视角。

    softmax 函数常用于分类问题,将一个向量的每个分量映射到 0 和 1 之间,表示对应类别的预测概率,所有分量求和为 1。本文将结合 OneFlow 的代码,分析 softmax 在计算图中如何前向传播和后向传播。今天在看代码的时候,一直纠结于矩阵的整体视角,对计算图上的梯度反传过程异常困惑。后来还是用逐元素求导的微观视角,将公式推理出来了,故作此文记录一下。如果不是纠结于整体视角,其实整个过程还是很简单的。

    二 前向传播

    2.1 公式

    softmax 输入是一个 n 维向量,输出也是一个 n 维向量。假设 (y in R^n, x in R^n),下面的公式中 exp 是逐元素操作,分母为对整个向量求和。

    [y = softmax(x) = frac{exp(x)}{sum exp(x)} ag{1} ]

    如果从微观的视角来看问题,每个元素可以通过下面的式子计算。

    [y_i = softmax(x_i) = frac{exp(x_i)}{sum exp(x_i)} ag{2} ]

    2.2 实现

    数值稳定性

    softmax 具体实现的时候,一般需要考虑数值稳定性

    wiki 上的数值稳定性

    In the mathematical subfield of numerical analysis, numerical stability is a generally desirable property of numerical algorithms. The precise definition of stability depends on the context. One is numerical linear algebra and the other is algorithms for solving ordinary and partial differential equations by discrete approximation.

    数值稳定性需要什么样的特性取决于具体的应用场景。在 softmax 中,需要避免发生上溢和下溢,导致 NaN 的出现[1]。对于公式 2,如果分母为 0,那么会变成 NaN,这种情况由下溢引起。如果 exp 的指数太大,会发生上溢,出现 NaN。如何解决这个问题呢?

    一般的处理方法是,让每个分量去减掉向量的最大值。对公式 2 上下同乘以 (exp(-max(pmb{x}))),得到公式 3。

    [y_i = frac{exp(x_i) exp(-max(pmb{x}))}{sum{exp(x_i)} exp(-max(pmb{x}))} = frac{exp(x_i - max(pmb{x}))}{sum{exp(x_i - max(pmb{x}))}} ag{3} ]

    于是就能避免出现上溢和下溢,从而避免 NaN。

    • 上溢:每个分量减去最大值,所以最大分量为 0,从而避免上溢。
    • 下溢:看分母会不会出现 0,因为至少一个分量是 0,所以求和最少为 1,避免了下溢导致 0 的出现。

    OneFlow 的实现

    使用了公式 3 来实现。

    // https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/kernels/softmax_kernel_util.cpp
      static void ComputeProb(DeviceCtx* ctx, const int64_t n, const int64_t w, const T* in, T* prob,
                              void* temp_storage, const size_t temp_storage_bytes) {
        auto Val = NdarrayUtil<DeviceType::kCPU, T>::GetValNdarrayBuilder();
        auto Var = NdarrayUtil<DeviceType::kCPU, T>::GetVarNdarrayBuilder();
        const size_t min_temp_storage_bytes =
            SoftmaxKernelUtil<DeviceType::kCPU, T>::GetComputeProbTempStorageSizeInBytes(n, w);
        CHECK_GE(temp_storage_bytes, min_temp_storage_bytes);
        const size_t reduce_temp_storage_bytes = GetReduceTempStorageSize<T>(n, w);
        T* reduce_storage = reinterpret_cast<T*>(temp_storage);
        auto reduce_storage_var =
            Var({static_cast<int64_t>(reduce_temp_storage_bytes / sizeof(T))}, reduce_storage);
        T* tmp = reinterpret_cast<T*>(reinterpret_cast<unsigned char*>(temp_storage)
                                      + reduce_temp_storage_bytes);
        // max | tmp[i] = Max_j(in[i][j])
        NdarrayUtil<DeviceType::kCPU, T>::ReduceMax(ctx, Var({n, 1}, tmp), Val({n, w}, in),
                                                    reduce_storage_var);
        // sub | prob[i][j] = in[i][j] - tmp[i]
        NdarrayUtil<DeviceType::kCPU, T>::BroadcastSub(ctx, Var({n, w}, prob), Val({n, w}, in),
                                                       Val({n, 1}, tmp));
        // exp | prob[i][j] = exp(prob[i][j])
        NdarrayUtil<DeviceType::kCPU, T>::InplaceExp(ctx, Var({n, w}, prob));
        // sum | tmp[i] = Sum_j(prob[i][j])
        NdarrayUtil<DeviceType::kCPU, T>::ReduceSum(ctx, Var({n, 1}, tmp), Val({n, w}, prob),
                                                    reduce_storage_var);
        // div | prob[i][j] /= tmp[i]
        NdarrayUtil<DeviceType::kCPU, T>::InplaceBroadcastDiv(ctx, Var({n, w}, prob), Val({n, 1}, tmp));
      }
    

    三 后向传播

    3.1 公式

    之前一直尝试从矩阵的角度去求解,导致怎么也求不出来。其实从微观的角度去求解,就不难了。首先,因为是标量,所以可以使用链式法则,不像矩阵求导,不可以随意使用链式法则,需要给出矩阵求导的定义,再结合实际情况确定链式法则的形式。再者,每个 (x_i) 都会参与到每个 (y_j) 的计算当中,所以梯度需要对不同的 (frac{partial y_j}{partial x_i}) 去求和。下面以 (x_1) 为例子。

    [frac{partial l}{partial x_1} = frac{partial l}{partial y_1} frac{partial y_1}{partial x_1} + frac{partial l}{partial y_2} frac{partial y_2}{partial x_1} + frac{partial l}{partial y_3} frac{partial y_3}{partial x_1} + ... ag{4} ]

    需要分别计算 (i)(j) 相等和不相等时候的导数,对公式 2,应用导数的除法法则即可。

    (i = j) 时,有

    [frac{partial y_i}{partial x_i} = frac{exp(x_i) sum{exp(x_i)} - exp(x_i) exp(x_i)}{sum{exp(x_i)}^2} = y_i (1 - y_i) ag{5} ]

    (i e j) 时,有

    [frac{partial y_j}{partial x_i} = frac{0 - exp(x_j) exp(x_i)}{(sum{exp(x_i)}^2)} = - y_i y_j ag{6} ]

    将公式 5 和 6 代入公式 4,可以得到公式 7。

    [frac{partial l}{partial x_1} = dy_1 y_1 (1 - y_1) + dy_2 (- y_1 y_2) + dy_3 (- y_1 y_3) + ... = y_1 (dy_1 - sum{y_i dy_i}) ag{7} ]

    将公式 4 和公式 7 中具体数字 (1) 变成 (i),就可以推导其他分量的导数了。

    3.2 实现

    需要注意到这个实现是 batch 版本的,手动求解一下 Mul, ReduceSum, BroadcastSub, InplaceMul,再画个矩阵就清晰明了了。

    // https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/kernels/softmax_kernel_util.cpp
      static void ComputeDiff(DeviceCtx* ctx, const int64_t n, const int64_t w, const T* dy,
                              const T* out, T* dx, void* temp_storage,
                              const size_t temp_storage_bytes) {
        auto Val = NdarrayUtil<DeviceType::kCPU, T>::GetValNdarrayBuilder();
        auto Var = NdarrayUtil<DeviceType::kCPU, T>::GetVarNdarrayBuilder();
        const size_t min_temp_storage_bytes =
            SoftmaxKernelUtil<DeviceType::kCPU, T>::GetComputeProbTempStorageSizeInBytes(n, w);
        CHECK_GE(temp_storage_bytes, min_temp_storage_bytes);
        const size_t reduce_temp_storage_bytes = GetReduceTempStorageSize<T>(n, w);
        T* reduce_storage = reinterpret_cast<T*>(temp_storage);
        auto reduce_storage_var =
            Var({static_cast<int64_t>(reduce_temp_storage_bytes / sizeof(T))}, reduce_storage);
        T* sum_vec = reinterpret_cast<T*>(reinterpret_cast<unsigned char*>(temp_storage)
                                          + reduce_temp_storage_bytes);
        // it's safe to use dx as tmp
        // dot product | get dot product sum_vec[i] from out[i] * dy[i]
        T* tmp = dx;
        NdarrayUtil<DeviceType::kCPU, T>::Mul(ctx, Var({n * w}, tmp), Val({n * w}, out),
                                              Val({n * w}, dy));
        NdarrayUtil<DeviceType::kCPU, T>::ReduceSum(ctx, Var({n, 1}, sum_vec), Val({n, w}, tmp),
                                                    reduce_storage_var);
        // sub | dx[i][j] = dy[i][j] - sum_vec[i]
        NdarrayUtil<DeviceType::kCPU, T>::BroadcastSub(ctx, Var({n, w}, dx), Val({n, w}, dy),
                                                       Val({n, 1}, sum_vec));
        // elementwise multiplication | dx[i][j] *= out[i][j]
        NdarrayUtil<DeviceType::kCPU, T>::InplaceMul(ctx, Var({n * w}, dx), Val({n * w}, out));
      }
    

    四 计算图上的反向传播

    计算图就不多做介绍了,这里顺便记录一下计算图上反向传播的思考。想明白这个之后,对于如何实现一个新的算子特别有帮助。后面会列举几个常见的例子。梯度反传的时候,节点只需要关心下一个节点的值和传回来的梯度,不需要关注邻接之外的节点。每个节点上的梯度,和节点的向量同型。因为昨天学习了矩阵求导,导致我非常关心整体的视角,尝试从矩阵的角度去计算整个梯度。实际上,从微观的角度去看会更好。

    4.1 例子

    假设 (x in R^n, y in R^m),反向传播回来的梯度为 (dy in R^m),是已知的。

    • (y = Wx, W in R^{m imes n}),则 (dx = W^T dy)
    • (y = x + b),则 (dx = dy)
    • (y = dropout(x)),则 (dx = mask * dy),其中 dropout 用 (mask) 来选择保留和丢弃哪些值。如果一个值保留了,那么反向的梯度也保留。如果丢弃了,那么反向梯度也丢弃。
    • (y = sigmoid(x)),则 (dx_i = dy_i y_i (1 - y_i)),这里从微观的视角来求得 (dx) 的公式。

    对于 sigmoid,OneFlow 中实现如下,可以看到上面的公式是如何在具体实现中呈现的。

    // https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/kernel/kernel_util.cpp
    KU_FLOATING_METHOD Sigmoid(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
      T half = static_cast<T>(0.5);
      for (int64_t i = 0; i != n; ++i) { y[i] = half * std::tanh(half * x[i]) + half; }
    }
    KU_FLOATING_METHOD SigmoidBackward(DeviceCtx* ctx, const int64_t n, const T* x, const T* y,
                                       const T* dy, T* dx) {
      for (int64_t i = 0; i != n; ++i) { dx[i] = y[i] * (1 - y[i]) * dy[i]; }
    }
    

    五 结论

    这篇随笔介绍了 softmax 如何在计算图中进行前向和后向计算,结合 OneFlow 代码具体实现,看看如何将抽象的公式转为具体的实现。最重要的是学会切换视角,不要局限于矩阵的整体视角,微观视角也很有帮助。对于计算图上面的计算,一个节点的梯度,只需要关注输出和传回来的梯度。节点上的梯度,建议使用微观的视角来求解。

    参考链接

    [1] https://blog.csdn.net/Shingle_/article/details/81988628

  • 相关阅读:
    疲劳原理
    golang中的 time 常用操作
    access与excel
    数据结构正式篇!初探!!
    数据结构复习之C语言malloc()动态分配内存概述
    C语言字符数组与字符串
    数据结构复习之C语言指针与结构体
    c语言数组
    数据结构
    C语言腾讯课堂(一)
  • 原文地址:https://www.cnblogs.com/zzk0/p/15173022.html
Copyright © 2011-2022 走看看