zoukankan      html  css  js  c++  java
  • TF版网络模型搭建常用代码备忘

    本文主要介绍如何搭建一个网络并训练

           最近,我在写代码时经常碰到这样的情况,明明记得代码应该怎么写,在写出来的代码调试时,总是有些小错误。原因不是接口参数个数不对,就是位置不对。为了节约上网查找时间,现记录下常用操作,以备需要时快速查看。

           根据网络结构不同功能,主要分这几大块:网络基本结构元组件,网络常用结构,Tensorboard调试接口,数据预处理常用操作,后处理常用操作。

    1、搭建一个基础网络所需的元组件:

    import tensorflow as tf
    import glog as log
    
    class basenet(object):
    '''
    base model for other specific cnn
    '''
    def __init__(self):
          pass
    
    def weight_variable(self,shape,sdtdev=0.1,name):
    initial
    =tf.truncated_normal(shape=shape,mean=0.0,stddev=sdtdev)
    if name is None:
    return tf.Variable(initial)
    else:
    return tf.get_variable(name=name,initial=initialdef bias_variable(self,shape,name): initial=tf.constant(.01,shape=shape)
         if name is None:
    return tf.Variable(initial,name=name)
    else:
    return tf.get_variable(name=name,initial=initial)
    @staticmethod def conv2d(self,x,w,s=1,name=None,padding='SAME'): #with tf.variable_scope(name): if s == 1: x = tf.nn.conv2d(x,w,strides=[1,s,s,1],padding=padding) else: x = tf.nn.conv2d(x,w,strides=[1,s,s,1],padding=padding) #log.info('basenet conv2d x:{:}'.format(x.get_shape().as_list())) return x

    def conv2d_transpose(x,w,b,output_shape,stride=2):
    if output_shape is None:
    output_shape =x.get_shape().as_list()
    output_shape[1]*=2
    output_shape[2]*=2
    output_shape[3]=w.get_shape().as_list()[2]
    conv = tf.nn.conv2d_transpose(x,w,output_shape,strides=[1,stride,stride,1],padding='SAME')
    return tf.nn.bias_add(conv,b)
    def maxpool(self,x,k=2,s=2,padding='SAME'): x= tf.nn.max_pool(x,ksize=[1,k,k,1],strides=[1,s,s,1],padding=padding) return x

    def avgpool(self,x,k,s,padding='SAME'):
    x= tf.nn.avg_pool(x,ksize=[1,k,k,1],strides=[1,s,s,1],padding=padding)
    def local_response_norm(x):
    return tf.nn.lrn(x,depth_radius=5,bias=2,alpha=1e-4,beta=0.75)
    def relu(self,x,name): x = tf.nn.relu(x) return x

    def leaky_relu(x,alpha=0.0,name=""):
    return tf.maximum(alpha*x,x,name)
    def relu6(self,x): x= min(max(0,x), 6) x = tf.nn.relu(x) return x
    def batch_norm(x,output,phrase,scope='bn',decay=0.9,eps=1e-5):
    with tf.variable_scope(scope):
    beta=tf.get_variable(name='beta',shape=[output],initializer=tf.constant_initializer(0.05))
    gamma=tf.get_variable(name='gamma',shape=[output],initialzer=tf.random_normal_initializer(1.0,0.02)
    batch_mean,batch_var=tf.nn.moment(x,[0,1,2],name='moment')

    def mean_var_2_update():
    ema_apply_op = ema.apply([batch_mean,batch_var])
    with tf.control_dependencies([ema_apply_op]):
    return tf.identity(batch_mean),tf.identity(batch_var)
    mean,var = tf.cond(phrase,mean_var_2_update,lambda:(ema.average(batch_mean),ema.average(batch_var))
    normed = tf.nn.batch_normalization(x,mean,var,beta,gamma,eps)
    return normed
    def wx_b(self,x,w,b): x = tf.matmul(x,w)+b log.info('basenet wx_b x:{:}'.format(x.get_shape().as_list())) return x def fc(self,x,w,b): x = tf.add(tf.matmul(x,w),b) return x

    2、常用网络结构

    网络结构1:
    
    def bottleneck_unit(x,out_chan1,out_chan2,down_stride=False,up_stride=False,name=None):

    3、数据预处理常用操作

    def save_img(image,save_dir,name,mean=None):
        if mean:
            image=unprocess_image(image,mean)
        misc.imsave(os.path.join(save_dir,name+'.png'),image)
    
    def process_image(image,mean_pixel):
        return image-mean_pixel

    4、Tensorboard常用接口

    def add_regular_to_summary(var):
        if var is not None:
            tf.summary.histogram(var.op.name,var)
            tf.add_to_collection('reg_loss',tf.nn.l2_loss(var))
    
    
    def add_activation_to_summary(var):
        if var is not None: 
            tf.summary.histogram(var.op.name+'/activation',var)
            tf.summary.scalar(var.op.name+'/sparsity',tf.nn.zero_fraction(var))
    
    def add_gradient_to_summary(grad,var):
        if grad is not None:
            tf.summary.histogram(var.op.name+'/gradient',grad)
  • 相关阅读:
    十步完全理解SQL
    c#退出应用程序办法
    几个有意思的算法题
    GeoServer不同服务器安装配置、数据发布及客户端访问
    开启httpd服务的时候 显示Could not reliably determine the server`s fully qualified domain name
    Working With OpenLayers(Section 1: Creating a Basic Map)
    GeoServer地图开发解决方案(五):基于Silverlight技术的地图客户端实现
    模拟远程HTTP的POST请求
    模拟提交带附件的表单
    支付宝手机网站接口对接
  • 原文地址:https://www.cnblogs.com/jimchen1218/p/12047944.html
Copyright © 2011-2022 走看看