zoukankan      html  css  js  c++  java
  • MindSpore模型推理

    MindSpore模型推理

    如果想在应用中使用自定义的MindSpore Lite模型,需要告知推理器模型所在的位置。推理器加载模型的方式有以下三种:

    • 加载本地模型。
    • 加载远程模型。
    • 混合加载本地和远程模型。

    加载模型

    方式一:加载并初始化本地模型。

    1. 加载模型。
    • Assets目录
    1. MLCustomLocalModel localModel = new MLCustomLocalModel.Factory("yourmodelname")
    2.         .setAssetPathFile("assetpathname")
    3.         .create();
    • 自定义目录

     .         MLCustomLocalModel localModel = new MLCustomLocalModel.Factory("yourmodelname")

    1.         .setLocalFullPathFile("sdfullpathname")
    2.         .create();
    3. 根据模型创建推理器。
    4. final MLModelExecutorSettings settings = new MLModelExecutorSettings.Factory(localModel).create();
    5. final MLModelExecutor modelExecutor = MLModelExecutor.getInstance(settings);
    6. // 调用模型推理,实现细节见下节模型推理;Bitmap待处理的图片。
    7. executorImpl(modelExecutor, bitmap);

    方式二:加载并初始化远程模型。

    加载远程模型时需先判断远程模型是否已经下载完成:

    1.     final MLCustomRemoteModel remoteModel =new MLCustomRemoteModel.Factory("yourmodelname")
    2.         .create();
    3.     MLLocalModelManager.getInstance()
    4.         .isModelExist(remoteModel)
    5.         .addOnSuccessListener(new OnSuccessListener<Boolean>() {
    6.             @Override
    7.             public void onSuccess(Boolean isDownload) {
    8.                 if (isDownload) {
    9.                     final MLModelExecutorSettings settings =
    10. 10.                             new MLModelExecutorSettings.Factory(remoteModel).create();
    11. 11.                     final MLModelExecutor modelExecutor = MLModelExecutor.getInstance(settings);
    12. 12.                     executorImpl(modelExecutor, bitmap);
    13. 13.                 }
    14. 14.             }
    15. 15.         })
    16. 16.         .addOnFailureListener(new OnFailureListener() {
    17. 17.             @Override
    18. 18.             public void onFailure(Exception e) {
    19. 19.                 // 异常处理。
    20. 20.             }
    21. 21.         });

    方式三:混合加载本地和远程模型。推荐使用这种方式,此方法可以确保当远程模型未下载时加载本地模型。

    1.     localModel = new MLCustomLocalModel.Factory("localModelName")
    2.         .setAssetPathFile("assetpathname")
    3.         .create();
    4.     remoteModel =new MLCustomRemoteModel.Factory("yourremotemodelname").create();
    5.     MLLocalModelManager.getInstance()
    6.         // 判断远程模型是否存在。
    7.         .isModelExist(remoteModel)
    8.         .addOnSuccessListener(new OnSuccessListener<Boolean>() {
    9.             @Override
    10. 10.             public void onSuccess(Boolean isDownloaded) {
    11. 11.                 MLModelExecutorSettings settings;
    12. 12.                 // 如果远程模型存在,优先加载本地已有的远程模型,否则加载本地已有的本地模型。
    13. 13.                 if (isDownloaded) {
    14. 14.                     settings = new MLModelExecutorSettings.Factory(remoteModel).create();
    15. 15.                 } else {
    16. 16.                     settings = new MLModelExecutorSettings.Factory(localModel).create();
    17. 17.                 }
    18. 18.                 final MLModelExecutor modelExecutor = MLModelExecutor.getInstance(settings);
    19. 19.                 executorImpl(modelExecutor, bitmap);
    20. 20.             }
    21. 21.         })
    22. 22.         .addOnFailureListener(new OnFailureListener() {
    23. 23.             @Override
    24. 24.             public void onFailure(Exception e) {
    25. 25.                 // 异常处理。
    26. 26.             }
    27. 27.         });

    模型推理器进行推理

    本章示例中的“executorImpl”方法为模型推理的详细流程,声明如下:

    1. void executorImpl(final MLModelExecutor modelExecutor, Bitmap bitmap)

    以下示例会借助“executorImpl”方法详细演示推理器调用的自定义模型推理的整个过程,此方法内主要包含如下关键处理流程:

    1. 设置输入输出格式。

    需要知道模型的输入输出格式。通过MLModelInputOutputSettings.Factory把输入输出格式设置到模型推理器。比如,一个图片分类模型的输入格式为一个float类型的1x224x224x3数组(表示只推理一张大小为224x224的三通道 (RGB)图片),输出格式为一个长度为1001的float型列表(每个值表示该图片经模型推理后1001个类别中各类别的可能性)。对于此模型,请按照以下方式设置输入输出格式:

    1. inOutSettings = new MLModelInputOutputSettings.Factory()
    2.     .setInputFormat(0, MLModelDataType.FLOAT32, new int[] {1, 224, 224, 3})
    3.     .setOutputFormat(0, MLModelDataType.FLOAT32, new int[] {1, 1001})
    4.     .create();
    5. 把图片数据输入到推理器。

    注意

    当前版本MindSpore生成的模型使用的数据格式与tflite类型的模型使用的数据格式相同,均为NHWC,caffe类型的模型使用的数据格式为NCHW。若需要将模型由caffe转换到MindSpore,请设置为NHWC格式。如下NHWC示例:1*224*224*3表示一张(batch N),大小为224(height H)*224(width W),3通道(channels C)的图片。

    1. private void executorImpl(final MLModelExecutor modelExecutor, Bitmap bitmap){
    2.     // 准备输入数据。
    3.     final Bitmap inputBitmap = Bitmap.createScaledBitmap(srcBitmap, 224, 224, true);
    4.     final float[][][][] input = new float[1][224][224][3];
    5.     for (int i = 0; i < 224; i++) {
    6.         for (int j = 0; j < 224; j++) {
    7.             int pixel = inputBitmap.getPixel(i, j);
    8.             input[batchNum][j][i][0] = (Color.red(pixel) - 127) / 128.0f;
    9.             input[batchNum][j][i][1] = (Color.green(pixel) - 127) / 128.0f;
    10. 10.             input[batchNum][j][i][2] = (Color.blue(pixel) - 127) / 128.0f;
    11. 11.         }
    12. 12.     }
    13. 13.     MLModelInputs inputs = null;
    14. 14.     try {
    15. 15.         inputs = new MLModelInputs.Factory().add(input).create();
    16. 16.         // 若模型需要多路输入,需要多次调用add()以便图片数据能够一次输入到推理器。
    17. 17.     } catch (MLException e) {
    18. 18.         // 处理输入数据格式化异常。
    19. 19.     }
    20. 20.  

    21. // 执行推理。可以通过“addOnSuccessListener”来监听推理成功,在“onSuccess”回调中处理推理成功。同时,可以通过“addOnFailureListener”来监听推理失败,在“onFailure”中处理推理失败。

    1. 22.     modelExecutor.exec(inputs, inOutSettings).addOnSuccessListener(new OnSuccessListener<MLModelOutputs>() {
    2. 23.         @Override
    3. 24.         public void onSuccess(MLModelOutputs mlModelOutputs) {
    4. 25.             float[][] output = mlModelOutputs.getOutput(0);
    5. 26.                 // 这里推理的返回结果在output数组里,可以进一步处理。
    6. 27.                 }
    7. 28.         }).addOnFailureListener(new OnFailureListener() {
    8. 29.         @Override
    9. 30.         public void onFailure(Exception e) {
    10. 31.             // 推理异常。
    11. 32.         }
    12. 33.     });

    34. }

    人工智能芯片与自动驾驶
  • 相关阅读:
    python中if __name__ == '__main__': 的解析
    CPPUTest 单元测试框架(针对 C 单元测试的使用说明)
    哈希表详解
    使用RSS提升DPDK应用的性能(转)
    DPDK内存管理-----rte_mbuf(转)
    DPDK内存管理-----(二)rte_mempool内存管理
    DPDK内存管理(1)(转)
    Scala + IntelliJ IDEA
    什么是消息队列中间件
    微信小程序直播
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/14320004.html
Copyright © 2011-2022 走看看