今天我们尝试了paddle-lite框架,发现意外可行
那么我们来写一下图像预处理函数:
package com.example.ironfarm; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.graphics.Matrix; import android.util.Log; import com.baidu.paddle.lite.MobileConfig; import com.baidu.paddle.lite.PaddlePredictor; import com.baidu.paddle.lite.PowerMode; import com.baidu.paddle.lite.Tensor; import java.io.File; import java.io.FileInputStream; import java.util.Arrays; public class PaddleLiteClassification { private static final String TAG = PaddleLiteClassification.class.getName(); private PaddlePredictor paddlePredictor; private Tensor inputTensor; private long[] inputShape = new long[]{1, 3, 224, 224}; private static float[] scale = new float[]{1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f}; private static float[] inputMean = new float[]{0.485f, 0.456f, 0.406f}; private static float[] inputStd = new float[]{0.229f, 0.224f, 0.225f}; private static final int NUM_THREADS = 4; /** * @param modelPath model path */ public PaddleLiteClassification(String modelPath) throws Exception { File file = new File(modelPath); if (!file.exists()) { throw new Exception("model file is not exists!"); } try { MobileConfig config = new MobileConfig(); config.setModelFromFile(modelPath); config.setThreads(NUM_THREADS); config.setPowerMode(PowerMode.LITE_POWER_HIGH); paddlePredictor = PaddlePredictor.createPaddlePredictor(config); inputTensor = paddlePredictor.getInput(0); inputTensor.resize(inputShape); } catch (Exception e) { e.printStackTrace(); throw new Exception("load model fail!"); } } public float[] predictImage(String image_path) throws Exception { if (!new File(image_path).exists()) { throw new Exception("image file is not exists!"); } FileInputStream fis = new FileInputStream(image_path); Bitmap bitmap = BitmapFactory.decodeStream(fis); float[] result = predictImage(bitmap); if (bitmap.isRecycled()) { bitmap.recycle(); } return result; } public float[] predictImage(Bitmap bitmap) throws Exception { return predict(bitmap); } public static int getMaxResult(float[] result) { float probability = 0; int r = 0; for (int i = 0; i < result.length; i++) { if (probability < result[i]) { probability = result[i]; r = i; } } return r; } private static float[] getScaledMatrix(Bitmap bitmap, int desWidth, int desHeight) { float[] dataBuf = new float[3 * desWidth * desHeight]; int rIndex; int gIndex; int bIndex; int[] pixels = new int[desWidth * desHeight]; Bitmap bm = Bitmap.createScaledBitmap(bitmap, desWidth, desHeight, false); bm.getPixels(pixels, 0, desWidth, 0, 0, desWidth, desHeight); int j = 0; int k = 0; for (int i = 0; i < pixels.length; i++) { int clr = pixels[i]; j = i / desHeight; k = i % desWidth; rIndex = j * desWidth + k; gIndex = rIndex + desHeight * desWidth; bIndex = gIndex + desHeight * desWidth; // 转成RGB通道顺序 dataBuf[bIndex] = (float) (((clr & 0x00ff0000) >> 16) / 255.0); dataBuf[gIndex] = (float) (((clr & 0x0000ff00) >> 8) / 255.0); dataBuf[rIndex] = (float) (((clr & 0x000000ff)) / 255.0); } if (bm.isRecycled()) { bm.recycle(); } Log.d("sss", Arrays.toString(dataBuf)); return dataBuf; } private Bitmap getScaleBitmap(Bitmap bitmap) { int bmpWidth = bitmap.getWidth(); int bmpHeight = bitmap.getHeight(); int size = (int) inputShape[2]; float scaleWidth = (float) size / bitmap.getWidth(); float scaleHeight = (float) size / bitmap.getHeight(); Matrix matrix = new Matrix(); matrix.postScale(scaleWidth, scaleHeight); return Bitmap.createBitmap(bitmap, 0, 0, bmpWidth, bmpHeight, matrix, true); } private float[] predict(Bitmap bmp) throws Exception { Bitmap b = getScaleBitmap(bmp); float[] inputData = getScaledMatrix(b, (int) inputShape[2], (int) inputShape[3]); b.recycle(); bmp.recycle(); inputTensor.setData(inputData); try { paddlePredictor.run(); } catch (Exception e) { throw new Exception("predict image fail! log:" + e); } Tensor outputTensor = paddlePredictor.getOutput(0); float[] result = outputTensor.getFloatData(); Log.d(TAG, Arrays.toString(result)); int l = getMaxResult(result); return new float[]{l, result[l]}; // return result; } }
这个函数的核心就是对拍照/相册的图像进行各种预处理转化成float数组……其实是通过rgb颜色来进行的预测
好,今天就努力到这里啦