zoukankan      html  css  js  c++  java
  • Java 版本tensorflow模型推理实现(基于bert命名实体、基于transform文本分类)

    最近在做文本分类任务,由于在实际工程中需要用服务对外提供功能,故采用java调用pb模型完成推理,特将过程记录如下:

    1. transform文本分类

    package com.techwolf.transformer;
    
    import com.alibaba.fastjson.*;
    import com.alibaba.fastjson.parser.Feature;
    import org.tensorflow.Graph;
    import org.tensorflow.Session;
    import org.tensorflow.Tensor;
    //import com.alibaba.fastjson.JSONPObject;
    
    //import org.json.JSONObject;
    
    import java.io.*;
    import java.nio.file.Files;
    import java.nio.file.Path;
    import java.nio.file.Paths;
    import java.util.*;
    
    public class JobPredict {
        private static String jsonPath = "src/main/resources/resource.json";
        private static String modelPath = "src/main/resources/model.pb";
        private static Map<String, Object> positionToFeature = new HashMap<String, Object>();
        private static Map<String, Object> jobMapping = new HashMap<String, Object>();
        private static Map<String, Object> mergeMapping = new HashMap<String, Object>();
        private static Map<String, Object> featureToId = new HashMap<String, Object>();
        private static Map<String, Object> idToCode = new HashMap<String, Object>();
        private static Map<String, Object> codeToLabel = new HashMap<String, Object>();
    
        public static String readJsonFile(String fileName) throws FileNotFoundException {
            String jsonStr = "";
            try {
                File jsonFile = new File(fileName);
                FileReader fileReader = new FileReader(jsonFile);
                Reader reader = new InputStreamReader(new FileInputStream(jsonFile), "utf-8");
                int ch = 0;
                StringBuffer sb = new StringBuffer();
                while ((ch = reader.read()) != -1) {
                    sb.append((char) ch);
                }
                fileReader.close();
                reader.close();
                jsonStr = sb.toString();
                return jsonStr;
            } catch (IOException e) {
                e.printStackTrace();
                return null;
            }
        }
    
        private static Map<String, Object> jsonTOMap(JSONObject jsobj) {
            Map<String, Object> data = new HashMap<String, Object>();
            Iterator it = jsobj.entrySet().iterator();
            while (it.hasNext()) {
                Map.Entry<String, Object> entry = (Map.Entry<String, Object>) it.next();
                data.put(entry.getKey(), entry.getValue());
            }
            return data;
        }
    
        private static void getConfig() throws FileNotFoundException {
            String jsonStr = readJsonFile(jsonPath);
            JSONObject obj = JSON.parseObject(jsonStr);
    
            positionToFeature = jsonTOMap(obj.getJSONObject("position2feature"));
            featureToId = jsonTOMap(obj.getJSONObject("feature2id"));
            jobMapping = jsonTOMap(obj.getJSONObject("job_mapping"));
            mergeMapping = jsonTOMap(obj.getJSONObject("merge_mapping"));
            idToCode = jsonTOMap(obj.getJSONObject("id2position"));
            codeToLabel = jsonTOMap(obj.getJSONObject("position_mapping"));
            System.out.println("config data loaded!");
        }
    
        public static String convert(String utfString) {
            StringBuilder sb = new StringBuilder();
            int i = -1;
            int pos = 0;
            int iint = 0;
            while ((i = utfString.indexOf("\u", pos)) != -1) {
                String sd = utfString.substring(pos, i);
                sb.append(sd);
                iint = i + 5;
    
                if (iint < utfString.length()) {
                    pos = i + 6;
                    sb.append((char) Integer.parseInt(utfString.substring(i + 2, i + 6), 16));
                }
            }
            String endStr = utfString.substring(iint + 1, utfString.length());
            return sb + "" + endStr;
        }
    
        private static Map<String, List> getCodeAndScore(JSONArray jsonArray) throws FileNotFoundException {
            List<Integer> codes = new ArrayList<Integer>();
            List<Float> scores = new ArrayList<Float>();
            Integer codeFlag = -1;
            float scoreFlag = (float) .0;
    
            for (int i = 0; i < jsonArray.size(); i++) {
                JSONObject skillsItem = (JSONObject) jsonArray.get(i);
                String code = (skillsItem.get("code")).toString();
                Float score = Float.parseFloat((String) skillsItem.get("score"));
                boolean isReplace = mergeMapping.containsKey(code);
                if (isReplace) {
                    code = (mergeMapping.get(code)).toString();
                    System.out.println("replace id" + code);
                }
                String position = (jobMapping.get(code)).toString();
                Integer featSeq = (Integer) positionToFeature.get(position);
                if (featSeq == null) {
                    codes.add((Integer) featureToId.get(codeFlag.toString()));
                    scores.add(scoreFlag);
                } else {
                    Integer x = (Integer) featureToId.get(featSeq.toString());
                    codes.add((Integer) featureToId.get(featSeq.toString()));
                    scores.add(score);
                }
            }
            if (jsonArray.size() < 3) {
                for(int i=0; i< (3-jsonArray.size()); i++) {
                    codes.add((Integer) featureToId.get(codeFlag.toString()));
                    scores.add(scoreFlag);
                }
            }
            Map<String, List> result = new HashMap<String, List>();
            result.put("codes", codes);
            result.put("scores", scores);
            return result;
        }
    
        private static byte[] readAllByteOrExit(Path path){
            try{
                return Files.readAllBytes(path);
            }catch (IOException e){
                System.out.println("Failed to read[" + path + "]:" + e.getMessage());
                System.exit(1);
            }
            return null;
        }
    
        private static Map<String, List> getDataContent(String testFile) throws FileNotFoundException {
            String jsonStr = readJsonFile(testFile);
            JSONObject obj = JSON.parseObject(jsonStr, Feature.OrderedField);
            JSONObject objNew = JSON.parseObject(obj.toJSONString(), Feature.OrderedField);
            ArrayList<List> sampleCode = new ArrayList<List>();
            ArrayList<List> sampleScore = new ArrayList<List>();
            Map<String, List> samples = new HashMap<String, List>();
    
            for (String userId: objNew.keySet()) {
                ArrayList<List> codeList = new ArrayList<List>();
                ArrayList<Double> scoresList = new ArrayList<Double>();
                JSONObject itemTags = (JSONObject) ((JSONObject)((JSONObject)objNew.get(userId)).get("_source")).get("tags");
                JSONArray skills = (JSONArray) itemTags.get("skills");
                JSONArray title = (JSONArray) itemTags.get("title");
                JSONArray desc = (JSONArray) itemTags.get("desc");
                Map<String, List> skillsResult = getCodeAndScore(skills);
                Map<String, List> titleResult = getCodeAndScore(title);
                Map<String, List> descResult = getCodeAndScore(desc);
                codeList.addAll(skillsResult.get("codes"));
                codeList.addAll(titleResult.get("codes"));
                codeList.addAll(descResult.get("codes"));
                scoresList.addAll(skillsResult.get("scores"));
                scoresList.addAll(titleResult.get("scores"));
                scoresList.addAll(descResult.get("scores"));
                sampleCode.add(codeList);
                sampleScore.add(scoresList);
            }
            samples.put("sampleCode", sampleCode);
            samples.put("sampleScore", sampleScore);
            System.out.println("ok! sample feature created.");
            return samples;
        }
    
        public static int[] arraySort(float[] arr, boolean desc) {
            float temp;
            int index;
            int k = arr.length;
            int[] Index = new int[k];
            for (int i = 0; i < k; i++) {
                Index[i] = i;
            }
    
            for (int i = 0; i < arr.length; i++) {
                for (int j = 0; j < arr.length - i - 1; j++) {
                    if (desc) {
                        if (arr[j] < arr[j + 1]) {
                            temp = arr[j];
                            arr[j] = arr[j + 1];
                            arr[j + 1] = temp;
    
                            index = Index[j];
                            Index[j] = Index[j + 1];
                            Index[j + 1] = index;
                        }
                    } else {
                        if (arr[j] > arr[j + 1]) {
                            temp = arr[j];
                            arr[j] = arr[j + 1];
                            arr[j + 1] = temp;
    
                            index = Index[j];
                            Index[j] = Index[j + 1];
                            Index[j + 1] = index;
                        }
                    }
                }
            }
            return Index;
        }
    
    
        private static void featToTensor(float[][][] indexes, int[][] codes, float[][] scores, Map<String, List> data) {
    
            List<Integer> featCode = data.get("sampleCode");
            List<Float> featScore = data.get("sampleScore");
            int size = 9;
            for(int i=0; i < featCode.size(); i++) {
                Object eachCode = featCode.get(i);
                Object eachScore = featScore.get(i);
                float [][] positionResult = new float[size][];
                for(int step=0; step < size; step++) {
                    float[] positionVector = new float[size];
                    positionVector[step] = 1;
                    positionResult[step] = positionVector;
                }
                indexes[i] = positionResult;
                Integer[] targetInter = ((List<Integer>)eachCode).toArray(new Integer[size]);
                int[] codeResult = Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
                Float[] targetFloat = ((List<Float>)eachScore).toArray(new Float[size]);
                double[] scoreResult = Arrays.stream(targetFloat).mapToDouble(Double::valueOf).toArray();
                float[] scoreFloat = new float[size];
                for(int j=0; j < scoreResult.length; j++) {
                    scoreFloat[j] = (float) scoreResult[j];
                }
                System.arraycopy(codeResult,0,codes[i], 0, codeResult.length);
                System.arraycopy(scoreFloat,0,scores[i], 0, scoreResult.length);
    
            }
    
        }
    
    
        private static List<HashMap<String, Float>> modelInfer(Map<String, List> data) {
    
            int batchSize = data.get("sampleCode").size();
            int padLength = 9;
            int returnNum = 5;
            int classNum = 868;
            float[][][] indexes = new float[batchSize][padLength][padLength];
            int[][] codes = new int[batchSize][padLength];
            float[][] scores = new float[batchSize][padLength];
            float transKeepProb = (float) 1.0;
            float multiKeepProb = (float) 1.0;
    
            byte[] graphDef = readAllByteOrExit(Paths.get(modelPath));
            Graph g = new Graph();
            g.importGraphDef(graphDef);
            Session sess = new Session(g);
    
            featToTensor(indexes, codes, scores, data);
            Tensor tensorIndex = Tensor.create(indexes);
            Tensor tensorCode = Tensor.create(codes);
            Tensor tensorScore = Tensor.create(scores);
            Tensor tensorTransProb = Tensor.create(transKeepProb);
            Tensor tensorMultiProb = Tensor.create(multiKeepProb);
            Tensor tensorClassResult = sess.runner().
                    feed("input_x:0", tensorCode).
                    feed("input_x_score:0", tensorScore).
                    feed("embed_position:0", tensorIndex).
                    feed("trans_keep_prob:0", tensorTransProb).
                    feed("multi_keep_prob:0", tensorMultiProb).
                    fetch("discriminator/softmax_score:0").run().get(0);
    
            float[][] result = (float[][]) tensorClassResult.copyTo(new float[batchSize][classNum]);
            List<HashMap<String, Float>> predictResult = new ArrayList();
            for(int i=0; i<result.length; i++){
                float[] resultVec = result[i];
                int[] resultIndex = new int[classNum];
                HashMap<String, Float> predictSample = new HashMap<String, Float>();
                resultIndex = arraySort(resultVec, true);
                for(int s=0; s < returnNum; s++) {
                    String sampleCode = Integer.toString(resultIndex[s]);
                    String label = (String) codeToLabel.get(Integer.toString((Integer) idToCode.get(sampleCode)));
                    predictSample.put(label, resultVec[s]);
                }
                predictResult.add(predictSample);
            }
            tensorClassResult.close();
            tensorMultiProb.close();
            tensorTransProb.close();
            tensorScore.close();
            tensorCode.close();
            tensorIndex.close();
            return predictResult;
        }
    
            public static void main (String[]args) throws IOException {
                String testFile = "src/main/data/predict_data.json";
    
                getConfig();
                Map<String, List> samples = getDataContent(testFile);
                List<HashMap<String, Float>> result = modelInfer(samples);
    
                System.out.println(result);
            }
    
    }

    2. 基于bert的ner

    package com.techwolf.bert;
    
    import org.tensorflow.Graph;
    import org.tensorflow.Session;
    import org.tensorflow.Tensor;
    
    import java.io.BufferedReader;
    import java.io.FileInputStream;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.nio.file.Files;
    import java.nio.file.Path;
    import java.nio.file.Paths;
    import java.util.*;
    
    
    public class BertNerPredict {
        private static String vocabPath = "src/main/resources/vocab.txt";
        private static Map<String, Integer> word2id = new HashMap<String, Integer>();
        static {
            try {
                BufferedReader buffer = null;
                buffer = new BufferedReader(new InputStreamReader(new FileInputStream(vocabPath)));
                int i = 0;
                String line = buffer.readLine().trim();
                while (line!=null){
                    word2id.put(line, i++);
                    line = buffer.readLine().trim();
                }
                buffer.close();
            }catch (Exception e){
            }
    //        System.out.println("word2id size is:"+word2id.size());
    
        }
    
        private static byte[] readAllByteOrExit(Path path){
            try{
                return Files.readAllBytes(path);
            }catch (IOException e){
                System.out.println("Failed to read[" + path + "]:" + e.getMessage());
                System.exit(1);
            }
            return null;
        }
    
        public static void getTextToId(int[][] inputIds, int[][] inputMask, String[] text){
            for(int i=0; i<text.length; i++){
                char[] chs = text[i].trim().toLowerCase().toCharArray();
    
                List<Integer> list = new ArrayList<>();
                List<Integer> mask = new ArrayList<>();
                list.add(word2id.get("[CLS]"));
                mask.add(1);
                for(int j=0; j<chs.length; j++){
                    String element = Character.toString(chs[j]);
                    if(word2id.containsKey(element)){
                        list.add(word2id.get(element));
                        mask.add(1);
                    }
                }
                list.add(word2id.get("[SEP]"));
                mask.add(1);
    
                int size = list.size();
                Integer[] targetInter = list.toArray(new Integer[size]);
                int[] idResult = Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
                Integer[] maskInter = mask.toArray(new Integer[size]);
                int[] maskResult = Arrays.stream(maskInter).mapToInt(Integer::valueOf).toArray();
                System.arraycopy(idResult,0,inputIds[i], 0, idResult.length);
                System.arraycopy(maskResult,0,inputMask[i], 0, maskResult.length);
            }
        }
    
        public static void main(String[] args) {
            String[] query = new String[]{"中华人民共和国", "新疆大学"};
            String resourceDir = "src/main/resources";
            String modelName = "model.pb";
    
            int batchSize = query.length;
            int padLength = 25;
            int[][] indexes = new int[batchSize][padLength];
            int[][] mask = new int[batchSize][padLength];
    
            byte[] graphDef = readAllByteOrExit(Paths.get(resourceDir, modelName));
            Graph g = new Graph();
            g.importGraphDef(graphDef);
            Session sess = new Session(g);
    
            if (query.length>0){
                System.out.println("Ok! Start predicting...
    ");
            }else {
                System.exit(0);
            }
    
            getTextToId(indexes, mask, query);
            Tensor tensorInputIds = Tensor.create(indexes);
            Tensor tensorMask = Tensor.create(mask);
            Tensor tensorSeqResult = sess.runner().feed("input_ids:0", tensorInputIds).
                    feed("input_mask:0", tensorMask).fetch("viterbi/ReverseSequence_1:0").run().get(0);
            Tensor tensorScoreResult = sess.runner().feed("input_ids:0", tensorInputIds).
                    feed("input_mask:0", tensorMask).fetch("viterbi/Max:0").run().get(0);
            int[][] sequenceId = (int[][]) tensorSeqResult.copyTo(new int[batchSize][padLength]);
            float[] sequenceScore = (float[]) tensorScoreResult.copyTo(new float[batchSize]);
            for(int i=0; i<sequenceId.length; i++){
                System.out.println("query: "+query[i]);
                System.out.println("sequence result: "+ Arrays.toString(sequenceId[i]));
                System.out.println("sequence score: "+ sequenceScore[i]+'
    ');
            }
            tensorScoreResult.close();
            tensorSeqResult.close();
            tensorMask.close();
            tensorInputIds.close();
        }
    }
  • 相关阅读:
    js 注意
    JS学习大作业-Excel
    js继承
    转载:margin外边距合并问题以及解决方式
    CSS属性选择器和部分伪类
    HTML使用CSS样式的方法
    link元素 rel src href属性
    【2020.01.06】SDN大作业
    【2019.12.11】SDN上机第7次作业
    【2019.12.04】SDN上机第6次作业
  • 原文地址:https://www.cnblogs.com/demo-deng/p/13903264.html
Copyright © 2011-2022 走看看