JPMML解析Random Forest模型并使用其预测分析
导入Jar包
maven 的pom.xml文件中添加jpmml的依赖
<dependency> <groupId>org.jpmml</groupId> <artifactId>pmml-evaluator</artifactId> <version>1.3.7</version> </dependency>
具体实现代码
模型读取类
import java.io.*; import java.nio.charset.Charset; import java.util.*; import com.google.common.io.Files; import org.dmg.pmml.FieldName; /** * 使用模型 * @author biantech * */ public class PmmlCalc { final static String utf8="utf-8"; public static void main(String[] args) throws IOException { if(args.length < 2){ System.out.println("参数个数不匹配"); } //文件生成路径 String pmmlPath = args[0]; String modelArgsFilePath = args[1]; PmmlInvoker invoker = new PmmlInvoker(pmmlPath); List<Map<FieldName, String>> paramList = readInParams(modelArgsFilePath); int lineNum = 0; //当前处理行数 File file = new File("result.txt"); for(Map<FieldName, String> param : paramList){ lineNum++; //System.out.println("======当前行: " + lineNum + "======="); Files.append("======当前行: " + lineNum + "=======",file,Charset.forName(utf8)); Map<FieldName, ?> result = invoker.invoke(param); Set<FieldName> keySet = result.keySet(); //获取结果的keySet for(FieldName fn : keySet){ String tempString = result.get(fn).toString()+" "; Files.append(tempString,file,Charset.forName(utf8)); } } System.out.println("resultFile="+file.getAbsolutePath()); } /** * 读取参数文件 * @param filePath 文件路径 * @return * @throws IOException */ public static List<Map<FieldName,String>> readInParams(String filePath) throws IOException{ InputStream is; is = PmmlCalc.class.getClassLoader().getResourceAsStream(filePath); if(is==null){ is = new FileInputStream(filePath); } InputStreamReader isreader = new InputStreamReader(is); BufferedReader br = new BufferedReader(isreader); String[] nameArr = br.readLine().split(","); //读取表头的名字 ArrayList<Map<FieldName,String>> list = new ArrayList<>(); String paramLine; //一行参数 //循环读取 每次读取一行数据 while((paramLine = br.readLine()) != null){ Map<FieldName,String> map = new HashMap<>(); String[] paramLineArr = paramLine.split(","); for(int i=0; i<paramLineArr.length; i++){//一次循环处理一行数据 map.put(new FieldName(nameArr[i]), paramLineArr[i]); //将表头和值组成map 加入list中 } list.add(map); } is.close(); return list; } }
调用执行类:PmmlInvoker
import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.util.Map; import javax.xml.bind.JAXBException; import org.dmg.pmml.FieldName; import org.dmg.pmml.PMML; import org.jpmml.evaluator.ModelEvaluator; import org.jpmml.evaluator.ModelEvaluatorFactory; import org.jpmml.model.PMMLUtil; import org.xml.sax.SAXException; /** * 读取pmml 获取模型 * @author biantech * */ public class PmmlInvoker { private ModelEvaluator modelEvaluator; // 通过文件读取模型 public PmmlInvoker(String pmmlFileName) { PMML pmml = null; InputStream is = null; try { if (pmmlFileName != null) { is = PmmlInvoker.class.getClassLoader().getResourceAsStream(pmmlFileName); if(is==null){ is = new FileInputStream(pmmlFileName); } pmml = PMMLUtil.unmarshal(is); } this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml); } catch (Exception e) { e.printStackTrace(); } finally { try { if(is!=null) is.close(); } catch (Exception localIOException3) { localIOException3.printStackTrace(); } } this.modelEvaluator.verify(); System.out.println("模型读取成功"); } // 通过输入流读取模型 public PmmlInvoker(InputStream is) { PMML pmml; try { pmml = PMMLUtil.unmarshal(is); try { is.close(); } catch (IOException localIOException) { } this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml); } catch (SAXException e) { pmml = null; } catch (JAXBException e) { pmml = null; } finally { try { is.close(); } catch (IOException localIOException3) { } } this.modelEvaluator.verify(); } public Map<FieldName, String> invoke(Map<FieldName, String> paramsMap) { return this.modelEvaluator.evaluate(paramsMap); } }
如何运行
- mvn package 命令生成 jpmml-parser-1-jar-with-dependencies.jar
- 将pmml文件, 数据集文件,jar 放在同一个目录下.(如 demo-model.pmml ,demo-data.csv)
- 使用命令行运行
java -jar jpmml-parser-1-jar-with-dependencies.jar demo-model.pmml demo-data.csv
- 运行结束后会生成一个result.txt,里面存储的是对数据的预测分析结果
======当前行: 1=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]} setosa 1.0 0.0 0.0 ======当前行: 2=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]} setosa 1.0 0.0 0.0 ======当前行: 3=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]} setosa 1.0 0.0 0.0 ======当前行: 4=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]} setosa 1.0 0.0 0.0 ======当前行: 5=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]} setosa 1.0 0.0 0.0 ======当前行: 6=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]} setosa 1.0 0.0 0.0
具体源代码请看如下地址
https://github.com/biantech/jpmml-parser