zoukankan      html  css  js  c++  java
  • 使用java调用python训练出的pmml模型

    记录下自己的过程,以后可以随时用,如果能帮到大家就更好了。

    从安装软件说起,嫌麻烦的就别看了。

    一、下载工具(俗话说得好,预先善其事必先利其器!哈哈)

    我刚开始安装的是eclipse,但有诸多麻烦不能解决,就用了IDEA,和Pycharm一个公司发行的。

    首先进入官网: http://www.jetbrains.com/products.html#lang=java

    选择IDEA下载:

    由于社区版的功能太少,我下载的是企业版的,后边会告诉破解方法。

    IDEA的安装教程网上都有,正常安装就好。

    企业版的激活码大家可以关注一个公众号,我也是在网上找到的。

    http://idea.medeming.com/

    关注公众号后粘贴就行了。

    二、Java环境安装

    参考教程:https://blog.csdn.net/weixin_38381149/article/details/89668578

    写博客时想找当时看的博客,但发现了这个很全的,jdk,maven,tomcat都有。

    想当初我为了装一个maven花了好久。。。

    三、新建Maven项目

      File ==》New==》Project==》Maven

    四、接下来在IDEA中配置Maven,这是当时参考的博客:https://www.cnblogs.com/jiangzhaowei/p/9534393.html

    五、添加依赖

      由于我只是为了调用模型,没有太多依赖,只添加了这么几个

        <dependencies>
    
            <dependency>
                <groupId>org.jpmml</groupId>
                <artifactId>pmml-evaluator</artifactId>
                <version>1.4.1</version>
            </dependency>
            <dependency>
                <groupId>org.jpmml</groupId>
                <artifactId>pmml-evaluator-extension</artifactId>
                <version>1.4.1</version>
            </dependency>
    
            <dependency>
                <groupId>javax.xml.bind</groupId>
                <artifactId>jaxb-api</artifactId>
                <version>2.3.0</version>
            </dependency>
            <dependency>
                <groupId>com.sun.xml.bind</groupId>
                <artifactId>jaxb-core</artifactId>
                <version>2.3.0</version>
            </dependency>
            <dependency>
                <groupId>com.sun.xml.bind</groupId>
                <artifactId>jaxb-impl</artifactId>
                <version>2.3.0</version>
            </dependency>
    
        </dependencies>

    六、java调用Python训练出的pmml模型的代码

    import org.dmg.pmml.FieldName;
    import org.dmg.pmml.PMML;
    import org.jpmml.evaluator.*;
    import org.jpmml.model.PMMLUtil;
    import org.xml.sax.SAXException;
    
    import javax.xml.bind.JAXBException;
    import java.io.FileInputStream;
    import java.io.FileNotFoundException;
    import java.io.IOException;
    import java.io.InputStream;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    
    public class ClassificationModel {
        private Evaluator modelEvaluator;
    
        /**
         * 通过传入 PMML 文件路径来生成机器学习模型
         *
         * @param pmmlFileName pmml 文件路径
         */
        public ClassificationModel(String pmmlFileName) {
            PMML pmml = null;
    
            try {
                if (pmmlFileName != null) {
                    InputStream is = new FileInputStream(pmmlFileName);
                    pmml = PMMLUtil.unmarshal(is);
                    try {
                        is.close();
                    } catch (IOException e) {
                        System.out.println("InputStream close error!");
                    }
    
                    ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
    
                    this.modelEvaluator = (Evaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
                    modelEvaluator.verify();
                    System.out.println("加载模型成功!");
                }
            } catch (SAXException e) {
                e.printStackTrace();
            } catch (JAXBException e) {
                e.printStackTrace();
            } catch (FileNotFoundException e) {
                e.printStackTrace();
            }
    
        }
    
        // 获取模型需要的特征名称
        public List<String> getFeatureNames() {
            List<String> featureNames = new ArrayList<String>();
    
            List<InputField> inputFields = modelEvaluator.getInputFields();
    
            for (InputField inputField : inputFields) {
                featureNames.add(inputField.getName().toString());
            }
            return featureNames;
        }
    
        // 获取目标字段名称
        public String getTargetName() {
            return modelEvaluator.getTargetFields().get(0).getName().toString();
        }
    
        // 使用模型生成概率分布
        private ProbabilityDistribution getProbabilityDistribution(Map<FieldName, ?> arguments) {
            Map<FieldName, ?> evaluateResult = modelEvaluator.evaluate(arguments);
    
            FieldName fieldName = new FieldName(getTargetName());
    
            return (ProbabilityDistribution) evaluateResult.get(fieldName);
    
        }
    
        // 预测不同分类的概率
        public ValueMap<String, Number> predictProba(Map<FieldName, Number> arguments) {
            ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
            return probabilityDistribution.getValues();
        }
    
        // 预测结果分类
        public Object predict(Map<FieldName, ?> arguments) {
            ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
    
            return probabilityDistribution.getPrediction();
        }
    
        public static void main(String[] args) {
            ClassificationModel clf = new ClassificationModel("D:/JupyterSpace/RandomForestClassifier_Iris.pmml"); //这里模型地址
    
            List<String> featureNames = clf.getFeatureNames();
            System.out.println("feature: " + featureNames);
    
            // 构建待预测数据
            Map<FieldName, Number> waitPreSample = new HashMap<>();
         #这里的key一定要对应python中的列名 waitPreSample.put(
    new FieldName("sepal length (cm)"), 10); waitPreSample.put(new FieldName("sepal width (cm)"), 1); waitPreSample.put(new FieldName("petal length (cm)"), 3); waitPreSample.put(new FieldName("petal width (cm)"), 2); System.out.println("waitPreSample predict result: " + clf.predict(waitPreSample).toString()); System.out.println("waitPreSample predictProba result: " + clf.predictProba(waitPreSample).toString()); } }

    注意事项:

    1、类名和文件名要一致

    2、打开File  ==》Project Structure

    看你的JDK版本和这里是否一致

    运行程序,查看是否报错。

    这是我报的一个错:

    NoClassDefFoundError: javax/activation/DataSource

      解决方法是下载:activation.jar包。

      下载地址:

        链接:https://pan.baidu.com/s/14D8cQWIJp2d7h2iljAPZ2A
        提取码:6f37

    应该没什么问题了。有问题请留言,一定回复。(有问题一定要告诉我,以后还要用呢。。。)

    https://www.cnblogs.com/zhangzhixing/
  • 相关阅读:
    crontab与系统时间不一致
    MySQL构造测试数据
    【SQL优化】SQL优化工具
    mysql case when then 使用
    update没带where,寻找问题的思路
    线程池
    线程理论
    数据共享
    进程池
    管道
  • 原文地址:https://www.cnblogs.com/zhangzhixing/p/12095815.html
Copyright © 2011-2022 走看看