zoukankan      html  css  js  c++  java
  • TensorFlow模型部署到服务器---TensorFlow2.0

    前言

    ​ 当一个TensorFlow模型训练出来的时候,为了投入到实际应用,所以就需要部署到服务器上。由于我本次所做的项目是一个javaweb的图像识别项目。所有我就想去寻找一下java调用TensorFlow训练模型的办法。

    image-20210801132207199

    由于TensorFlow很久没更新的缘故,网上的博客大都是18/19年的,并且是基于TensorFlow1.0的,对于现在使用的TensorFlow2.0不太友好。

    下面我简述一下TensorFlow1.0时期的方法:

    1.动态模型生成不便

    需要将训练的.h5模型转换成.pb模型,并且需要自己定义.pb模型的输入输出参数。(pb模型是一种基于动态图的模型)

    pb的生成代码冗长、而且对初学者真滴不太友好

    image-20210801132825083

    相比之下.h5模型的生成代码就一行

    image-20210801133250449

    此外,这个生成pb模型的代码是否能照搬使用,还是一个问题,并且还可能报一些奇奇怪怪的错误。

    2.maven导包不便

    查阅资料发现java上的TensorFlow的jar包都是TensorFlow1.0的

    image-20210801144751826

    现状:

    image-20210801144856860

    并且maven官网上的TensorFlow2.0的api已经改名成了tensorflow-core-api,并且网上相关方面的教程十分难找。由于网上都是导入的1.0的包,自己导入2.0的包之后,详细的调用教程可以说是没有。从上面也可以看出来TensorFlow对java的调用也不怎么重视了。所以这又给学习的途中徒增了很多困难。

    全新思路

    思路一

    用java直接调用训练好的模型很困难,那么我们想办法让java调用python脚本,让python脚本去调用.h5模型会不会更简单呢?

    代码如下

    package com.guard.service;
    
    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    
    public class api_service {
    
        public String recognize(String path){
            //此处的path是图片路径
            Process proc;
            String res = null;
            try {
                System.out.println("接受到的参数"+path);
                String[] cmd = new String[] { "python", "E:\machine_learning\predict.py", path};
                proc = Runtime.getRuntime().exec(cmd);
                BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream()));
                String line = null;
                while ((line = in.readLine()) != null) {
                    System.out.println(line);
                    res = line;
                }
                in.close();
                proc.waitFor();
            } catch (IOException e) {
                e.printStackTrace();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(res+">>>>>>>>>>>");
            return res;
        }
    }
    
    

    但是我们可以看出,这个其实是用java在win上跑了这样一个指令

    image-20210801183239294

    虽然这个确实是一个好办法,但是这个路径参数需要事先知道服务器上的路径,并且在协作开发的时候,每个人的路径和环境就不同,虽然该方法能用,但是我认为还不够好。

    思路二

    我们可以直接用python的flask框架,直接生成一个api接口,就可以远程直接调用TensorFlow训练好的模型进行结果预测。

    image-20210801184122679

    image-20210801184226699

    个人认为,这种方法相较于用java调用命令行,这种方法还是更加直观的

    并且flask仅仅需要加个@app.route的注解就能实现,可谓是十分方便

    下面是模型调用代码

    model.py

    import glob
    import sys
    import os
    import cv2
    import numpy as np
    import tensorflow as tf
    import image_processing
    
    def model_ues(path):
        # 缩放图片大小为100*100
        w = 100
        h = 100
    
        # 测试图像的地址 (改为自己的)
    
        # path_test = "resource/test24.jpg"
        api_token = "fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda"
        path_test = image_processing.download_img(path,api_token)
    
        # 创建保存图像的空列表
        imgs = []
        img = cv2.imread(path_test)
        img = cv2.resize(img, (w, h))
        # 将每张经过处理的图像数据保存在之前创建的imgs空列表当中
        imgs.append(img)
        imgs = np.asarray(imgs, np.float32)
        # print("shape of data:",imgs.shape)
    
        # 导入模型
        model = tf.keras.models.load_model(r"resource/rice_0.93.h5")
        # 创建图像标签列表
        rice_dict = {0: 'Rice blast', 1: 'Rice fleck',
                 2: 'Rice koji disease', 3: 'Sheath blight'}
    
        # 将图像导入模型进行预测
        prediction = model.predict_classes(imgs)
        # prediction = np.argmax(model.predict(imgs), axis=-1)
    
    
        # 绘制预测图像
        for i in range(np.size(prediction)):
            # 打印每张图像的预测结果
            print(rice_dict[prediction[i]])
        return rice_dict[prediction[0]]
    
    

    为了实现图片外链接受,下面是图片下载脚本

    image_processing.py

    # coding: utf8
    import requests
    import random
    
    def download_img(img_url, api_token):
        print (img_url)
        header = {"Authorization": "Bearer " + api_token} # 设置http header,视情况加需要的条目,这里的token是用来鉴权的一种方式
        r = requests.get(img_url, headers=header, stream=True)
        print(r.status_code) # 返回状态码
        file_img = 'resource/img.png'
    
        # file_img = 'resource/'
        print(file_img)
        if r.status_code == 200:
            open(file_img, 'wb').write(r.content) # 将内容写入图片
            print("done")
        del r
    
        return file_img
    # if __name__ == '__main__':
    #     # 下载要的图片
    #     img_url = "https://z3.ax1x.com/2021/07/27/W5l6Qe.png"
    #     api_token = "fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda"
    #     download_img(img_url, api_token)
    

    主程序脚本

    app.py

    from flask import Flask,render_template, url_for, request, json,jsonify
    import model
    app = Flask(__name__)
    
    #设置编码
    app.config['JSON_AS_ASCII'] = False
    
    @app.route('/test')
    def hello_world():
    
        return "hello world"
    
    @app.route('/predict', methods=['GET', 'POST'])
    def form_data():
        my_path = request.form['path']
        print(my_path)
        str = model.model_ues(my_path)
        print("http://127.0.0.1:5000/predict")
        return jsonify({'result':str,'msg':'200'})
    
    if __name__ == '__main__':
        app.run()
    

    数据解析

    虽然我们能够通过postman进行测试接受到回传的结果,但是我们要怎么用java实现呢??

    1.使用postman生成大致代码框架(postman生成的代码可能不能直接运行)

    image-20210801185332234

    这里我选用的是java-okhttp的方法,但其实使用Unirest写出来的代码更加简洁易懂。

    public class Get_result {
    
        public  String getResult(String path) throws IOException {
    //        String path = "https://i.loli.net/2021/07/29/badDNR2OCironUf.jpg";
            OkHttpClient client = new OkHttpClient().newBuilder()
                    .build();
            MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
            RequestBody body = RequestBody.create(mediaType, "path="+path);
            Request request = new Request.Builder()
                    .url("http://127.0.0.1:8000/predict")
                    .method("POST", body)
                    .addHeader("Content-Type", "application/x-www-form-urlencoded")
                    .build();
            Response response = client.newCall(request).execute();
            String result = response.body().string();
            System.out.println(result);
                }
    }
    
    
    {
      "msg": "200",
      "result": "Rice fleck"
    }
    

    获取到json数据之后,就需要对json数据进行解析

    java上的解析原理是,先按照json编写一个类,之后用Gson对接受到的数据按照这个类进行规范化

    (这里可以用GsonFormatPlus插件来自动生成这个实体类)

    //Rice_result.java---为该json的实体类
    package com.guard.tool;
    
    import lombok.Data;
    import lombok.NoArgsConstructor;
    
    @NoArgsConstructor
    @Data
    public class Rice_result {
        private String msg;
        private String result;
    
    }
    

    下面是数据解析代码(和上面的okhttp获取json数据的代码连起来看)

    //json数据解析
            Gson gson = new Gson();
            java.lang.reflect.Type type = new TypeToken<Rice_result>(){}.getType();
            Rice_result rice_result = gson.fromJson(result, type);
            System.out.println(rice_result);
            if("200".equals(rice_result.getMsg())){
    //            System.out.println(rice_result.getResult());
                return Rice_result.convertdata(rice_result.getResult());
            }else {
    //            System.out.println("获取结果出错!!");
                return "获取结果出错!!";
            }
    

    这样的话就可以进行json数据的解析了。

    图链制作

    由于需要使用java发送post请求给flask的预测端口,那么就需要把本地上传的数据做成图链,把图链作为数据传给flask的预测端口,从而来接收结果。

    由于前端js的知识大多遗忘,这里就选用了用java来发送一个post请求,获得回传的信息。

    这里我使用的是sm.ms的图床(该图床无需登录,且速度快,算得上是一个好的选择)

    //sm.ms的使用方法,建议看官方文档
    package com.guard.tool;
    
    import com.google.gson.Gson;
    import com.google.gson.reflect.TypeToken;
    import okhttp3.*;
    
    import java.io.File;
    import java.io.IOException;
    
    
    public class CloudUpload {
    
      public String toUrl(String path) throws IOException {
    
    //    String file_path = "E:/machine_learning/test8.jpg";
    
        String file_path = path;
        OkHttpClient client = new OkHttpClient().newBuilder()
                .build();
        MediaType mediaType = MediaType.parse("multipart/form-data");
        RequestBody body = new MultipartBody.Builder().setType(MultipartBody.FORM)
                .addFormDataPart("smfile",file_path,
                        RequestBody.create(MediaType.parse("application/octet-stream"),
                                new File(file_path)))
                .addFormDataPart("format","json")
                .build();
        Request request = new Request.Builder()
                .url("https://sm.ms/api/v2/upload")
                .method("POST", body)
                .addHeader("Content-Type", "multipart/form-data")
                .addHeader("Authorization", "TlxzRSaVJj0o7HFZOd9sgdf4Jl60RA00")
       //这里的user-agent和Cookie需要自己打开网站,到网站的页面去拿取
                .addHeader("user-agent","Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36")
                .addHeader("Cookie", "SMMSrememberme=42417%3A10e8e9cb5281082b493fdee73381aeb2dca0bd3d; PHPSESSID=1gjog2em3ogof23vrqi79vd41m; SM_FC=runWNk3mPIiL8mzl%2FrlEfzM940LRKjLm182cm2qDrm4%3D")
                .build();
        Response response = client.newCall(request).execute();
        String result = response.body().string();
        System.out.println(result);
    //    String result = response.body().string();
    
        Gson gson = new Gson();
        java.lang.reflect.Type type = new TypeToken<Image_data>(){}.getType();
        Image_data imge_data = gson.fromJson(result, type);
        System.out.println(imge_data);
        if (imge_data.getSuccess()){
          System.out.println(imge_data.getData().getUrl());
          return imge_data.getData().getUrl();
        }
        else{
          System.out.println("图片已经上传过一次!!");
          System.out.println(imge_data.getImages());
          return imge_data.getImages();
        }
      }
    }
    
    

    回传的json结果--这个就需要使用上面的插件来进行处理

    {
        "success": true,
        "code": "success",
        "message": "Upload success.",
        "data": {
            "file_id": 0,
            "width": 192,
            "height": 454,
            "filename": "test25.jpg",
            "storename": "xICPNzFsfth5uJk.png",
            "size": 124993,
            "path": "/2021/08/01/xICPNzFsfth5uJk.png",
            "hash": "2exIdQGvBru46RKMyNjg3DhCTO",
            "url": "https://i.loli.net/2021/08/01/xICPNzFsfth5uJk.png",
            "delete": "https://sm.ms/delete/2exIdQGvBru46RKMyNjg3DhCTO",
            "page": "https://sm.ms/image/xICPNzFsfth5uJk"
        },
        "RequestId": "9BFE9DEB-8370-44C8-A8AF-AAB2DB753A18"
    }
    

    总结

    以上就是我这次在小组编写<基于CNN图像分类的水稻病虫害识别>这个项目中的收获。在此记录下学习路上踩过的一些坑和一些解决方法。

  • 相关阅读:
    使用SAEPython在虾米网自动签到
    Python的SimpleHTTPServer
    人人控 40行python搭出来的远程控制程序 支持插件
    吐血解决python中文写入文件问题
    JavaScript 响应选中文字并获取
    对WPS的吐槽
    Powerful Sleep 笔记[如何睡得好]
    Python极轻量HTTP服务器&框架 Bottle
    打印二维数组
    电梯的测试用例
  • 原文地址:https://www.cnblogs.com/printwangzhe/p/15087444.html
Copyright © 2011-2022 走看看