zoukankan      html  css  js  c++  java
  • tensorflow LSTM+CTC使用详解

      最近用tensorflow写了个OCR的程序,在实现的过程中,发现自己还是跳了不少坑,在这里做一个记录,便于以后回忆。主要的内容有lstm+ctc具体的输入输出,以及TF中的CTC和百度开源的warpCTC在具体使用中的区别。

    正文

    输入输出

    因为我最后要最小化的目标函数就是ctc_loss,所以下面就从如何构造输入输出说起。

    tf.nn.ctc_loss

    先从TF自带的tf.nn.ctc_loss说起,官方给的定义如下,因此我们需要做的就是将图片的label(需要OCR出的结果),图片,以及图片的长度转换为label,input,和sequence_length。

    ctc_loss(
    labels,
    inputs,
    sequence_length,
    preprocess_collapse_repeated=False,
    ctc_merge_repeated=True,
    time_major=True
    )
    input: 输入(训练)数据,是一个三维float型的数据结构[max_time_step , batch_size , num_classes],当修改time_major = False时,[batch_size,max_time_step,num_classes]
    总体的数据流:
    image_batch
    ->[batch_size,max_time_step,num_features]->lstm
    ->[batch_size,max_time_step,cell.output_size]->reshape
    ->[batch_size*max_time_step,num_hidden]->affine projection A*W+b
    ->[batch_size*max_time_step,num_classes]->reshape
    ->[batch_size,max_time_step,num_classes]->transpose
    ->[max_time_step,batch_size,num_classes]
    下面详细解释一下,
    假如一张图片有如下shape:[60,160,3],我们如果读取灰度图则shape=[60,160],此时,我们将其一列作为feature,那么共有60个features,160个time_step,这时假设一个batch为64,那么我们此时获得到了一个[batch_size,max_time_step,num_features] = [64,160,60]的训练数据。
    然后将该训练数据送入构建的lstm网络中,(需要注意的是dynamic_rnn的输入数据在一个batch内的长度是固定的,但是不同batch之间可以不同,我们需要给他一个sequence_length(长度为batch_size的向量)来记录本次batch数据的长度,对于OCR这个问题,sequence_length就是长度为64,而值为160的一维向量)
    得到形如[batch_size,max_time_step,cell.output_size]的输出,其中cell.output_size == num_hidden。
    下面我们需要做一个线性变换将其送入ctc_loos中进行计算,lstm中不同time_step之间共享权值,所以我们只需定义W的结构为[num_hidden,num_classes]b的结构为[num_classes]。而tf.matmul操作中,两个矩阵相乘阶数应当匹配,所以我们将上一步的输出reshape成[batch_size*max_time_step,num_hidden](num_hidden为自己定义的lstm的unit个数)记为A,然后将其做一个线性变换,于是A*w+b得到形如[batch_size*max_time_step,num_classes]然后在reshape回来得到[batch_size,max_time_step,num_classes]最后由于ctc_loss的要求,我们再做一次转置,得到[max_time_step,batch_size,num_classes]形状的数据作为input

    labels: 标签序列
    由于OCR的结果是不定长的,所以label实际上是一个稀疏矩阵SparseTensor
    其中:

    • indices:二维int64的矩阵,代表非0的坐标点
    • values:二维tensor,代表indice位置的数据值
    • dense_shape:一维,代表稀疏矩阵的大小
      比如有两幅图,分别是123,和4567那么
      indecs = [[0,0],[0,1],[0,2],[1,0],[1,1],[1,2],[1,3]]
      values = [1,2,3,4,5,6,7]
      dense_shape = [2,4]
      代表dense tensor:
      1
      2
      [[1,2,3,0]
      [4,5,6,7]]

    seq_len: 在input一节中已经讲过,一维数据,[time_step,…,time_step]长度为batch_size,值为time_step

     
  • 相关阅读:
    python虚拟环境使用
    虚拟化网络进阶管理
    虚拟化进阶管理
    KVM虚拟化
    Xen虚拟化
    Virtualization基础
    Virnish使用
    CentOS配置FTP服务器
    Ajax结合Json进行交互数据(四)
    将 数据库中的结果集转换为json格式(三)
  • 原文地址:https://www.cnblogs.com/Libo-Master/p/8109691.html
Copyright © 2011-2022 走看看