zoukankan      html  css  js  c++  java
  • tensorflow 1.0 学习:参数和特征的提取

    在tf中,参与训练的参数可用 tf.trainable_variables()提取出来,如:

    #取出所有参与训练的参数
    params=tf.trainable_variables()
    print("Trainable variables:------------------------")
    
    #循环列出参数
    for idx, v in enumerate(params):
         print("  param {:3}: {:15}   {}".format(idx, str(v.get_shape()), v.name))

    这里只能查看参数的shape和name,并没有具体的值。如果要查看参数具体的值的话,必须先初始化,即:

    sess=tf.Session()
    sess.run(tf.global_variables_initializer())

    同理,我们也可以提取图片经过训练后的值。图片经过卷积后变成了特征,要提取这些特征,必须先把图片feed进去。

    具体看实例:

    # -*- coding: utf-8 -*-
    """
    Created on Sat Jun  3 12:07:59 2017
    
    @author: Administrator
    """
    
    import tensorflow as tf
    from skimage import io,transform
    import numpy as np
    
    #-----------------构建网络----------------------
    #占位符
    x=tf.placeholder(tf.float32,shape=[None,100,100,3],name='x')
    y_=tf.placeholder(tf.int32,shape=[None,],name='y_')
    
    #第一个卷积层(100——>50)
    conv1=tf.layers.conv2d(
          inputs=x,
          filters=32,
          kernel_size=[5, 5],
          padding="same",
          activation=tf.nn.relu,
          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
    pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
    
    #第二个卷积层(50->25)
    conv2=tf.layers.conv2d(
          inputs=pool1,
          filters=64,
          kernel_size=[5, 5],
          padding="same",
          activation=tf.nn.relu,
          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
    pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
    
    #第三个卷积层(25->12)
    conv3=tf.layers.conv2d(
          inputs=pool2,
          filters=128,
          kernel_size=[3, 3],
          padding="same",
          activation=tf.nn.relu,
          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
    pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)
    
    #第四个卷积层(12->6)
    conv4=tf.layers.conv2d(
          inputs=pool3,
          filters=128,
          kernel_size=[3, 3],
          padding="same",
          activation=tf.nn.relu,
          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
    pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)
    
    re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])
    
    #全连接层
    dense1 = tf.layers.dense(inputs=re1, 
                          units=1024, 
                          activation=tf.nn.relu,
                          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                          kernel_regularizer=tf.nn.l2_loss)
    dense2= tf.layers.dense(inputs=dense1, 
                          units=512, 
                          activation=tf.nn.relu,
                          kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                          kernel_regularizer=tf.nn.l2_loss)
    logits= tf.layers.dense(inputs=dense2, 
                            units=5, 
                            activation=None,
                            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                            kernel_regularizer=tf.nn.l2_loss)
    
    #---------------------------网络结束---------------------------
    #%%
    #取出所有参与训练的参数
    params=tf.trainable_variables()
    print("Trainable variables:------------------------")
    
    #循环列出参数
    for idx, v in enumerate(params):
         print("  param {:3}: {:15}   {}".format(idx, str(v.get_shape()), v.name))
    
    #%%
    #读取图片
    img=io.imread('d:/cat.jpg')
    #resize成100*100
    img=transform.resize(img,(100,100))
    #三维变四维(100,100,3)-->(1,100,100,3)
    img=img[np.newaxis,:,:,:]
    img=np.asarray(img,np.float32)
    sess=tf.Session()
    sess.run(tf.global_variables_initializer()) 
    
    #提取最后一个全连接层的参数 W和b
    W=sess.run(params[26])
    b=sess.run(params[27])
    
    #提取第二个全连接层的输出值作为特征    
    fea=sess.run(dense2,feed_dict={x:img})

    最后一条语句就是提取某层的数据输出作为特征。

    注意:这个程序并没有经过训练,因此提取出的参数只是初始化的参数。

  • 相关阅读:
    MySQL之pymysql模块
    MySQL 之 索引原理与慢查询优化
    MySQL 之 视图、触发器、存储过程、函数、事物与数据库锁
    MySql之数据操作
    MySQL之多表查询
    MySQL之单表查询
    MySQL之表的约束
    MySQL之表操作
    MySQL之表的数据类型
    pycharm 2016 注册(pycharm-professional-2016.3.2)
  • 原文地址:https://www.cnblogs.com/denny402/p/6937084.html
Copyright © 2011-2022 走看看