zoukankan      html  css  js  c++  java
  • tensorflow从训练自定义CNN网络模型到Android端部署tflite

    网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型在部署到安卓端的时候出现各种问题。因此,本文会记录从PC端训练、导出到安卓端部署的各种细节。欢迎大家讨论、指教。

    PC端系统:Ubuntu14

    tensorflow版本:tensroflow1.14

    安卓版本:9.0

    PC端训练过程

    数据集:自定义生成

    训练框架:tensorflow slim  关于tensorflow slim如何安装,这里不再赘述,大家自行百度解决。

    数据生成代码:生成50000张28*28大小三通道的验证码图片,共分10类,0-9,生成的数据保存在datasets/images/里面

    # -*- coding: utf-8 -*-
    
    import cv2
    import numpy as np
    
    from captcha.image import ImageCaptcha
    
    
    def generate_captcha(text='1'):
        """Generate a digit image."""
        capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
        image = capt.generate_image(text)
        image = np.array(image, dtype=np.uint8)
        return image
        
        
    if __name__ == '__main__':
        output_dir = './datasets/images/'
        for i in range(50000):
            label = np.random.randint(0, 10)
            image = generate_captcha(str(label))
            image_name = 'image{}_{}.jpg'.format(i+1, label)
            output_path = output_dir + image_name
            cv2.imwrite(output_path, image)

    训练:本次训练我用tensorflow slim 搭建了一个七层卷积的网络,最后测试准确率在96%~99%左右,模型1.2M,适合在移动端部署。训练的时候我做了两点工作

    1、指明了模型的输入和输出节点的名字,PC端部署测试模型的时候要用到,也便于快速确定模型的输出数据到底是什么格式,移动端代码要与其保持一致

    inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs')
    .......
    .......
    prob_ = tf.identity(prob, name='prob')

    2、训练结束的时候直接把模型保存成PB格式

            constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #训练完毕直接把模型保存为PB格式
            with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb
                f.write(constant_graph.SerializeToString())

    训练代码如下

    # -*- coding: utf-8 -*-
    
    """Train a CNN model to classifying 10 digits.
    
    Example Usage:
    ---------------
    python3 train.py 
        --images_path: Path to the training images (directory).
        --model_output_path: Path to model.ckpt.
    """
    
    import cv2
    import glob
    import numpy as np
    import os
    import tensorflow as tf
    
    import model
    from tensorflow.python.framework import graph_util
    
    flags = tf.app.flags
    
    flags.DEFINE_string('images_path', None, 'Path to training images.')
    flags.DEFINE_string('model_output_path', None, 'Path to model checkpoint.')
    FLAGS = flags.FLAGS
    
    
    def get_train_data(images_path):
        """Get the training images from images_path.
        
        Args: 
            images_path: Path to trianing images.
            
        Returns:
            images: A list of images.
            lables: A list of integers representing the classes of images.
            
        Raises:
            ValueError: If images_path is not exist.
        """
        if not os.path.exists(images_path):
            raise ValueError('images_path is not exist.')
            
        images = []
        labels = []
        images_path = os.path.join(images_path, '*.jpg')
        count = 0
        for image_file in glob.glob(images_path):
            count += 1
            if count % 100 == 0:
                print('Load {} images.'.format(count))
            image = cv2.imread(image_file)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            # Assume the name of each image is imagexxx_label.jpg
            label = float(image_file.split('_')[-1].split('.')[0])
            images.append(image)
            labels.append(label)
        images = np.array(images)
        labels = np.array(labels)
        return images, labels
    
    
    def next_batch_set(images, labels, batch_size=128):
        """Generate a batch training data.
        
        Args:
            images: A 4-D array representing the training images.
            labels: A 1-D array representing the classes of images.
            batch_size: An integer.
            
        Return:
            batch_images: A batch of images.
            batch_labels: A batch of labels.
        """
        indices = np.random.choice(len(images), batch_size)
        batch_images = images[indices]
        batch_labels = labels[indices]
        return batch_images, batch_labels
    
    
    def main(_):
        inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs')
        labels = tf.placeholder(tf.int32, shape=[None], name='labels')
        
        cls_model = model.Model(is_training=True, num_classes=10)
        preprocessed_inputs = cls_model.preprocess(inputs)#预处理
        prediction_dict = cls_model.predict(preprocessed_inputs)
        loss_dict = cls_model.loss(prediction_dict, labels)
        loss = loss_dict['loss']
        postprocessed_dict = cls_model.postprocess(prediction_dict)
        classes = postprocessed_dict['classes']
        prob = postprocessed_dict['prob']
        classes_ = tf.identity(classes, name='classes')
        prob_ = tf.identity(prob, name='prob')
        acc = tf.reduce_mean(tf.cast(tf.equal(classes, labels), 'float'))
        
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(0.05, global_step, 150, 0.9)
        
        optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)
        train_step = optimizer.minimize(loss, global_step)
        
        saver = tf.train.Saver()
        
        images, targets = get_train_data(FLAGS.images_path)
        
        init = tf.global_variables_initializer()
        
        with tf.Session() as sess:
            sess.run(init)
            
            for i in range(6000):
                batch_images, batch_labels = next_batch_set(images, targets)
                train_dict = {inputs: batch_images, labels: batch_labels}
                
                sess.run(train_step, feed_dict=train_dict)
                
                loss_, acc_,prob__,classes__ = sess.run([loss, acc, prob_,classes_], feed_dict=train_dict)
                
                train_text = 'step: {}, loss: {}, acc: {},classes:{}'.format(
                    i+1, loss_, acc_,classes__)
                print(train_text)
                print (prob__)
            saver.save(sess, FLAGS.model_output_path)
            constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #训练完毕直接把模型保存为PB格式
            with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb
                f.write(constant_graph.SerializeToString())    
    if __name__ == '__main__':
        tf.app.run()

    这里尤其要注意,训练的时候图片是否做过预处理,比如减去均值和除法归一化操作,因为移动端需要保持和训练时候一样的操作。我的在训练的时候,预处理工作中包含了减去均值和除法归一化,并且把这两个OP打包直接放进了模型里面,也就是说图片数据进入模型之后会先进行预处理然后再进行正式的卷积等系列操作。所以,移动端的数据不需要单独写预处理的代码。很多时候,导出模型的时候并没有把预处理操作打包进模型,所以移动端要单独写几行关于减去均值和归一化的代码,然后再把数据送到分类模型当中。

    另外一种把ckpt模型导出为pb模型的方式,代码如下

    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    def freeze_graph(input_checkpoint,output_graph):
        '''
        :param input_checkpoint:
        :param output_graph: PB模型保存路径
        :return:
        '''
        # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
        # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
     
        # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
        #input_node_names = "inputs"
        output_node_names = "inputs,classes"
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
        graph = tf.get_default_graph() # 获得默认的图
        input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图
     
        with tf.Session() as sess:
            saver.restore(sess, input_checkpoint) #恢复图并得到数据
            output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
                sess=sess,
                input_graph_def=input_graph_def,# 等于:sess.graph_def
                output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
     
            with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
                f.write(output_graph_def.SerializeToString()) #序列化输出
            print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
     
            # for op in graph.get_operations():
            #     print(op.name, op.values())
    # 输入ckpt模型路径
    input_checkpoint='model/model.ckpt'
    # 输出pb模型的路径
    out_pb_path="frozen_model.pb"
    # 调用freeze_graph将ckpt转为pb
    freeze_graph(input_checkpoint,out_pb_path)

    把PB模型导出为tflite格式代码

    import tensorflow as tf
    #把pb文件路径改成自己的pb文件路径即可
    path = "model2.pb"
     
    #如果是不知道自己的模型的输入输出节点,建议用tensorboard做可视化查看计算图,计算图里有输入输出的节点名称
    inputs = ["inputs"]
    outputs = ["prob"]
    #转换pb模型到tflite模型
    converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs, outputs)
    #converter.post_training_quantize = True
    tflite_model = converter.convert()
    open("model3.tflite", "wb").write(tflite_model)

    还有另外一种利用bazel把模型导出为tflite的办法

    进入tensorflow源码目录,两步编译
    bazel build tensorflow/python/tools:freeze_graph
    bazel build tensorflow/lite/toco:toco
    ./bazel-bin/tensorflow/contrib/lite/toco/toco
    --input_file=/media/bayes/69da5b29-ae56-4feb-93a1-2ce24323aa78/project/model2.pb
    --output_file=/media/bayes/69da5b29-ae56-4feb-93a1-2ce24323aa78/project/model2.tflite
    --input_format=TENSORFLOW_GRAPHDEF
    --output_format=TFLITE
    --inference_type=FLOAT
    --input_shape=1,28,28,3
    --input_array=inputs
    --output_array=prob

    PB模型测试模型准确率

    # -*- coding: utf-8 -*-
    
    """Evaluate the trained CNN model.
    Example Usage:
    ---------------
    python3 infrence_pb.py 
        --frozen_graph_path: Path to model frozen graph.
    """
    
    import numpy as np
    import tensorflow as tf
    
    from captcha.image import ImageCaptcha
    
    flags = tf.app.flags
    flags.DEFINE_string('frozen_graph_path', None, 'Path to model frozen graph.')
    FLAGS = flags.FLAGS
    
    
    def generate_captcha(text='1'):
        capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
        image = capt.generate_image(text)
        image = np.array(image, dtype=np.uint8)
        return image
    
    
    def main(_):
        model_graph = tf.Graph()
        with model_graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(FLAGS.frozen_graph_path, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
        
        with model_graph.as_default():
            with tf.Session(graph=model_graph) as sess:
                inputs = model_graph.get_tensor_by_name('inputs:0')
                classes = model_graph.get_tensor_by_name('classes:0')
                prob = model_graph.get_tensor_by_name('prob:0')
                for i in range(10):
                    label = np.random.randint(0, 10)
                    image = generate_captcha(str(label))
                    image = 
                    image_np = np.expand_dims(image, axis=0)
                    predicted_label,probs = sess.run([classes,prob], 
                                               feed_dict={inputs: image_np})
                    print(predicted_label, ' vs ', label)
                    print(probs)
                
                
    if __name__ == '__main__':
        tf.app.run()

    tflite格式测试模型准确率

    # -*- coding:utf-8 -*-
    import os
    import cv2
    import numpy as np
    import time
    
    import tensorflow as tf
    
    test_image_dir = './test_images/'
    #model_path = "./model/quantize_frozen_graph.tflite"
    model_path = "./model3.tflite"
    
    # Load TFLite model and allocate tensors.
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    
    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    print(str(input_details))
    output_details = interpreter.get_output_details()
    print(str(output_details))
    
    #with tf.Session( ) as sess:
    if 1:
        file_list = os.listdir(test_image_dir)
        
        model_interpreter_time = 0
        start_time = time.time()
        # 遍历文件
        for file in file_list:
            print('=========================')
            full_path = os.path.join(test_image_dir, file)
            print('full_path:{}'.format(full_path))
            
    
            img = cv2.imread(full_path )
            res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
            # 变成长784的一维数据
            #new_img = res_img.reshape((784))
            new_img = np.array(res_img, dtype=np.uint8)
            # 增加一个维度,变为 [1, 784]
            image_np_expanded = np.expand_dims(new_img, axis=0)
            image_np_expanded = image_np_expanded.astype('float32') # 类型也要满足要求
            
            # 填装数据
            model_interpreter_start_time = time.time()
            interpreter.set_tensor(input_details[0]['index'], image_np_expanded)
            
            # 注意注意,我要调用模型了
            interpreter.invoke()
            output_data = interpreter.get_tensor(output_details[0]['index'])
            model_interpreter_time += time.time() - model_interpreter_start_time
            
            # 出来的结果去掉没用的维度
            result = np.squeeze(output_data)
            print('result:{}'.format(result))
            #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
            
            # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
            #print('result:{}'.format( (np.where(result==np.max(result)))[0][0]  ))
        used_time = time.time() - start_time
        print('used_time:{}'.format(used_time))
        print('model_interpreter_time:{}'.format(model_interpreter_time))

    模型训练好以后,接下来要把模型部署到安卓端,其实这步很简单,只要替换安卓代码相应部分即可,安卓代码我会上传到CSDN,大家按需下载即可。那么主要留意更改哪些代码呢

    #模型的输入大小
    private int[] ddims = {1, 3, 28, 28};
    #模型的名称
    private static final String[] PADDLE_MODEL = {
    "model3",
    "mobilenet_quant_v1_224",
    "mobilenet_v1_1.0_224",
    "mobilenet_v2"
    };

    #标签的名称
    BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("cacheLabel1.txt")));
    #模型输出的数据类型,在PC端可以清楚地看到
    float[][] labelProbArray = new float[1][10];

    #输入数据预处理工作是否已经包含在模型里面
    //  imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f));
    // imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f));
    // imgData.putFloat((((val & 0xFF) - 128f) / 128f));
    imgData.putFloat(((val >> 16) & 0xFF) );
    imgData.putFloat(((val >> 8) & 0xFF) );
    imgData.putFloat((val & 0xFF) );

    留一张测试图片,大家可以拿去测试,正确结果应该是0.0,安卓代码地址是这里,CSDN下载请搜索 anquangan

    查看PB模型节点代码

    #coding:utf-8
     
    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    tf.reset_default_graph()  # 重置计算图
    output_graph_path = 'model3.pb'
    with tf.Session() as sess:
     
        tf.global_variables_initializer().run()
        output_graph_def = tf.GraphDef()
        # 获得默认的图
        graph = tf.get_default_graph()
        with open(output_graph_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(output_graph_def, name="")
            # 得到当前图有几个操作节点
            print("%d ops in the final graph." % len(output_graph_def.node))
     
            tensor_name = [tensor.name for tensor in output_graph_def.node]
            print(tensor_name)
            print('---------------------------')
            # 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型
            #summaryWriter = tf.summary.FileWriter('log_graph/', graph)
     
     
            for op in graph.get_operations():
                # print出tensor的name和值
                print(op.name, op.values())
  • 相关阅读:
    webpack基础
    LeetCode232. 用栈实现队列做题笔记
    mysql 时间加减一个月
    leetcode 1381. 设计一个支持增量操作的栈 思路与算法
    LeetCode 141. 环形链表 做题笔记
    leetcode 707. 设计链表 做题笔记
    leetcode 876. 链表的中间结点 做题笔记
    leetcode 143. 重排链表 做题笔记
    leetcode 1365. 有多少小于当前数字的数字 做题笔记
    LeetCode1360. 日期之间隔几天 做题笔记
  • 原文地址:https://www.cnblogs.com/cnugis/p/11910278.html
Copyright © 2011-2022 走看看