zoukankan      html  css  js  c++  java
  • tensorflow源码分析——CTC

    CTC是2006年的论文Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks中提到的,论文地址: http://www.cs.toronto.edu/~graves/icml_2006.pdf

    论文中CTC的定义是这样的:把对未分割的序列数据label的任务叫做Temporal Classification,把使用RNNs对未分割的序列数据label叫做Connectionist Temporal Classification(CTC) 。与之相对的是,把对数据序列的每一个time-step或者frame独立label 叫做framewise classification

    tensorflow中的相关实现在 /tensorflow/python/ops/ctc_ops.py

    1. ctc_loss, 计算ctc loss

    def ctc_loss(labels, inputs, sequence_length,
                 preprocess_collapse_repeated=False,
                 ctc_merge_repeated=True, time_major=True):

    这个类执行softmax操作,所以输入应该是LSTM输出的线性映射

    inputs, 最内部维度大小是num_classes,代表“num_labels +1” 个类别,其中num_labels是真实的balebs的数目,最大值“num_labels-1”是为blank label保留的

    例如,如果一个单词包含3个labels ‘[a, b, c]’,则num_classes =4, 且labels的索引号是 ‘{a:0, b:1, c:2, blank:3}’

    至于参数 preprocess_collapse_repeated 和 ctc_merge_repeated:

    如果 preprocess_collapse_repeated = True ,在计算ctc之前,重复的labels会被合并为一个labels。这种预处理对下面这种情况是有用的:如果训练数据是强制对齐得到的,会包含不必要的重复。

    如果 ctc_merge_repeated = False,那么伴随ctc计算的深入,重复的非blank将不会被合并,会被解释为独立的labels。这是ctc的简化的非标准的版本

    具体见下表

    • preprocess_collapse_repeated = False,ctc_merge_repeated = True:经典CTC,输出的真实的重复的中间带有blanks类别,也可以通过解码器解码,输出不带有blanks的重复类别
    • preprocess_collapse_repeated = True,ctc_merge_repeated = False:因为在training之前,input 的labels已经合并重复项了,所以不会输出重复的类
    • preprocess_collapse_repeated = False,ctc_merge_repeated = False:输出重复的中间带有blank的类别,但是通常不需要解码器合并重复项
    • preprocess_collapse_repeated = True,ctc_merge_repeated = True: 未测试,非常可能不会学会输出重复类

    参数:

    labels: int32 SparseTensor, 标准的输出,稀疏矩阵

    inputs: 3-D float tensor . 计算得到的logits。 如果time_major = False, shape:batch_size x max_time x num_classes. 如果 time_major = True, shape:max_time x batch_size x num_classes

    sequence_length: 1-D int32 向量, batch_size

    输出:

    1-D float tensor,size:[batch], 概率的负对数

     2. ctc_beam_search_decoder: 对输入的logits执行beam search 解码

    def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100,
                                top_paths=1, merge_repeated=True):

     如果 merge_repeated = True, 在输出的beam中合并重复类。这意味着如果一个beam中的连续项( consecutive entries) 相同,只有第一个提交。即,如果top path 是‘A B B B ’,返回值是‘A B’(当merge_repeated = True),‘A B B B ’ (当merge_repeated = False)

    参数:

    inputs: 3-D float tensor , shape:max_time x batch_size x num_classes

    sequence_length: 1-D int32 向量, batch_size

    beam_ int scalar>=0

    top_paths: int scalar>=0, <= beam_width, 输出解码后的数目

    输出:

    元组:(decoded, log_prob)

    其中:

    decoded : a list of length top_paths, 每一个是一个稀疏矩阵

    log_prob : matrix , shape (batch_size x top_paths)

  • 相关阅读:
    TLS1.3&TLS1.2形式化分析(二)
    浏览器代理设置和取消代理
    jdk在window系统中的配置
    pycharm2017.3版本永久激活
    Scyther 形式化分析工具资料整理(三)
    百度快照的检索和反馈删除
    Scyther-Semantics and verification of Security Protocol 翻译 (第二章 2.2.2----2.3)
    双一流学校名单
    Scyther tools 协议形式化分析帮助文档翻译
    全国书画艺术之乡-----通渭
  • 原文地址:https://www.cnblogs.com/yuetz/p/6678243.html
Copyright © 2011-2022 走看看