zoukankan      html  css  js  c++  java
  • 2 (自我拓展)部署花的识别模型(学习tensorflow实战google深度学习框架)

    kaggle竞赛的inception模型已经能够提取图像很好的特征,后续训练出一个针对当前图片数据的全连接层,进行花的识别和分类。这里见书即可,不再赘述。

    书中使用google参加Kaggle竞赛的inception模型重新训练一个全连接神经网络,对五种花进行识别,我姑且命名为模型flower_photos_model。我进一步拓展,将lower_photos_model模型进一步保存,然后部署和应用。然后,我们直接调用迁移之后又训练好的模型,对花片进行预测。

    这里讨论两种方式:使用import_meta_graph和使用saver()

    首先,原书的迁移学习的代码需要做一些改动。

    writer = tf.summary.FileWriter('./graphs/flower_photos_model_graph', sess.graph)
    saver.save(sess, "Saved_model/flower_photos_model.ckpt")

     Saver()方式

    我相较于训练flower_photos_model模型时,增添了一个变量的定义:

    即label_index=tf.argmax(final_tensor,1)

    def main():
        #先定义相同的计算图再加载迁移学习的模型
        bottleneck_input = tf.placeholder(tf.float32, [None, BOTTLENECK_TENSOR_SIZE], name='BottleneckInputPlaceholder')
        with tf.name_scope('final_training_ops'):
            weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, n_classes], stddev=0.001))
            biases = tf.Variable(tf.zeros([n_classes]))
            logits = tf.matmul(bottleneck_input, weights) + biases
            final_tensor = tf.nn.softmax(logits)
            label_index=tf.argmax(final_tensor,1)
    #利用import_meta_graph和import_graph_def加载的变量均不允许与当前定义计算图有冲突。
    #saver = tf.train.Saver()则只加载当前计算图中定义的。
        saver = tf.train.Saver()
        
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.700)  
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
            saver.restore(sess, "Saved_model/flower_photos_model.ckpt")
            #还是要加载一下inception模型 
            MODEL_DIR = './inception_dec_2015'
            MODEL_FILE= 'tensorflow_inception_graph.pb'
            with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
            bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])
            print (bottleneck_tensor)
            print (jpeg_data_tensor)
            #为了在tensorboard中观察加载的计算图。
            writer = tf.summary.FileWriter('./graphs/flower_photos_model_graph_use', sess.graph)
            writer.close()
            #image_path='./data/xiaojie_application/xiaojie_rose.jpg'
            image_path='./data/xiaojie_application/xiaojie_sunflowers.jpg'
            #image_path='./data/xiaojie_application/5547758_eea9edfd54_n.jpg'
            
            """测试一张图片,能否获取瓶颈向量。
            image_data = gfile.FastGFile(image_path, 'rb').read()
            print (sess.run(jpeg_data_tensor,{jpeg_data_tensor:image_data}))
            print ("xiaojie1")
            print (sess.run(bottleneck_tensor,{jpeg_data_tensor:image_data}))
            """
            label_index_value=evalution_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor,bottleneck_input,label_index)
            #print (label_index_value)
            classes=['daisy','dandelion','roses','sunflowers','tulips']
            print ("预测的花的类型:",classes[label_index_value[0]])

    相关的函数的定义:

    evalution_xiaojie输出预测的分类index。
    def evalution_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor,bottleneck_input,label_index):
    #输出一张图片的预测结果    bottleneck_values=get_bottleneck_values_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor)
        bottlenecks = []
        bottlenecks.append(bottleneck_values)
        label_index_value = sess.run(label_index, feed_dict={
                bottleneck_input: bottlenecks})
        return label_index_value

    获取瓶颈向量(关于瓶颈向量,见书)

    def get_bottleneck_values_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor):
        #瓶颈向量
        if not os.path.exists(CACHE_DIR): os.makedirs(CACHE_DIR)
        bottleneck_path = get_bottleneck_path_xiaojie(CACHE_DIR,image_path)
        print (bottleneck_path)
        if not os.path.exists(bottleneck_path):
            image_data = gfile.FastGFile(image_path, 'rb').read()
            bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor)
            bottleneck_string = ','.join(str(x) for x in bottleneck_values)
            with open(bottleneck_path, 'w') as bottleneck_file:
                bottleneck_file.write(bottleneck_string)
        else:
            with open(bottleneck_path, 'r') as bottleneck_file:
                bottleneck_string = bottleneck_file.read()
                bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
        return bottleneck_values

    使用inception模型计算瓶颈向量

    def run_bottleneck_on_image(sess, image_data, image_data_tensor, bottleneck_tensor):
        print("yes")
        bottleneck_values = sess.run(bottleneck_tensor, {image_data_tensor: image_data})
        bottleneck_values = np.squeeze(bottleneck_values)
        print("no")
        return bottleneck_values

    瓶颈向量有一个缓存文件,这也是类似于原书训练迁移学习模型时的做法

    def get_bottleneck_path_xiaojie(CACHE_DIR,image_path):
        file_name_suffix=image_path.split('/')[-1]
        file_name_no_suffix=file_name_suffix.split('.')[0]
        bottleneck_file_name=file_name_no_suffix+('_cache.txt')
        bottleneck_path=os.path.join(CACHE_DIR, bottleneck_file_name)
        return bottleneck_path

    定义的全局变量

    BOTTLENECK_TENSOR_SIZE = 2048
    n_classes = 5
    BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
    JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
    CACHE_DIR='./data/xiaojie_application/cache_bottleneck/'

    Saver方式的说明:

    Saver只能导出持久化模型中与当前代码定义计算图相匹配的部分。

    因此,对于之前inception也需要再一次重新加载。

    此外,当前代码定义计算图,比持久化模型flower_photos_model多定义了一个变量,即label_index=tf.argmax(final_tensor,1),即输出预测的分类index。

    import_meta_graph方式

    import_meta_graph方式与saver方式的不同点在于会导入完整的计算图,因此当前代码不能定义和要加载计算图相互冲突的部分。

    相关函数定义的代码均不变。只将main函数的内容和全局变量改为:

    def main():
        #如果使用tf.train.import_meta_graph的话,就会重复加载计算图。因此,避免重复,当前代码中不能定义重复的。
         
        #saver = tf.train.Saver()
        saver = tf.train.import_meta_graph("Saved_model/flower_photos_model.ckpt.meta")
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.700)  
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        #with tf.Session() as sess:
            #如果直接使用saver = tf.train.Saver()和restore还原一个model.ckpt文件,是不可能将之前迁移学习那个模型利用import_graph_def加载的inception模型加载进来的。
            saver.restore(sess, "Saved_model/flower_photos_model.ckpt")
    
            bottleneck_tensor= sess.graph.get_tensor_by_name(import_BOTTLENECK_TENSOR_NAME)
            jpeg_data_tensor = sess.graph.get_tensor_by_name(import_JPEG_DATA_TENSOR_NAME)
            print (bottleneck_tensor)
            print (jpeg_data_tensor)
    
            writer = tf.summary.FileWriter('./graphs/flower_photos_model_graph_use', sess.graph)
            writer.close()
            image_path='./data/xiaojie_application/xiaojie_rose.jpg'
            #image_path='./data/xiaojie_application/xiaojie_sunflowers.jpg'
            #image_path='./data/xiaojie_application/5547758_eea9edfd54_n.jpg'
            
            """测试一张图片
            image_data = gfile.FastGFile(image_path, 'rb').read()
            print (sess.run(jpeg_data_tensor,{jpeg_data_tensor:image_data}))
            print ("xiaojie1")
            print (sess.run(bottleneck_tensor,{jpeg_data_tensor:image_data}))
            """
            bottleneck_input= sess.graph.get_tensor_by_name("BottleneckInputPlaceholder:0")
            final_tensor = sess.graph.get_tensor_by_name("final_training_ops/Softmax:0")
            label_index=tf.argmax(final_tensor,1)
            label_index_value=evalution_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor,bottleneck_input,label_index)
            print (label_index_value)
            classes=['daisy','dandelion','roses','sunflowers','tulips']
            print ("预测的花的类型:",classes[label_index_value[0]]) 

    全局变量改为:

    import_BOTTLENECK_TENSOR_NAME = 'import/pool_3/_reshape:0'

    import_JPEG_DATA_TENSOR_NAME = 'import/DecodeJpeg/contents:0'

    这是因为,使用import_meta_graph方式的话,当前代码不能定义任何与持久化模型中计算图冲突的节点。此外,在flower_photos_model模型对全连接层进行训练的过程中,已经利用import_graph_def的方式导入google Inception v3的持久化模型pb文件,因此,已经包括了google的模型。通过在tensorboard中查看,会发现,所有导入的模块节点之前会带上import节点。因此,在训练flower_photos_model模型时,使用的是pool_3/_reshape:0获取张量,而此时,只能使用import/pool_3/_reshape:0'获取张量。

    只能使用import/pool_3/_reshape:0'获取张量。

            final_tensor = sess.graph.get_tensor_by_name("final_training_ops/Softmax:0")

    然后,我们再定义一个label_index

            label_index=tf.argmax(final_tensor,1)

    因此,同saver模型一样,所有的其它函数接口和实现都不用变。

    最后的结果很nice。可以识别五种花朵,可以直接部署应用。

    程序附件

    链接:https://pan.baidu.com/s/11YtyDEyV84jONPi9tO2TCw 密码:8mfj

  • 相关阅读:
    SkyWalking结合Logback获取全局唯一标识 trace-id 记录到日志中
    Mysql数据库优化技术
    MySQL中集合的差的运算方法
    深入理解Java ClassLoader及在 JavaAgent 中的应用
    自制吸锡带
    Ubuntu下双显示器设定
    ffmpeg 命令的使用
    ifeq ifneq ifdef ifndef
    字符对齐
    ruby on rails使用gmail的smtp发送邮件
  • 原文地址:https://www.cnblogs.com/xiaojieshisilang/p/9232685.html
Copyright © 2011-2022 走看看