zoukankan      html  css  js  c++  java
  • Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer

    Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer

    0x00 摘要

    Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文将带领大家来分析Alink中 MultiStringIndexer 的实现。

    因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。

    本文缘由是想分析GBDT,发现GBDT涉及到MultiStringIndexer的使用,所以只能先分析MultiStringIndexer 。

    0x01 概念

    Alink的官方介绍是:MultiStringIndexer训练组件的作用是训练一个模型用于将多列字符串映射为整数。

    具体来说,StringIndexer(字符串-索引变换)将标签的"字符串列"编码为"标签索引的列"。

    • 标签索引序列的取值范围是[0,numLabels(字符串中所有出现的单词去掉重复的词后的总和)],按照标签出现频率排序,出现最多的标签索引为0(具体为升序降序是可以配置的)。
    • 如果输入是数值型,我们先将数值映射到字符串,再对字符串进行索引化。
    • 如果下游的pipeline(例如:Estimator或者Transformer)需要用到索引化后的标签序列,则需要将这个pipeline的输入列名字指定为索引化序列的名字。大部分情况下,通过setSelectedCols设置输入的列名。

    以这些输入为例:

    ("football", "can"),
    ("football", "hhh"),
    ("football", "zzz"),
    ("basketball", "zzz"),
    ("basketball", "can"),
    ("tennis", "can")
    

    对于第一列,MultiStringIndexer 对数据集的label进行重新编号。按label出现的频次,转换成0 ~ numOfLabels - 1(分类个数)。如果是按照从高到低排序,则频次最高的转换为0,以此类推,比如:

    • football,出现次数最多,出现了3次,转换(编号)为0
    • 其次是basketball,出现了2次,编号为1,以此类推。

    在应用StringIndexer对labels进行重新编号后,带着这些编号后的label对数据进行了训练,并接着对其他数据进行了预测,得到预测结果,预测结果的label也是重新编号过的,因此需要转换回来。

    0x02 示例代码

    示例代码如下,本示例代码中,是按照升序排列,即football总数为3,则其idx为3,tennis个数为1,其idx为0:

    public class MultiStringIndexerExample {
        static AlgoOperator getData(boolean isBatch) {
            Row[] array = new Row[] {
                    Row.of("football", "can"),
                    Row.of("football", "hhh"),
                    Row.of("football", "zzz"),
                    Row.of("basketball", "zzz"),
                    Row.of("basketball", "can"),
                    Row.of("tennis", "can")
            };
    
            if (isBatch) {
                return new MemSourceBatchOp(
                        Arrays.asList(array), new String[] {"a", "b"});
            } else {
                return new MemSourceStreamOp(
                        Arrays.asList(array), new String[] {"a", "b"});
            }
        }
    
        public static void main(String[] args) throws Exception {
            BatchOperator data = (BatchOperator)getData(true);
            MultiStringIndexer stringindexer = new MultiStringIndexer()
                    .setSelectedCols("a", "b")
                    .setOutputCols("a_indexed", "b_indexed")
                    .setStringOrderType("frequency_asc");
            stringindexer.fit(data).transform(data).print();
        }
    }
    

    输出如下:

    a|b|a_indexed|b_indexed
    -|-|---------|---------
    football|can|2|2
    football|hhh|2|0
    football|zzz|2|1
    basketball|zzz|1|1
    basketball|can|1|2
    tennis|can|0|2
    

    转换成表格看的更清楚。

    a b a_indexed b_indexed
    football can 2 2
    football hhh 2 0
    football zzz 2 1
    basketball zzz 1 1
    basketball can 1 2
    tennis can 0 2

    0x03 总体逻辑

    我们先给出一个流程图

    老套路,我们从 MultiStringIndexerTrainBatchOp.linkFrom开始挖掘。

    @Override
    public MultiStringIndexerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
        BatchOperator<?> in = checkAndGetFirst(inputs);
    
        // 示例中有 .setSelectedCols("a", "b"),这里是取出具体列名字
        final String[] selectedColNames = getSelectedCols();
        // 获取列的类型
        final String[] selectedColSqlType = new String[selectedColNames.length];
        for (int i = 0; i < selectedColNames.length; i++) {
            selectedColSqlType[i] = FlinkTypeConverter.getTypeString(
                TableUtil.findColTypeWithAssertAndHint(in.getSchema(), selectedColNames[i]));
        }
    
    // runtime打印数据
    selectedColNames = {String[2]@2536} 
     0 = "a"
     1 = "b"
    selectedColSqlType = {String[2]@2537} 
     0 = "VARCHAR"
     1 = "VARCHAR"
      
        // 获取选取列对应的数据
        DataSet<Row> inputRows = in.select(selectedColNames).getDataSet();
        // 
        DataSet<Tuple3<Integer, String, Long>> indexedToken =
            StringIndexerUtil.indexTokens(inputRows, getStringOrderType(), 0L, true);
    
        DataSet<Row> values = indexedToken
            .mapPartition(new RichMapPartitionFunction<Tuple3<Integer, String, Long>, Row>() {
                @Override
                public void mapPartition(Iterable<Tuple3<Integer, String, Long>> values, Collector<Row> out)
                    throws Exception {
                    Params meta = null;
                    if (getRuntimeContext().getIndexOfThisSubtask() == 0) {           
                        // 第一个task会做这个计算,就是把列名,列类型作为元数据传送
                        meta = new Params().set(HasSelectedCols.SELECTED_COLS, selectedColNames)
                            .set(HasSelectedColTypes.SELECTED_COL_TYPES, selectedColSqlType);
                    }
     
    // runtime打印数据              
    meta = {Params@9311} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
     params = {HashMap@9316}  size = 2              
                  
                    new MultiStringIndexerModelDataConverter().save(Tuple2.of(meta, values), out);
                }
            })
            .name("build_model");
    
        this.setOutput(values, new MultiStringIndexerModelDataConverter().getModelSchema());
        return this;
    }
    

    训练过程总体逻辑总结如下:

    • 取出具体列名字,列的类型;
    • 获取"选取列"对应的数据;
    • 把列名,列类型作为元数据传送;
    • StringIndexerUtil.indexTokens 给各个列的不同字串赋予连续的indices。每列的 indices 彼此不相关;
      • 调用到 indexSortedByFreq(data, startIndex, ignoreNull, true),作用是给各个列的不同字串赋予连续的indices,indices是按照字符串出现的频率排序;
        • 调用到 countTokens的作用是按照 "列idx","word" 来合并计算单词个数,得到<"列idx","word",单词个数>,比如第一列中,football这个单词的个数是3,则返回三元组是 <0,football,3>,其中列的idx从0开始计算。
          • 调用 flattenTokens 把输入数据 Row 给打散,返回 A DataSet of tuples of column index and token,即<"列idx","word">。比如对于 Row.of("football", "can") 这个输入,flattenTokens 输出两个Tuple2 ,<0, "football"> 和 <1, "can">。
          • 对上面结构进行map操作,输出<column idx, word, 1L>,比如 <0, "football", 1L> ;
          • 按照 "列idx","word" 来分组;
          • 按照 "列idx","word" 来合并计算单词个数;
        • indexSortedByFreq会对countTokens返回的结果<"列idx","word",词频>处理;
          • 首先按照 列idx 做分组;
          • 然后在上面结果基础上,按照单词个数排序;
          • 排序的index是以输入参数startIndex开始,startIndex在这里是0;
          • 最后得到 第一列的 (0,football,0),(0,basketball,1),(0,football,2);第二列的数据 (1,hhh,0),(1,zzz,1),(1,can,2);
    • 把indexTokens的结果存储为模型,其中使用之前提到的 "把列名,列类型作为元数据"。

    下面具体剖析后两个阶段。

    0x04 Add Index to Token

    这部分就是给各个列的不同字串赋予连续的indices。每列的 indices 彼此不相关。

    具体是由StringIndexerUtil.indexTokens 做到的。

    public static DataSet<Tuple3<Integer, String, Long>> indexTokens(
        DataSet<Row> data, HasStringOrderTypeDefaultAsRandom.StringOrderType orderType,
        final long startIndex, final boolean ignoreNull) {
        		case FREQUENCY_ASC:
                    return indexSortedByFreq(data, startIndex, ignoreNull, true);
    }
    

    4.1 合并计算单词个数

    indexSortedByFreq会调用countTokens来计算单词个数,所以我们先看countTokens。

    countTokens的作用是按照 "列idx","word" 来合并计算单词个数,比如第一列中,football这个单词的个数是3,则返回三元组是 <0,football,3>,其中列的idx从0开始计算。

    具体逻辑如下:

    • 调用 flattenTokens 把输入数据 Row 给打散,返回 A DataSet of tuples of column index and token,即<"列idx","word">。比如对于 Row.of("football", "can") 这个输入,flattenTokens 输出两个Tuple2 ,<0, "football"> 和 <1, "can">。
    • 对上面结果进行map操作,输出<column idx, word, 1L>,比如 <0, "football", 1L> ,这个是计数的常规操作。
    • 按照 "列idx","word" 来分组;
    • 按照 "列idx","word" 来合并计算单词个数,就是不停归并上面的 1L。

    4.1.1 打散输入数据

    其中 flattenTokens 的作用是把输入数据 Row 给打散,返回 A DataSet of tuples of column index and token.。

    比如对于 Row.of("football", "can") 这个输入,flattenTokens 使用 out.collect(Tuple2.of(i, String.valueOf(o))); 输出两个Tuple2。

    value = {Row@9212} "football,can"
     fields = {Object[2]@9215} 
      0 = "football"
      1 = "can"
      
    输出 <0, "football"> 和 <1, "can">
    

    4.1.2 分组计算个数

    这是通过flattenTokens的结果进行 map,groupBy,reduce的一系列操作完成的。

    具体代码如下:

    public static DataSet<Tuple3<Integer, String, Long>> countTokens(DataSet<Row> data, final boolean ignoreNull) {
        return flattenTokens(data, ignoreNull) // 把输入数据 Row 给打散
            .map(new MapFunction<Tuple2<Integer, String>, Tuple3<Integer, String, Long>>() {
                @Override
                public Tuple3<Integer, String, Long> map(Tuple2<Integer, String> value) throws Exception {
                    return Tuple3.of(value.f0, value.f1, 1L); // 输出<column idx, word, 1L>,比如 <0, "football", 1L> 
                }
            })
            .groupBy(0, 1) // 按照 "列idx","word" 来分组
            .reduce(new ReduceFunction<Tuple3<Integer, String, Long>>() {
                @Override
                public Tuple3<Integer, String, Long> reduce(Tuple3<Integer, String, Long> value1, Tuple3<Integer, String, Long> value2) throws Exception {
                    value1.f2 += value2.f2;
                    return value1; // 按照 "列idx","word" 来合并计算单词个数
                }
            })
            .name("count_tokens");
    }
    
    // reduce之后发出
    value1 = {Tuple3@9284} "(0,football,3)"
     f0 = {Integer@9226} 0
     f1 = "football"
     f2 = {Long@9295} 3
    

    4.2 合并计算单词个数

    前面 countTokens的 返回三元组是 <列idx","word" ,词频>,其中列的idx从0开始计算。

    indexSortedByFreq会对countTokens返回的结果<"列idx","word",词频>处理;

    • 首先按照 列idx 做分组;
    • 然后在上面结果基础上,按照单词个数排序;
    • 排序的index是以输入参数startIndex开始,startIndex在这里是0;
    • 最后得到 第一列的 (0,tennis,0),(0,basketball,1),(0,football,2);第二列的数据 (1,hhh,0),(1,zzz,1),(1,can,2);

    具体代码如下:

    public static DataSet<Tuple3<Integer, String, Long>> indexSortedByFreq(
        DataSet<Row> data, final long startIndex, final boolean ignoreNull, final boolean isAscending) {
        return countTokens(data, ignoreNull)
            .groupBy(0) //按照 列idx 做分组
            .sortGroup(2, isAscending ? Order.ASCENDING : Order.DESCENDING) //按照单词个数排序
            .reduceGroup(new GroupReduceFunction<Tuple3<Integer, String, Long>, Tuple3<Integer, String, Long>>() {
                @Override
                public void reduce(Iterable<Tuple3<Integer, String, Long>> values,
                                   Collector<Tuple3<Integer, String, Long>> out) {
                    long id = startIndex;
                    for (Tuple3<Integer, String, Long> value : values) {
                        out.collect(Tuple3.of(value.f0, value.f1, id++)); // 归并
                    }
                }
            });
    }
    

    0x05 输出模型

    这部分分为两部分:

    • 输出元数据,就是之前得到的 "把列名,列类型作为元数据"。
    • 输出具体每一列的每一个单词信息,比如 第一列的 (0,tennis,0),(0,basketball,1),(0,football,2);第二列的数据 (1,hhh,0),(1,zzz,1),(1,can,2);
    public class MultiStringIndexerModelDataConverter implements
        ModelDataConverter<Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>>, MultiStringIndexerModelData> {
        @Override
        public void save(Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>> modelData, Collector<Row> collector) {
            if (modelData.f0 != null) {
                collector.collect(Row.of(-1L, modelData.f0.toJson(), null));
            }
            modelData.f1.forEach(tuple -> {
                collector.collect(Row.of(tuple.f0.longValue(), tuple.f1, tuple.f2));
            });
        }  
    }
    
    tuple = {Tuple3@9405} "(0,tennis,0)"
     f0 = {Integer@9406} 0
     f1 = "tennis"
     f2 = {Long@9408} 0
    

    0x06 预测

    预测功能是在 ModelMapperAdapter 完成的。

    public class ModelMapperAdapter extends RichMapFunction<Row, Row> implements Serializable {
        private final ModelMapper mapper;
        private final ModelSource modelSource;
    
        @Override
        public void open(Configuration parameters) throws Exception {
            List<Row> modelRows = this.modelSource.getModelRows(getRuntimeContext());
            this.mapper.loadModel(modelRows); //加载模型
        }
    
        @Override
        public Row map(Row row) throws Exception {
            return this.mapper.map(row); //预测
        }
    }
    

    6.1 加载模型

    MultiStringIndexerModelDataConverter中我们会进行模型加载。

    • 首先会加载元信息
    • 其次会逐条加载模型信息
    public MultiStringIndexerModelData load(List<Row> rows) {
        MultiStringIndexerModelData modelData = new MultiStringIndexerModelData();
        modelData.tokenAndIndex = new ArrayList<>();
        modelData.tokenNumber = new HashMap<>();
        for (Row row : rows) {
            long colIndex = (Long) row.getField(0);
            if (colIndex < 0L) { // 元数据
                modelData.meta = Params.fromJson((String) row.getField(1));
            } else { // 具体模型信息
                int columnIndex = ((Long) row.getField(0)).intValue();
                Long tokenIndex = Long.valueOf(String.valueOf(row.getField(2)));
                modelData.tokenAndIndex.add(Tuple3.of(columnIndex, (String) row.getField(1), tokenIndex));
                modelData.tokenNumber.merge(columnIndex, 1L, Long::sum); // 合并列数据个数
            }
        }
    
        // To ensure that every columns has token number.
        int numFields = 0;
        if (modelData.meta != null) {
            numFields = modelData.meta.get(HasSelectedCols.SELECTED_COLS).length;
        }
        for (int i = 0; i < numFields; i++) {
            modelData.tokenNumber.merge(i, 0L, Long::sum);
        }
        return modelData;
    }
    

    最后模型内容如下,其中 tokenNumber 表示每列的数据有几个,tokenAndIndex表示具体信息,比如(0,tennis,0),(0,basketball,1),(0,football,2) 就表示他们都是第一列的,basketball转换后的数据是 1:

    modelData = {MultiStringIndexerModelData@9348} 
     meta = {Params@9440} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
     tokenAndIndex = {ArrayList@9360}  size = 6
      0 = {Tuple3@9472} "(0,football,2)"
      1 = {Tuple3@9511} "(0,tennis,0)"
      2 = {Tuple3@9512} "(1,zzz,1)"
      3 = {Tuple3@9513} "(1,hhh,0)"
      4 = {Tuple3@9514} "(0,basketball,1)"
      5 = {Tuple3@9515} "(1,can,2)"
     tokenNumber = {HashMap@9385}  size = 2
      {Integer@9507} 0 -> {Long@9508} 3
      {Integer@9509} 1 -> {Long@9508} 3
    numFields = 2
    

    6.2 预测

    预测是在 MultiStringIndexerModelMapper 完成的。

    // 假设输入是:row = {Row@9309} "football,can"
    // 选择的列是:selectedColNames = {String[2]@9314}  0 = "a" 1 = "b"
    // 模型映射器是:
    this = {MultiStringIndexerModelMapper@9309} 
     indexMapper = {HashMap@9318}  size = 2
      {Integer@9357} 0 -> {HashMap@9314}  size = 3
       key = {Integer@9357} 0
        value = 0
       value = {HashMap@9314}  size = 3
        "basketball" -> {Long@9386} 1
        "football" -> {Long@9332} 2
        "tennis" -> {Long@9384} 0
      {Integer@9352} 1 -> {HashMap@9358}  size = 3
       key = {Integer@9352} 1
        value = 1
       value = {HashMap@9358}  size = 3
        "can" -> {Long@9332} 2
        "hhh" -> {Long@9384} 0
        "zzz" -> {Long@9386} 1
    

    则经历过下列代码,最后就可以进行预测

    public Row map(Row row) throws Exception {
        Row result = new Row(selectedColNames.length);
        for (int i = 0; i < selectedColNames.length; i++) {
            Map<String, Long> mapper = indexMapper.get(i);
            int colIdxInData = selectedColIndicesInData[i];
            Object val = row.getField(colIdxInData);
            String key = val == null ? null : String.valueOf(val);
            Long index = mapper.get(key);
            if (index != null) {
                result.setField(i, index); // 我们主要执行在这里
            } else {
            }
        }
      
    // 最后预测结果是:
    row = {Row@9308} "football,can"
    result = {Row@9313} "2,2"
        
        return outputColsHelper.getResultRow(row, result);
    }
    

    0xFF 参考

    Spark之特征预处理

  • 相关阅读:
    elk系列1之入门安装与基本操作【转】
    elk系列3之通过json格式采集Nginx日志【转】
    mysql开启GTID跳过错误的方法【转】
    curl: (6) Couldn’t resolve host ‘www.ttlsa.com’【转】
    离线下载pip包进行安装【转】
    初学Memcached安装及使用【转】
    http 错误代码解释 && nginx 自定义错误【转】
    有关mysql的innodb_flush_log_at_trx_commit参数【转】
    mysqldump 逻辑备份的正确方法【转】
    谁说运维用ELK没用?我就说很有用,只是你之前不会用【转】
  • 原文地址:https://www.cnblogs.com/rossiXYZ/p/13429876.html
Copyright © 2011-2022 走看看