zoukankan      html  css  js  c++  java
  • tensorflow 1.0 学习:参数和特征的提取

    在tf中,参与训练的参数可用 tf.trainable_variables()提取出来,如:

    #取出所有参与训练的参数
    params=tf.trainable_variables()
    print("Trainable variables:------------------------")
    
    #循环列出参数
    for idx, v in enumerate(params):
         print("  param {:3}: {:15}   {}".format(idx, str(v.get_shape()), v.name))

    这里只能查看参数的shape和name,并没有具体的值。如果要查看参数具体的值的话,必须先初始化,即:

    sess=tf.Session()
    sess.run(tf.global_variables_initializer())

    同理,我们也可以提取图片经过训练后的值。图片经过卷积后变成了特征,要提取这些特征,必须先把图片feed进去。

    具体看实例:

    # -*- coding: utf-8 -*-
    """
    Created on Sat Jun  3 12:07:59 2017
    
    @author: Administrator
    """
    
    import tensorflow as tf
    from skimage import io,transform
    import numpy as np
    
    #-----------------构建网络----------------------
    #占位符
    x=tf.placeholder(tf.float32,shape=[None,100,100,3],name='x')
    y_=tf.placeholder(tf.int32,shape=[None,],name='y_')
    
    #第一个卷积层(100——>50)
    conv1=tf.layers.conv2d(
          inputs=x,
          filters=32,
          kernel_size=[5, 5],
          padding="same",
          activation=tf.nn.relu,
          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
    pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
    
    #第二个卷积层(50->25)
    conv2=tf.layers.conv2d(
          inputs=pool1,
          filters=64,
          kernel_size=[5, 5],
          padding="same",
          activation=tf.nn.relu,
          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
    pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
    
    #第三个卷积层(25->12)
    conv3=tf.layers.conv2d(
          inputs=pool2,
          filters=128,
          kernel_size=[3, 3],
          padding="same",
          activation=tf.nn.relu,
          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
    pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)
    
    #第四个卷积层(12->6)
    conv4=tf.layers.conv2d(
          inputs=pool3,
          filters=128,
          kernel_size=[3, 3],
          padding="same",
          activation=tf.nn.relu,
          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
    pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)
    
    re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])
    
    #全连接层
    dense1 = tf.layers.dense(inputs=re1, 
                          units=1024, 
                          activation=tf.nn.relu,
                          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                          kernel_regularizer=tf.nn.l2_loss)
    dense2= tf.layers.dense(inputs=dense1, 
                          units=512, 
                          activation=tf.nn.relu,
                          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                          kernel_regularizer=tf.nn.l2_loss)
    logits= tf.layers.dense(inputs=dense2, 
                            units=5, 
                            activation=None,
                            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                            kernel_regularizer=tf.nn.l2_loss)
    
    #---------------------------网络结束---------------------------
    #%%
    #取出所有参与训练的参数
    params=tf.trainable_variables()
    print("Trainable variables:------------------------")
    
    #循环列出参数
    for idx, v in enumerate(params):
         print("  param {:3}: {:15}   {}".format(idx, str(v.get_shape()), v.name))
    
    #%%
    #读取图片
    img=io.imread('d:/cat.jpg')
    #resize成100*100
    img=transform.resize(img,(100,100))
    #三维变四维(100,100,3)-->(1,100,100,3)
    img=img[np.newaxis,:,:,:]
    img=np.asarray(img,np.float32)
    sess=tf.Session()
    sess.run(tf.global_variables_initializer()) 
    
    #提取最后一个全连接层的参数 W和b
    W=sess.run(params[26])
    b=sess.run(params[27])
    
    #提取第二个全连接层的输出值作为特征    
    fea=sess.run(dense2,feed_dict={x:img})

    最后一条语句就是提取某层的数据输出作为特征。

    注意:这个程序并没有经过训练,因此提取出的参数只是初始化的参数。

  • 相关阅读:
    大型网站随着业务的增长架构演进
    springboot日志logback配置
    一些容易出错的细节
    从一个下载优化说起
    徒手优化冒泡排序
    php设计模式之观察者模式
    php设计模式之抽象工厂模式
    phper談談最近重構代碼的感受(3)
    php设计模式----工厂模式
    偏执的我从Linux到Windows的感受
  • 原文地址:https://www.cnblogs.com/denny402/p/6937084.html
Copyright © 2011-2022 走看看