原文地址:
https://zhuanlan.zhihu.com/p/23309693
https://zhuanlan.zhihu.com/p/23293860
CTC:前向计算例子
这里我们直接使用warp-ctc中的变量进行分析。我们定义T为RNN输出的结果的维数,这个问题的最终输出维度为alphabet_size。而ground_truth的维数为L。也就是说,RNN输出的结果为alphabet_size*T的结果,我们要将这个结果和1*L这个向量进行对比,求出最终的Loss。
我们要一步一步地揭开这个算法的细节……当然这个算法的实现代码有点晦涩……
我们的第一步要顺着test_cpu.cpp的路线来分析代码。第一步我们就是要解析small_test()中的内容。也就是做前向计算,计算对于RNN结果来说,对应最终的ground_truth——t的label的概率。
这个计算过程可以用动态规划的算法求解。我们可以用一个变量来表示动态规划的中间过程,它就是:
:表示在RNN计算的时间T时刻,这一时刻对应的ground_truth的label为第i个下标的值t[i]的概率。
这样的表示有点抽象,我们用一个实际的例子来讲解:
RNN结果:,这里的每一个变量都对应一个列向量。
ground_truth:
那么表示的结果对应着的概率,当然与此同时,前面的结果也都合理地对应完成。
从上面的结果我们可以看出,如果的结果对应着,那么的结果也必然对应着。所以前面的结果是确定的。然而对于其他的一些情况来说,我们的转换存在着一定的不确定性。
CTC:前向计算具体过程
我们还是按照上面的例子进行计算,我们把刚才的例子搬过来:
RNN结果:,这里的每一个变量都对应一个列向量。
ground_truth:
alphabet:
按照上面介绍的计算方法,第一步我们先做ground_truth的状态扩展,于是我们就把长度从3扩展到了7,现在的ground_truth变成了:
我们的RNN结果长度为4,也就是说我们会从上面的7个ground_truth状态中进行转移,并最终转移到最终状态。理论上利用动态规划的算法,我们需要计算4*7=28个中间结果。好了,下面我们用表示RNN的第T时刻状态为ground_truth中是第i个位置的概率。
那么我们就开始计算了:
T=1时,我们只能选择和blank,所以这一轮我们终结状态只可能落在0和1上。所以第一轮变成了:
T=2时,我们可以继续选择,我们同时也可以选择,还可以选择和之间的blank,所以我们可以进一步关注这三个位置的概率,于是我们将其他的位置的概率设为0。
T=3时,留给我们的时间已经不多了,我们还剩2步,要走完整个旅程,我们只能选择,以及它们之间的空格。于是乎我们关心的位置又发生了变化:
是不是有点看晕了?没关系,因为还剩最后一步了。下面是最后一步,因为最后一步我们必须要到以及它后面的空格了,所以我们的概率最终计算也就变成了:
好吧,最终的结果我们求出来了,实际上这就是通过时间的推移不断迭代求解出来的。关于迭代求解的公式这里就不再赘述了。我们直接来看一张图:
于是乎我们从这个计算过程中发现一些问题:
首先是一个相对简单的问题,我们看到在计算过程中我们发现了大量的连乘。由于每一个数字都是浮点数,那么这样连乘下去,最终数字有可能非常小而导致underflow。所以我们要将这个计算过程转到对数域上。这样我们就将其中的乘法转变成了加法。但是原本就是加法的计算呢?比方说我们现在计算了loga和logb,我们如何计算log(a+b)呢,这里老司机给出了解决方案,我们假设两个数中a>b,那么有
这样我们就利用了loga和logb计算出了log(a+b)来。
另外一个问题就是,我们发现在刚才的计算过程当中,对于每一个时间段,我们实际上并不需要计算每一个ground-truth位置的概率信息,实际上只要计算满足某个条件的某一部分就可以了。所以我们有没有希望在计算前就规划好这条路经,以保证我们只计算最相关的那些值呢?
如何控制计算的数量?
不得不说,这一部分warp-ctc写得实在有点晦涩,当然也可能是我在这方面的理解比较渣。我们这里主要关注两个部分——一个是数据的准备,一个是最终的数据的使用。
在介绍数据准备之前,我们先简单说一下这部分计算的大概思路。我们用两个变量start和end表示我们需要计算的状态的起止点,在每一个时间点,我们要更新start和end这两个变量。然后我们更新start和end之间的概率信息。这里我们先要考虑一个问题,start和end的更新有什么规律?
为了简化思考,我们先假设ground_truth中没有重复的label,我们的大脑瞬间得到了解放。好了,下面我们就要给出代码中的两个变量——
T:表示RNN结果中的维度
S/2:ground_truth的维度(S表示了扩展blank之后的维度)
基本上具备一点常识,我们就可以知道T>=S/2。什么?你觉得有可能出现T<S/2的情况?兄弟,这种见鬼的事情如果发生,你难道要我们把RNN的结果拆开给你用?臣妾不太能做得到啊……
好了,既然接受了上面的事实,那么我们就来举几个例子看看:
我们假设T=3,S/2=3,那么说白了,它们之间的对应关系是一一对应,说白了这就和blank位置没啥关系了。在T=1时,我们要转移到第一个结果,T=2,我们要转移到第二个结果……
如何控制计算的数量?cont.
好,废话少说我们书接上回。不明真相的小朋友先看这个:
下面我们假设T=4,S/2=3,好玩的地方来了。T比S/2多一个,也就是说我们允许冗余出现了,那么我们可能的形式也就变多了。我们可以增加一个blank,我们也可以在没有label位置原地打一轮酱油。选择更多,欢乐更多。
虽然选择变多,但是着并不意味着我们可以选择任意一种状态转移的方式,至少:
- 在T=2时,我们至少要转移到第一个结果
- 在T=3时,我们至少要转移到第二个结果
- 在T=4时,兄弟我们准备下车了
这其实就是对start的限制。源代码中有这样一句话:
int remain = (S / 2) + repeats - (T - t);
这里我们先忽略repeats,那么remain这个变量其实是在计算label数量和剩余时间的差。如果用这样的语言来表达刚才的那个问题,我们语言就变成这个样子:
- 当时间还剩4轮时(包括第4轮),我们在哪都无所谓(实际上是从T=1开始计算的)
- 当时间还剩3轮时(包括第3轮),我们至少要转移到第一个结果(index=1)
- 当时间还剩2轮时(包括第2轮),我们至少要转移到第二个结果(index=3)
- 当时间还剩1轮时(包括第1轮),我们至少要转移到第三个结果(index=5)
好了,这里我们看出其中的含义了。我们再啰嗦一下,看看这些变量随T的变化情况:
- T=1,remain=0,start+=1
- T=2,remain=1,start+=2
- T=3,remain=2,start+=2
现在我们已经十分清楚了,当remain>=0时,start都要向前走,限制我们计算前面状态的概率,因为这些概率已经没有意义了。下面的代码也是这样描述的:
if(remain >= 0)
start += s_inc[remain];
那么这个s_inc是什么东西?它就是我们需要提前准备好的计算量。我们知道经过扩充的label序列中,所有的非空label都处在奇数的index上,而填充的blank都处在偶数的index上(我们是0-based的计算方法,matlab选手请退散……),所以对于上面的问题,当start=0时,下一步我们会从0跳到1,此后我们会从1到3,3到5,跳转的步数都是2,所以基于这个思路,我们就可以把s_inc这个数组生成出来。当然,我们的前提是没有重复。下面我们会说重复的问题的。
我们上面说了这么多,重点把start的变化介绍清楚了。下面我们来看看end。其实end的原理也类似,我们还是用刚才的废话套路来介绍站在end视角的世界:
- 在T=1时,我们最多能到第一个结果
- 在T=2时,我们最多能转移到第二个结果
- 在T=3时,我们最多能转移到第三个结果
- 在T=4时,我们已经掌握了整个世界……oh yeah
好了,可以看出end的变化形式,每个时刻end都可以+2,直到到达最后一个非blank的label,end变成了+1,然后end就不用动了,等着start动就可以了……(怎么感觉有点污?天哪……)
那么end变化的条件是什么呢?
if(t <= (S / 2) + repeats)
end += e_inc[t - 1];
我们还是忽略repeats,那么就十分清楚了,如果当前时刻小于等于label数,那么尽管前进,如果大于了,基本上也就到头了,这时候end就不用动了。
好了,前面我们终于说完了简单模式下start和end的移动规律,下面我们来看看带重复模式下的变化方法。
重复,重复
重复会带来什么样的变化呢?说白了如果有重复的label出现,那么两个连续重复的label中间就要至少出现一个blank。换句话说,每出现一个重复,我们的S/2就要加一,于是我们再看一眼这两个计算公式:
int remain = (S / 2) + repeats - (T - t);
if(remain >= 0)
start += s_inc[remain];
if(t <= (S / 2) + repeats)
end += e_inc[t - 1];
我们把repeats和S/2归到一起,这时候就能看明白了。
同理,在计算s_inc和e_inc的时候,由于有repeats的存在,它们从过去的+2变成了两个+1。也就是说先从label跳到blank,再跳到下一个label。这样就可以解释s_inc和e_inc的初始化策略了:
int e_counter = 0;
int s_counter = 0;
s_inc[s_counter++] = 1;
int repeats = 0;
for (int i = 1; i < L; ++i) {
if (labels[i-1] == labels[i]) {
s_inc[s_counter++] = 1;
s_inc[s_counter++] = 1;
e_inc[e_counter++] = 1;
e_inc[e_counter++] = 1;
++repeats;
}
else {
s_inc[s_counter++] = 2;
e_inc[e_counter++] = 2;
}
}
e_inc[e_counter++] = 1;
好了,到此我们才算把CTC中compute ctc loss这部分介绍完了。教科书上的一个公式看着简单,落实到代码就似乎充满了trick。希望看懂了这个计算的你大脑没有阵亡。