zoukankan      html  css  js  c++  java
  • 用PMML实现机器学习模型的跨平台上线

      在机器学习用于产品的时候,我们经常会遇到跨平台的问题。比如我们用Python基于一系列的机器学习库训练了一个模型,但是有时候其他的产品和项目想把这个模型集成进去,但是这些产品很多只支持某些特定的生产环境比如Java,为了上一个机器学习模型去大动干戈修改环境配置很不划算,此时我们就可以考虑用预测模型标记语言(Predictive Model Markup Language,以下简称PMML)来实现跨平台的机器学习模型部署了。

    1. PMML概述

        PMML是数据挖掘的一种通用的规范,它用统一的XML格式来描述我们生成的机器学习模型。这样无论你的模型是sklearn,R还是Spark MLlib生成的,我们都可以将其转化为标准的XML格式来存储。当我们需要将这个PMML的模型用于部署的时候,可以使用目标环境的解析PMML模型的库来加载模型,并做预测。

        可以看出,要使用PMML,需要两步的工作,第一块是将离线训练得到的模型转化为PMML模型文件,第二块是将PMML模型文件载入在线预测环境,进行预测。这两块都需要相关的库支持。

    2. PMML模型的生成和加载相关类库

        PMML模型的生成相关的库需要看我们使用的离线训练库。如果我们使用的是sklearn,那么可以使用sklearn2pmml这个python库来做模型文件的生成,这个库安装很简单,使用"pip install sklearn2pmml"即可,相关的使用我们后面会有一个demo。如果使用的是Spark MLlib, 这个库有一些模型已经自带了保存PMML模型的方法,可惜并不全。如果是R,则需要安装包"XML"和“PMML”。此外,JAVA库JPMML可以用来生成R,SparkMLlib,xgBoost,Sklearn的模型对应的PMML文件。github地址是:https://github.com/jpmml/jpmml。

        加载PMML模型需要目标环境支持PMML加载的库,如果是JAVA,则可以用JPMML来加载PMML模型文件。相关的使用我们后面会有一个demo。

    3. PMML模型生成和加载示例

        下面我们给一个示例,使用sklearn生成一个决策树模型,用sklearn2pmml生成模型文件,用JPMML加载模型文件,并做预测。

        完整代码参见我的github:https://github.com/ljpzzz/machinelearning/blob/master/model-in-product/sklearn-jpmml

        首先是用用sklearn生成一个决策树模型,由于我们是需要保存PMML文件,所以最好把模型先放到一个Pipeline数组里面。这个数组里面除了我们的决策树模型以外,还可以有归一化,降维等预处理操作,这里作为一个示例,我们Pipeline数组里面只有决策树模型。代码如下:

    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib inline
    import pandas as pd
    from sklearn import tree
    from sklearn2pmml.pipeline import PMMLPipeline
    from sklearn2pmml import sklearn2pmml

    import os
    os.environ["PATH"] += os.pathsep + 'C:/Program Files/Java/jdk1.8.0_171/bin'

    X=[[1,2,3,1],[2,4,1,5],[7,8,3,6],[4,8,4,7],[2,5,6,9]]
    y=[0,1,0,2,1]
    pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier(random_state=9))]);
    pipeline.fit(X,y)

    sklearn2pmml(pipeline, ".demo.pmml", with_repr = True)

        上面这段代码做了一个非常简单的决策树分类模型,只有5个训练样本,特征有4个,输出类别有3个。实际应用时,我们需要将模型调参完毕后才将其放入PMMLPipeline进行保存。运行代码后,我们在当前目录会得到一个PMML的XML文件,可以直接打开看,内容大概如下:

    复制代码
    <?xml version="1.0" encoding="UTF-8" standalone="yes"?>
    <PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
        <Header>
            <Application name="JPMML-SkLearn" version="1.5.3"/>
            <Timestamp>2018-06-24T05:47:17Z</Timestamp>
        </Header>
        <MiningBuildTask>
            <Extension>PMMLPipeline(steps=[('classifier', DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
                max_features=None, max_leaf_nodes=None,
                min_impurity_decrease=0.0, min_impurity_split=None,
                min_samples_leaf=1, min_samples_split=2,
                min_weight_fraction_leaf=0.0, presort=False, random_state=9,
                splitter='best'))])</Extension>
        </MiningBuildTask>
        <DataDictionary>
            <DataField name="y" optype="categorical" dataType="integer">
                <Value value="0"/>
                <Value value="1"/>
                <Value value="2"/>
            </DataField>
            <DataField name="x3" optype="continuous" dataType="float"/>
            <DataField name="x4" optype="continuous" dataType="float"/>
        </DataDictionary>
        <TransformationDictionary>
            <DerivedField name="double(x3)" optype="continuous" dataType="double">
                <FieldRef field="x3"/>
            </DerivedField>
            <DerivedField name="double(x4)" optype="continuous" dataType="double">
                <FieldRef field="x4"/>
            </DerivedField>
        </TransformationDictionary>
        <TreeModel functionName="classification" missingValueStrategy="nullPrediction" splitCharacteristic="multiSplit">
            <MiningSchema>
                <MiningField name="y" usageType="target"/>
                <MiningField name="x3"/>
                <MiningField name="x4"/>
            </MiningSchema>
            <Output>
                <OutputField name="probability(0)" optype="continuous" dataType="double" feature="probability" value="0"/>
                <OutputField name="probability(1)" optype="continuous" dataType="double" feature="probability" value="1"/>
                <OutputField name="probability(2)" optype="continuous" dataType="double" feature="probability" value="2"/>
            </Output>
            <Node>
                <True/>
                <Node>
                    <SimplePredicate field="double(x3)" operator="lessOrEqual" value="3.5"/>
                    <Node score="1" recordCount="1.0">
                        <SimplePredicate field="double(x3)" operator="lessOrEqual" value="2.0"/>
                        <ScoreDistribution value="0" recordCount="0.0"/>
                        <ScoreDistribution value="1" recordCount="1.0"/>
                        <ScoreDistribution value="2" recordCount="0.0"/>
                    </Node>
                    <Node score="0" recordCount="2.0">
                        <True/>
                        <ScoreDistribution value="0" recordCount="2.0"/>
                        <ScoreDistribution value="1" recordCount="0.0"/>
                        <ScoreDistribution value="2" recordCount="0.0"/>
                    </Node>
                </Node>
                <Node score="2" recordCount="1.0">
                    <SimplePredicate field="double(x4)" operator="lessOrEqual" value="8.0"/>
                    <ScoreDistribution value="0" recordCount="0.0"/>
                    <ScoreDistribution value="1" recordCount="0.0"/>
                    <ScoreDistribution value="2" recordCount="1.0"/>
                </Node>
                <Node score="1" recordCount="1.0">
                    <True/>
                    <ScoreDistribution value="0" recordCount="0.0"/>
                    <ScoreDistribution value="1" recordCount="1.0"/>
                    <ScoreDistribution value="2" recordCount="0.0"/>
                </Node>
            </Node>
        </TreeModel>
    </PMML>
    复制代码

        可以看到里面就是决策树模型的树结构节点的各个参数,以及输入值。我们的输入被定义为x1-x4,输出定义为y。

        有了PMML模型文件,我们就可以写JAVA代码来读取加载这个模型并做预测了。

        我们创建一个Maven或者gradle工程,加入JPMML的依赖,这里给出maven在pom.xml的依赖,gradle的结构是类似的。

    复制代码
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.4.1</version>
        </dependency>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator-extension</artifactId>
            <version>1.4.1</version>
        </dependency>
    复制代码

        接着就是读取模型文件并预测的代码了,具体代码如下:

    复制代码
    import org.dmg.pmml.FieldName;
    import org.dmg.pmml.PMML;
    import org.jpmml.evaluator.*;
    import org.xml.sax.SAXException;
    
    import javax.xml.bind.JAXBException;
    import java.io.FileInputStream;
    import java.io.IOException;
    import java.io.InputStream;
    import java.util.HashMap;
    import java.util.LinkedHashMap;
    import java.util.List;
    import java.util.Map;
    
    /**
     * Created by 刘建平Pinard on 2018/6/24.
     */
    public class PMMLDemo {
        private Evaluator loadPmml(){
            PMML pmml = new PMML();
            InputStream inputStream = null;
            try {
                inputStream = new FileInputStream("D:/demo.pmml");
            } catch (IOException e) {
                e.printStackTrace();
            }
            if(inputStream == null){
                return null;
            }
            InputStream is = inputStream;
            try {
                pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
            } catch (SAXException e1) {
                e1.printStackTrace();
            } catch (JAXBException e1) {
                e1.printStackTrace();
            }finally {
                //关闭输入流
                try {
                    is.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
            Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
            pmml = null;
            return evaluator;
        }
        private int predict(Evaluator evaluator,int a, int b, int c, int d) {
            Map<String, Integer> data = new HashMap<String, Integer>();
            data.put("x1", a);
            data.put("x2", b);
            data.put("x3", c);
            data.put("x4", d);
            List<InputField> inputFields = evaluator.getInputFields();
            //过模型的原始特征,从画像中获取数据,作为模型输入
            Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
            for (InputField inputField : inputFields) {
                FieldName inputFieldName = inputField.getName();
                Object rawValue = data.get(inputFieldName.getValue());
                FieldValue inputFieldValue = inputField.prepare(rawValue);
                arguments.put(inputFieldName, inputFieldValue);
            }
    
            Map<FieldName, ?> results = evaluator.evaluate(arguments);
            List<TargetField> targetFields = evaluator.getTargetFields();
    
            TargetField targetField = targetFields.get(0);
            FieldName targetFieldName = targetField.getName();
    
            Object targetFieldValue = results.get(targetFieldName);
            System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
            int primitiveValue = -1;
            if (targetFieldValue instanceof Computable) {
                Computable computable = (Computable) targetFieldValue;
                primitiveValue = (Integer)computable.getResult();
            }
            System.out.println(a + " " + b + " " + c + " " + d + ":" + primitiveValue);
            return primitiveValue;
        }
        public static void main(String args[]){
            PMMLDemo demo = new PMMLDemo();
            Evaluator model = demo.loadPmml();
            demo.predict(model,1,8,99,1);
            demo.predict(model,111,89,9,11);
    
        }
    }
    复制代码

        代码里有两个函数,第一个loadPmml是加载模型的,第二个predict是读取预测样本并返回预测值的。我的代码运行结果如下:

    target: y value: {result=2, probability_entries=[0=0.0, 1=0.0, 2=1.0], entityId=5, confidence_entries=[]}
    1 8 99 1:2
    target: y value: {result=1, probability_entries=[0=0.0, 1=1.0, 2=0.0], entityId=6, confidence_entries=[]}
    111 89 9 11:1

        也就是样本(1,8,99,1)被预测为类别2,而(111,89,9,11)被预测为类别1。

        以上就是PMML生成和加载的一个示例,使用起来其实门槛并不高,也很简单。

    4. PMML总结与思考

        PMML的确是跨平台的利器,但是是不是就没有缺点呢?肯定是有的!

        第一个就是PMML为了满足跨平台,牺牲了很多平台独有的优化,所以很多时候我们用算法库自己的保存模型的API得到的模型文件,要比生成的PMML模型文件小很多。同时PMML文件加载速度也比算法库自己独有格式的模型文件加载慢很多。

        第二个就是PMML加载得到的模型和算法库自己独有的模型相比,预测会有一点点的偏差,当然这个偏差并不大。比如某一个样本,用sklearn的决策树模型预测为类别1,但是如果我们把这个决策树落盘为一个PMML文件,并用JAVA加载后,继续预测刚才这个样本,有较小的概率出现预测的结果不为类别1.

        第三个就是对于超大模型,比如大规模的集成学习模型,比如xgboost, 随机森林,或者tensorflow,生成的PMML文件很容易得到几个G,甚至上T,这时使用PMML文件加载预测速度会非常慢,此时推荐为模型建立一个专有的环境,就没有必要去考虑跨平台了。

        此外,对于TensorFlow,不推荐使用PMML的方式来跨平台。可能的方法一是TensorFlow serving,自己搭建预测服务,但是会稍有些复杂。另一个方法就是将模型保存为TensorFlow的模型文件,并用TensorFlow独有的JAVA库加载来做预测。

        我们在下一篇会讨论用python+tensorflow训练保存模型,并用tensorflow的JAVA库加载做预测的方法和实例。

  • 相关阅读:
    格式化数字,将字符串格式的数字,如:1000000 改为 1 000 000 这种展示方式
    jquery图片裁剪插件
    前端开发采坑之安卓和ios的兼容问题
    页面消息提示,上下滚动
    可以使用css的方式让input不能输入文字吗?
    智慧农村“三网合一”云平台测绘 大数据 农业 信息平台 应急
    三维虚拟城市平台测绘 大数据 规划 三维 信息平台 智慧城市
    农业大数据“一张图”平台测绘 大数据 房产 国土 农业 信息平台
    应急管理管理局安全生产预警平台应急管理系统不动产登记 测绘 大数据 规划 科教 三维 信息平台
    地下综合管廊管理平台测绘 大数据 地下管线 三维 信息平台
  • 原文地址:https://www.cnblogs.com/tan2810/p/11990005.html
Copyright © 2011-2022 走看看