zoukankan      html  css  js  c++  java
  • tesnorflow conv deconv,padding

    1.padding test

    input = tf.placeholder(tf.float32, shape=(1,2, 2,1))
    simpleconv=slim.conv2d(input,1,[3,3],stride = 1,activation_fn = None,scope = 'simpleconv3')
    sess.run(tf.global_variables_initializer())
    weights=graph.get_tensor_by_name("simpleconv3/weights:0")
    sess.run(tf.assign(weights,tf.constant(1.0,shape=weights.shape)))
    a=np.ndarray(shape=(1,2,2,1),dtype='float',buffer=np.array([1.0,2,3,4]))
    simpleconvout=sess.run(simpleconv,feed_dict={input:a.astype('float32')})
    print simpleconvout
    [[[[ 10.000000]
    [ 10.000000]]
    
    [[ 10.000000]
    [ 10.000000]]]]
    
    input1 = tf.placeholder(tf.float32, shape=(1,4, 4,1))
    simpleconv=slim.conv2d(input1,1,[3,3],stride = 2,activation_fn = None,scope = 'simpleconv3')
    sess.run(tf.global_variables_initializer())
    weights=graph.get_tensor_by_name("simpleconv3/weights:0")
    sess.run(tf.assign(weights,tf.constant(1.0,shape=weights.shape)))
    a=np.ndarray(shape=(1,4,4,1),dtype='float',buffer=np.array([1.0,2,3,4,2,3,4,5,3,4,5,6,4,5,6,7]))
    simpleconvout=sess.run(simpleconv,feed_dict={input1:a.astype('float32')})
    
    print simpleconvout
    
    [[[[ 27.]
    [ 27.]]
    
    [[ 27.]
    [ 24.]]]]
    
    simpledeconv=slim.conv2d_transpose(input,1,[3,3],stride = 2,activation_fn = None,scope = 'simpledeconv')
    sess.run(tf.global_variables_initializer())
    weights=graph.get_tensor_by_name("simpledeconv/weights:0")
    sess.run(tf.assign(weights,tf.constant(1.0,shape=weights.shape)))
    a=np.ndarray(shape=(1,2,2,1),dtype='float',buffer=np.array([1.0,2,3,4]))
    simpleconvout=sess.run(simpledeconv,feed_dict={input:a.astype('float32')})
    print simpleconvout
    
    [[[[ 1.000000]
    [ 1.000000]
    [ 3.000000]
    [ 2.000000]]
    
    [[ 1.000000]
    [ 1.000000]
    [ 3.000000]
    [ 2.000000]]
    
    [[ 4.000000]
    [ 4.000000]
    [ 10.000000]
    [ 6.000000]]
    
    [[ 3.000000]
    [ 3.000000]
    [ 7.000000]
    [ 4.000000]]]]
    
     
    
    conv stride=1是四周padding 0,stride=2是down right padding 0
    
    deconv是top left各插了两行0
    
    而torch中的deconv是四周padding一圈0
    View Code

    参考http://blog.csdn.net/lujiandong1/article/details/53728053 

    'SAME' padding方式时,如果padding的数目是奇数,则多的padding在右边(下边)

    2.实现custom-padding

    https://stackoverflow.com/questions/37659538/custom-padding-for-convolutions-in-tensorflow 

    实现custom conv decon
    def conv(input,num_outputs,kernel_size,stride=1,padW=0,padH=0,activation_fn=None,scope=None):
        padded_input = tf.pad(input, [[0, 0], [padH, padH], [padW, padW], [0, 0]], "CONSTANT")
        return slim.conv2d(padded_input,num_outputs,kernel_size,stride = stride,padding="VALID",activation_fn = activation_fn ,scope = scope)
    input1 = tf.placeholder(tf.float32, shape=(1,4, 4,1))
    a=np.ndarray(shape=(1,4,4,1),dtype='float',buffer=np.array([1.0,2,3,4,2,3,4,5,3,4,5,6,4,5,6,7]))
    simpleconv=conv(input1,1,[3,3],stride = 2,padW=1,padH=1,activation_fn = None,scope = 'conv')
    sess.run(tf.global_variables_initializer())
    weights=graph.get_tensor_by_name("conv/weights:0")
    sess.run(tf.assign(weights,tf.constant(1.0,shape=weights.shape)))
    simpleconvout=sess.run(simpleconv,feed_dict={input1:a.astype('float32')})
    print simpleconvout
    [[[[ 8.]
    [ 21.]]
    
    [[ 21.]
    [ 45.]]]]
    
     
    
     
    
     
    
    def deconv(input,num_outputs,kernel_size,stride=2,activation_fn=None,scope=None):
        N,H,W,C = [i.value for i in input.get_shape()]
        out = slim.conv2d_transpose(input,num_outputs,kernel_size,stride = stride,padding="VALID",activation_fn = activation_fn ,scope = scope)
        return tf.slice(out, [0, kernel_size[0]/2,kernel_size[1]/2, 0], [N, H*stride, W*stride,num_outputs])
    
    input = tf.placeholder(tf.float32, shape=(1,2, 2,1))
    a=np.ndarray(shape=(1,2,2,1),dtype='float',buffer=np.array([1.0,2,3,4]))
    simpledeconv=deconv(input,1,[3,3],stride = 2,activation_fn = None,scope = 'simpledeconv1')
    sess.run(tf.global_variables_initializer())
    weights=graph.get_tensor_by_name("simpledeconv1/weights:0")
    sess.run(tf.assign(weights,tf.constant(1.0,shape=weights.shape)))
    out=sess.run(simpledeconv,feed_dict={input:a.astype('float32')})
    print out
    
    [[[[ 1.]
    [ 3.]
    [ 2.]
    [ 2.]]
    
    [[ 4.]
    [ 10.]
    [ 6.]
    [ 6.]]
    
    [[ 3.]
    [ 7.]
    [ 4.]
    [ 4.]]
    
    [[ 3.]
    [ 7.]
    [ 4.]
    [ 4.]]]]
    View Code

  • 相关阅读:
    循环调用spring的dao,数个过后无响应
    WebEx如何录制电脑内的声音
    java对象转换String类型的三种方法
    使用Hibernate+MySql+native SQL的BUG,以及解决办法
    mysql之触发器trigger
    mysql 触发器学习
    Java对比两个数据库中的表和字段,写个冷门的东西
    PHP几个快速读取大文件例子
    Java安全中的“大坑”,跨平台真“浮云”
    国内一些大公司的开源项目
  • 原文地址:https://www.cnblogs.com/mlj318/p/7122874.html
Copyright © 2011-2022 走看看