zoukankan      html  css  js  c++  java
  • 【ONNX Operator】ReverseSequence解析

    最近在写ONNX的算子测试代码,照着github上ONNX的算子描述文档,先写单个算子生成模型的代码,再设计算子的输入数据、属性以及参数。

    简单的算子如abs,明明白白的求绝对值操作,不需要看算子说明就可以设计出测试用例(正负数、零以及浮点数和定点数)。复杂点的如Scan,算子说明一大堆英文,到现在还没搞明白,决定放到最后再做。

    还有的算子如这篇文章所要解析的——ReverseSequence,除了需要仔细看算子描述外,还需查查网上别人写的文章,在纸上写写画画才能理解。

    既然都在纸上写写画画了,不如写篇文章记录下,既可以分享给有需要的朋友,也可以供自己以后温习。

    废话说完了,正文开始:

    ReverseSequence

    定义

    先引用ONNX的英文说明:

    Reverse batch of sequences having different lengths specified by sequence_lens.
    For each slice i iterating on batch axis, the operator reverses the first sequence_lens[i] elements on time axis, and copies elements whose index's beyond sequence_lens[i] to the output. So the output slice i contains reversed sequences on the first sequence_lens[i] elements, then have original values copied for the other elements.

    简单说就是将一个输入张量先按照行或者列进行分割,然后将分割后的子张量在指定的维度上进行翻转操作

    参数

    Inputs:

    • input 输入数据,rank大于或等于2
    • sequence_lens, 子张量i的翻转长度,int型数字,小于或等于子张量的长度

    Attributes:

    • batch_axis, batch的维度,参数为0或1,分别表示沿着行或者沿着列对input进行分割,分割后的切片设为子张量i
    • time_axis, 分割后的子张量i的前sequence_lens个元素在第time_axis维上进行翻转操作,time_axis的参数为0或者1,分别表示行或者列

    Outputs:

    • Output 输出shape和输入shape一致

    举例说明

    (1)

    input = [[0.0, 4.0, 8.0, 12.0], [1.0, 5.0, 9.0, 13.0], [2.0, 6.0, 10.0, 14.0], [3.0, 7.0, 11.0, 15.0]] 
    sequence_lens = [4, 3, 2, 1] 
    time_axis = 0 
    batch_axis = 1
    

    batch_axis设为1,表示沿着列进行分割;
    time_axis设为0,表示将分割后的子张量的行数据进行翻转;
    sequence_lens设为[4,3,2,1]表示分割后的子张量i对应的元素翻转数量是sequence_lens[i]。

    (2)

    input = [[0.0, 1.0, 2.0, 3.0 ], [4.0, 5.0, 6.0, 7.0 ], [8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0]] 
    sequence_lens = [1, 2, 3, 4] 
    time_axis = 1 
    batch_axis = 0
    

    batch_axis设为0,表示沿着行进行分割;
    time_axis设为1,表示将分割后的子张量的列数据进行翻转;
    sequence_lens设为[1,2,3,4]表示分割后的子张量i对应的元素翻转数量是sequence_lens[i]。

    补充:dims大于2维的input如何处理?

    上个部分的示例里使用的input是2维的——(4,4),也就是一个4行4列的矩阵。而ONNX的ReverseSequence算子的两个属性参数batch_axis和time_axis都固定了值只能是非0即1. 而这里的0和1刚好对应上矩阵的行和列,sequence_lens中的长度以及元素的大小只要低于input的长和宽也就是h和w(4和4)即可。

    但是,当input的dims大于2的时候,如何确定sequence_lens的长度和元素的限制范围呢?

    通过使用Tensorflow的reverse_sequence做了个小实验,代码如下:

    import tensorflow.compat.v1 as tf
    # shape of b: (2,3,2,2)
    b = tf.constant([[[[1, 2],
             [3, 4]],
    
            [[5, 6],
             [7, 8]],
    
            [[9, 10],
             [11, 12]]],
    
    
           [[[13, 14],
             [15, 16]],
    
            [[17, 18],
             [19, 20]],
    
            [[21, 22],
             [23, 24]]]])
    
    l = tf.constant([1, 3], tf.int64)
    z = tf.reverse_sequence(b, seq_lengths=l, seq_axis=1, batch_axis=0)
    
    l2 = tf.constant([2, 1, 2], tf.int64)
    z2 = tf.reverse_sequence(b, seq_lengths=l2, seq_axis=0, batch_axis=1)
    
    with tf.Session() as sess:
        print(sess.run(b))
        print("-"*60)
        print(sess.run(z))
        print("-"*60)
        print(sess.run(z2))
    

    z的输出如下:

    [[[[ 1  2]
       [ 3  4]]
    
      [[ 5  6]
       [ 7  8]]
    
      [[ 9 10]
       [11 12]]]
    
    
     [[[21 22]
       [23 24]]
    
      [[17 18]
       [19 20]]
    
      [[13 14]
       [15 16]]]]
    

    z2的输出如下:

    [[[[13 14]
       [15 16]]
    
      [[ 5  6]
       [ 7  8]]
    
      [[21 22]
       [23 24]]]
    
    
     [[[ 1  2]
       [ 3  4]]
    
      [[17 18]
       [19 20]]
    
      [[ 9 10]
       [11 12]]]]
    

    所以,sequence_lens的长度和元素的范围是:
    The length of sequence_lens = input_shape[batch_axis]
    The value of sequence_lens <= input_shape[time_axis]


    References:
    https://blog.csdn.net/Cerisier/article/details/80118611
    https://github.com/onnx/onnx/blob/master/docs/Operators.md#reversesequence
    https://www.tensorflow.org/api_docs/python/tf/reverse_sequence

  • 相关阅读:
    【vue】------ 路由创建 ------ 【William】
    【vue】------------@vue/cli3.7.0创建项目+ts-------------【William】
    【svn】--------------svn介绍------------------【William】
    【vue】------浅谈vue------【William】
    node创建服务器
    vue项目搭建
    利用vw做rem适配(纯css)
    nodejs实现md5和SHA256加密
    Cookie、session、localStorage、sessionStorage的区别
    tpc三次握手与四次挥手
  • 原文地址:https://www.cnblogs.com/liushengchieh/p/14971752.html
Copyright © 2011-2022 走看看