zoukankan      html  css  js  c++  java
  • tensorflow: a Implementation of rotation ops (旋转的函数实现方法)

    tensorflow 旋转矩阵的函数实现方法

    关键字: rot90, tensorflow

    1. 背景

    在做数据增强的操作过程中, 很多情况需要对图像旋转和平移等操作, 针对一些特殊的卷积(garbo conv)操作,还需要对卷积核进行旋转操作.
    在tensorflow中似乎没有实现对4D tensor的旋转操作.
    严格的说: tensorflow对tensor的翻转操作并未实现, 仅有针对3D tensor的tf.image.rot()
    而在大多数的情况下使用的是4D形式的tensor, [B,W,H,C] 或者是3D的图像组成的batchs.

    通过查看这篇文章的代码可以知道[1] 可以使用numpy的rot90()函数旋转, 但是rot90对象是ndarray, 针对tensorflow.tensor对象而言显然是无法使用的, 会抛出类似: 无法找到m.dim属性的异常.
    也就是说无法使用numpy.rot90() 函数.

    又知, tensorflow中提供有对矩阵的翻转, 转置,切片操作的函数,但是没有提供旋转90°, 180°,270°的操作.
    因此可以参照numpy.rot90(m, k=1, axes=(0,1)) 的程序片段去自己动手实现.
    rot90中的第一个参数m是操作对象, k是旋转的次数,k=1 代表逆时针旋转90度, k=2 代表逆时针旋转180度,以此类推
    axes是代表旋转的操作在哪两个维度构成的平面上.

    rot90的源代码如下:

    def rot90(m, k=1, axes=(0,1)):
        '''
        ......
        '''
        # 省略检测参数的操作
        k %= 4
    
        if k == 0:
            return m[:]
        if k == 2:
            return flip(flip(m, axes[0]), axes[1])
    
        axes_list = arange(0, m.ndim)
        (axes_list[axes[0]], axes_list[axes[1]]) = (axes_list[axes[1]],
                                                    axes_list[axes[0]])
    
        if k == 1:
            return transpose(flip(m,axes[1]), axes_list)
        else:
            # k == 3
            return flip(transpose(m, axes_list), axes[1])
    

    PS: 通过阅读上述的代码,也可以发现在tensorflow中直接使用rot90所抛出的异常是在这里出现的

    if axes[0] == axes[1] or absolute(axes[0] - axes[1]) == m.ndim
    

    原因是: 程序把tensor对象当成np.ndarray操作了, 而tensor对象没有m.dim属性

    2. 实现rot90操作

    2.1 梳理程序流程

    通过查看源代码可以梳理出程序流程图:

    程序流程图

    2.2 tensorflow 实现旋转操作

    根据上述的流程图, 可以实现对tensorflow的rot90操作;

    def rot90(tensor,k=1,axes=[1,2],name=None):
        '''
        autor:lizh
        tensor: a tensor 4 or more dimensions
        k: integer, Number of times the array is rotated by 90 degrees.
        axes: (2,) array_like
            The array is rotated in the plane defined by the axes.
            Axes must be different.
        
        -----
        Returns
        -------
        tensor : tf.tensor
                 A rotated view of `tensor`.
        See Also: https://www.tensorflow.org/api_docs/python/tf/image/rot90 
        '''
        axes = tuple(axes)
        if len(axes) != 2:
            raise ValueError("len(axes) must be 2.")
            
        tenor_shape = (tensor.get_shape().as_list())
        dim = len(tenor_shape)
        
        if axes[0] == axes[1] or np.absolute(axes[0] - axes[1]) == dim:
            raise ValueError("Axes must be different.")
            
        if (axes[0] >= dim or axes[0] < -dim 
            or axes[1] >= dim or axes[1] < -dim):
            
            raise ValueError("Axes={} out of range for tensor of ndim={}."
                .format(axes, dim))
        k%=4
        if k==0:
            return tensor
        if k==2:
            img180 = tf.reverse(tf.reverse(tensor, axis=[axes[0]]),axis=[axes[1]],name=name)
            return img180
        
        axes_list = np.arange(0, dim)
        (axes_list[axes[0]], axes_list[axes[1]]) = (axes_list[axes[1]],axes_list[axes[0]]) # 替换
        
        print(axes_list)
        if k==1:
            img90=tf.transpose(tf.reverse(tensor,axis=[axes[1]]), perm=axes_list, name=name)
            return img90
        if k==3:
            img270=tf.reverse( tf.transpose(tensor, perm=axes_list),axis=[axes[1]],name=name)
            return img270
    

    2.3 代码测试

    # 加载库
    import numpy as np
    import matplotlib.pyplot as plt
    import tensorflow as tf
    
    # 手写体数据集 加载
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("/home/lizhen/data/MNIST/", one_hot=True)
    
    sess=tf.Session()
    #选取数据 4D
    images = mnist.train.images
    img_raw = images[0,:] # [0,784]
    img=tf.reshape(img_raw,[-1,28,28,1]) # img 现在是tensor
    # 绘图
    def fig_2D_tensor(tensor):# 绘图
        #plt.matshow(tensor, cmap=plt.get_cmap('gray'))
        plt.matshow(tensor) # 彩色图像
        # plt.colorbar() # 颜色条
        plt.show()
    # 显 显示 待旋转的图片
    fig_2D_tensor(sess.run(img)[0,:,:,0]) # 提取ndarray
    
    

    待操作的图片

    简单的测试一下代码:

    img11_rot=rot90(img,2) # 旋转两次90
    fig_2D_tensor(sess.run(img11_rot)[0,:,:,0]) # 打印图像
    
    img12_rot=rot90(img,1,[1,1]) # 抛出异常,  测试 Axes must be different.
    img13_rot=rot90(img,1,[0,6]) # 抛出异常,  测试 Axes must be different.
    
    img14_rot=rot90(img,axes=[1,5])# 抛出异常,测试out of range.
    
    img14_rot=rot90(img,axes=[-1,2]) # -1的下标是倒数第二个,测试out of range.
    

    测试结果:

    3总结

    okey了,现在可以用了.
    .....

    额,,,,,最近才发现tensorflow的最新版本,大约就在前几天发布的新版本(14天前, 1.10.1 )上已经添加了对2D,3D图像的操作,支持[B,W,H,C]格式的tensor做出旋转[2]

    星期五, 07. 九月 2018 02:49下午

    参考文献


    1. Understanding 2D Dilated Convolution Operation with Examples in Numpy and Tensorflow with Interactive Code ↩︎

    2. tensorflow/python/ops/image_ops#rot90 ↩︎

  • 相关阅读:
    JS事件学习笔记(思维导图)
    [logstash-input-file]插件使用详解
    echarts折线图,纵坐标数值显示不准确的问题解决
    IDEA 创建maven jar、war、 pom项目
    Lombok介绍、使用方法和总结
    Springboot2.0访问Redis集群
    springboot2.x 整合redis集群的几种方式
    SpringBoot 2.x 使用Redis作为项目数据缓存
    Springboot2.x使用redis作为缓存
    SpringBoot中application.yml基本配置详情
  • 原文地址:https://www.cnblogs.com/greentomlee/p/9604806.html
Copyright © 2011-2022 走看看