zoukankan      html  css  js  c++  java
  • Tensorflow 使用slim框架下的分类模型进行分类

    Tensorflow的slim框架可以写出像keras一样简单的代码来实现网络结构(虽然现在keras也已经集成在tf.contrib中了),而且models/slim提供了类似之前说过的object detection接口类似的image classification接口,可以很方便的进行fine-tuning利用自己的数据集训练自己所需的模型。

    官方文档提供了比较详细的从数据准备,预训练模型的model zoo,fine-tuning,freeze model等一系列流程的步骤,但是缺少了inference的文档,不过tf所有模型的加载方式是通用的,所以调用方法和调用其他pb模型是一样的。

    根据TF开发人员是说法Tensorflow对于模型读写的保存和调用的步骤一般如下:Build your graph --> write your graph --> import from written graph --> run compute etc

    以下我们使用slim提供的网络inception-resnet-v2作为例子:

    1. export inference graph

    import tensorflow as tf
    import nets.inception_resnet_v2 as net
    
    slim = tf.contrib.slim
    
    # checkpoint path
    checkpoint_path = "/your/path/to/inception_resnet_v2.ckpt" # ckpt file obtained during model training or fine-tuning
    
    # set up and load session
    sess = tf.Session()
    arg_scope = net.inception_resnet_v2_arg_scope()
    # initialize tensor suitable for model input
    input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3])
    with slim.arg_scope(arg_scope):
        logits, end_points = net.inception_resnet_v2(inputs=input_tensor)
    
    # set up model saver
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)
    with tf.gfile.GFile('/your/path/to/model_graph.pb', 'w') as f:   # save model to given pb file
        f.write(sess.graph_def.SerializeToString()) 
    f.close()
    

    2. freeze model

    这里用tf提供的tensorflow/python/tools下的freeze_graph工具:

    $ bazel build tensorflow/python/tools:freeze_graph
    $ bazel-bin/tensorflow/python/tools/freeze_graph 
        --input_graph=/your/path/to/model_graph.pb    # obtained above
        --input_checkpoint=/your/path/to/inception_resnet_v2.ckpt 
        --input_binary=true
        --output_graph=/your/path/to/frozen_graph.pb 
        --output_node_names=InceptionResnetV2/Logits/Predictions   # output node name defined in inception resnet v2 net
    

    (Optional) visualize frozen graph

    LOG_DIR = ‘/tmp/graphdeflogdir’
    model_filename = '/your/path/to/frozen_graph.pb'
    
    with tf.Session() as sess:
        with tf.gfile.FastGFile(model_filename, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            graph = tf.import_graph_def(graph_def, name='')
        writer = tf.summary.FileWriter(LOG_DIR, graph_def)
    writer.close()
    

    然后用tensorborad --logdir=LOG_DIR选择graph就可以查看到frozen后的网络结构。

    3. inference

    import cv2
    import numpy as np
    
    def preprocess_inception(image_np, central_fraction=0.875): 
        image_height, image_width, image_channel = image_np.shape
        if central_fraction:
            bbox_start_h = int(image_height * (1 - central_fraction) / 2)
            bbox_end_h = int(image_height - bbox_start_h)
            bbox_start_w = int(image_width * (1 - central_fraction) / 2)
            bbox_end_w = int(image_width - bbox_start_w)
            image_np = image_np[bbox_start_h:bbox_end_h, bbox_start_w:bbox_end_w]
        # normalize
        image_np = 2 * (image_np / 255.) - 1
        return image_np
    
    image_np = cv2.imread("test.jpg")
    # preprocess image as inception resnet v2 does
    image_np = preprcess_inception(image_np)
    # resize to model input image size
    image_np = cv2.resize(image_np, (299, 299))
    # expand dims to shape [None, 299, 299, 3]
    image_np = np.expand_dims(image_np, 0)
    # load model
    with tf.gfile.GFile('/your/path/to/frozen_graph.pb')
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        graph = tf.import_graph_def(graph_def, name='')
        with tf.Session(graph=graph) as sess:
            input tensor = sess.graph.get_tensor_by_name("input:0")   # get input tensor 
            output_tensor = sess.graph.get_tensor_by_name("InceptionResnetV2/Logits/Predictions:0")  # get output tensor
            logits = sess.run(output_tensor, feed_dict={input_tensor: image_np})
            print "Prediciton label index:", np.argmax(logits[0], 1)
            print "Top 3 Prediciton label index:", np.argsort(logits[0], 3)
    

    参考:

    1. https://stackoverflow.com/questions/42961243/using-pre-trained-inception-v4-model
    2. https://gist.github.com/cchadowitz-pf/f1c3e781c125813f9976f6e69c06fec2
    3. https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
    4. https://github.com/tensorflow/models/blob/master/slim/README.md
    5. https://gist.github.com/tokestermw/795cc1fd6d0c9069b20204cbd133e36b
  • 相关阅读:
    EXT--columnWidth
    EXT经验--查询items的xtype
    修改VS解决方案及工程名,解决如何打开高/版本VS项目
    jQuery Ajax 全解析(转)
    MS SqlSever一千万条以上记录分页数据库优化经验总结【索引优化 + 代码优化】[转]
    .net框架版本说明
    [Ajax] 使用Ajax异步上传图片文件(非Form表单提交)
    CodeSmith 7.01破解下载
    jQuery插件之Cookie
    Oracle笔记 目录索引
  • 原文地址:https://www.cnblogs.com/arkenstone/p/7551270.html
Copyright © 2011-2022 走看看