zoukankan      html  css  js  c++  java
  • Python+Android进行TensorFlow开发

    Tensorflow是Google开源的一套机器学习框架,支持GPU、CPU、Android等多种计算平台。本文将介绍在Tensorflow在Android上的使用。

    Android使用Tensorflow框架需要引入两个文件libtensorflow_inference.so、libandroid_tensorflow_inference_java.jar。这两个文件可以使用官方预编译的文件。如果预编译的so不满足要求(比如不支持训练模型中的某些操作符运算),也可以自己通过bazel编译生成这两个文件。
    将libandroid_tensorflow_inference_java.jar放在app下的libs目录下,so文件命名为libtensorflow_jni.so放在src/main/jniLibs目录下对应的ABI文件夹下。目录结构如下:
    在这里插入图片描述
    Android目录结构

    同时在app的build.gradle中的dependencies模块下添加如下配置:

    dependencies {
        ...
        compile files('libs/libandroid_tensorflow_inference_java.jar')
        ...
    }
    

    使用tensorflow框架进行机器学习分为四个步骤:

    • 构造神经网络

    • 训练神经网络模型

    • 将训练好的模型输出为pb文件

    • ndroid上加载pb模型进行计算

    前三步是模型的构造,我们通过python实现,下面给出了一个二分类的简单模型的构造过程,首先是训练过程:

    # -*-coding:utf-8 -*-
    from __future__ import print_function
    import os
    import tensorflow as tf
    from numpy.random import RandomState
    
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    """
    训练模型
    """
    def train():
        # 定义训练数据集batch大小为8
        batch_size = 8
    
        # 定义神经网络参数,参数体现出神经网络结构,一个输入层,一个输出层,一个隐藏层
        w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1), name="w1_val")
        w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1), name="w2_val")
    
        # 定义输入输出格式
        x = tf.placeholder(tf.float32, shape=(None, 2), name='x_input')
        y_ = tf.placeholder(tf.float32, shape=(None, 1))
    
        # 定义神经网络前向传播过程
        a = tf.matmul(x, w1)
        y = tf.matmul(a, w2, name="cal_node")
    
        # 定义交叉熵和反向传播算法
        cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
        train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy)
    
        # 生成随机训练集
        rdm = RandomState(1)
        dataset_size = 128
    
        # 定义映射关系
        X = rdm.rand(dataset_size, 2)
        Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]
    
        with tf.Session() as sess:
            # 初始化所有参数
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
    
            # print sess.run(w1)
            # print sess.run(w2)
    
            STEPS = 500
            for i in range(STEPS):
                start = (i * batch_size) % dataset_size
                end = min(start + batch_size, dataset_size)
    
                # 训练神经网络,更新神经网络参数
                sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
    
                if i % 100 == 0:
                    total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
                    print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy))
    
                print(sess.run(w1))
                print(sess.run(w2))
    
            # 保存check point
            saver = tf.train.Saver(tf.trainable_variables())
            saver.save(sess, './model/checpt')
    

    上面的代码首先定义神经网络,初始化训练数据,进行500次训练过程,并将训练结果checkpoints保存到model文件夹下,checkpoints包含了训练模型得到的参数信息,共生成四个相关的文件,如下图:
    checkpoint相关文件

    由于checkpoint文件众多,为了方便使用,我们通过下面的代码将它们生成一个pb文件,在android上只需要这个pb文件即可使用这个训练好的模型:

    """
    存储pb模型
    """
    def dump_graph_to_pb(pb_path):
        with tf.Session() as sess:
            check_point = tf.train.get_checkpoint_state("./model/")
            if check_point:
                saver = tf.train.import_meta_graph(check_point.model_checkpoint_path + '.meta')
                saver.restore(sess, check_point.model_checkpoint_path)
            else:
                raise ValueError("Model load failed from {}".format(check_point.model_checkpoint_path))
    
            graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), "cal_node".split(","))
    
            with tf.gfile.GFile(pb_path, "wb") as f:
                f.write(graph_def.SerializeToString())
    

    拿到生成的pb模型,我们可以在android上使用了。将pb文件在这main/assets下:
    在这里插入图片描述

    接下来就可以载入pb,进行计算了:

    public class MainActivity extends AppCompatActivity {
        private Graph graph_;
        private Session session_;
        private AssetManager assetManager;
    
        private static ExecutorService executorService;
        private static Handler handler;
        @Override
        protected void onCreate(Bundle savedInstanceState) {
            super.onCreate(savedInstanceState);
            setContentView(R.layout.activity_main);
    
            executorService = Executors.newFixedThreadPool(5);
    
            // 初始化tensorflow
            initTensorFlow("outmodel.pb");
    
            // 使用tensorflow进行计算
            runTensorFlow();
        }
        ...
    }
    

    通过如下方式载入pb模型,初始化tensorflow:

    private boolean initTensorFlow(String modelFile) {
            assetManager = getAssets();
            // 新建Graph
            graph_ = new Graph();
    
            InputStream is = null;
            try {
                // 读取Assets pb文件
                is = assetManager.open(modelFile);
            } catch (IOException e) {
                e.printStackTrace();
                return false;
            }
    
            try {
                // 加载pb到Graph
                TensorUtil.loadGraph(is, graph_);
                is.close();
            } catch (IOException e) {
                e.printStackTrace();
                return false;
            }
            // 初始化session
            session_ = new Session(graph_);
            if (session_ == null) {
                return false;
            }
    
            return true;
        }
    

    然后就可以使用tensorflow API进行运算了:

    private void runTensorFlow() {
            executorService.execute(generatePredictRunnable(handler));
        }
    
        private Runnable generatePredictRunnable(Handler handler) {
            return new Runnable() {
                @Override
                public void run() {
                    float[][] input = new float[1][2];
    
                    input[0][0] = 1;
                    input[0][1] = 2;
    
                    // 定义输入tensor
                    Tensor inputTensor = Tensor.create(input);
    
                    // 指定输入,输出节点,运行并得到结果
                    Tensor resultTensor = session_.runner()
                            .feed("x_input", inputTensor)
                            .fetch("cal_node")
                            .run()
                            .get(0);
    
                    float[][] dst = new float[1][1];
                    resultTensor.copyTo(dst);
    
                    // 处理结果
                    ArrayList<Float> resultList = new ArrayList<>();
                    for (float val : dst[0]) {
                        if (val != 0) {
                            resultList.add(val);
                        } else {
                            break;
                        }
                    }
                }
            };
        }
    

    上面就是通过python训练机器学习模型,并在android平台进行调用的完整流程。

    原创作者:JackMeGo,原文链接:https://www.jianshu.com/p/eef4ab014a12

    欢迎关注我的微信公众号「码农突围」,分享Python、Java、大数据、机器学习、人工智能等技术,关注码农技术提升•职场突围•思维跃迁,20万+码农成长充电第一站,陪有梦想的你一起成长。

  • 相关阅读:
    call()和apply( )
    String.prototype.replace( )
    Global对象和浏览器的window对象
    ros qt 項目增加新的线程
    ubuntu18.04 在QT中添加ros环境搭建 亲测可用
    ubuntu18.04系统下安装Nvidia驱动 + cuda10.0 + cudnn7
    【ROS学习】发布自定义数据结构的话题
    Autoware快速使用资料
    TX2-ubuntu无外接显示器远程桌面时分辨率过低
    Jetson TX2 安装 远程桌面软件 NoMachine
  • 原文地址:https://www.cnblogs.com/hejunlin/p/12507132.html
Copyright © 2011-2022 走看看