zoukankan      html  css  js  c++  java
  • elasticsearch插件的开发--计算特征向量的相似度

    更改elasticsearch的score评分

      在某些情况下,我们需要自定义score的分值,从而达到个性化搜索的目的。例如我们通过机器学习可以得到每个用户的特征向量、同时知道每个商品的特征向量,如何计算这两个特征向量的相似度?这个两个特征向量越高,评分越高,从而把那些与用户相似度高的商品优先推荐给用户。

    插件源码解读

      通过查看官网文档,运行一个脚步必须通过“ScriptEngine”来实现的。为了开发一个自定义的插件,我们需要实现“ScriptEngine”接口,并通过getScriptEngine()这个方法来加载我们的插件。ScriptEngine接口具体介绍见文献[1].下面通过官网给出的一个具体例子:

      private static class MyExpertScriptEngine implements ScriptEngine {
      //可以命名自己在脚本api中使用的名称来引用这个脚本后端。
        @Override
        public String getType() {
            return "expert_scripts";
        }
    

     

      //核心方法,下面是通过java的lamada表达式来实现的
        @Override
        public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
            if (context.equals(SearchScript.CONTEXT) == false) {
                throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
            }
            // we use the script "source" as the script identifier
            if ("pure_df".equals(scriptSource)) {
            //通过p来获取参数params中的值,lookup得到文档中的的值
                SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {
                    final String field;
                    final String term;
                    {
                        if (p.containsKey("field") == false) {
                            throw new IllegalArgumentException("Missing parameter [field]");
                        }
                        if (p.containsKey("term") == false) {
                            throw new IllegalArgumentException("Missing parameter [term]");
                        }
                        field = p.get("field").toString();
                        term = p.get("term").toString();
                    }
    
                    @Override
                    public SearchScript newInstance(LeafReaderContext context) throws IOException {
                        PostingsEnum postings = context.reader().postings(new Term(field, term));
                        if (postings == null) {
                            // the field and/or term don't exist in this segment, so always return 0
                            return new SearchScript(p, lookup, context) {
                                @Override
                                public double runAsDouble() {
                                    return 0.0d;
                                }
                            };
                        }
                        return new SearchScript(p, lookup, context) {
                            int currentDocid = -1;
                            @Override
                            public void setDocument(int docid) {
                                // advance has undefined behavior calling with a docid <= its current docid
                                if (postings.docID() < docid) {
                                    try {
                                        postings.advance(docid);
                                    } catch (IOException e) {
                                        throw new UncheckedIOException(e);
                                    }
                                }
                                currentDocid = docid;
                            }
                            @Override
                            public double runAsDouble() {
                                if (postings.docID() != currentDocid) {
                                    // advance moved past the current doc, so this doc has no occurrences of the term
                                    return 0.0d;
                                }
                                try {
                                    return postings.freq();
                                } catch (IOException e) {
                                    throw new UncheckedIOException(e);
                                }
                            }
                        };
                    }
    
                    @Override
                    public boolean needs_score() {
                        return false;
                    }
                };
                return context.factoryClazz.cast(factory);
            }
            throw new IllegalArgumentException("Unknown script name " + scriptSource);
        }
    
        @Override
        public void close() {
            // optionally close resources
        }
    }
    

    通过分析上面的代码及结合业务需求,我们给出如下脚步:

    脚步一

        package com;
        
        import org.apache.logging.log4j.LogManager;
        import org.apache.logging.log4j.Logger;
        import org.apache.lucene.index.LeafReaderContext;
        import org.elasticsearch.script.ScriptContext;
        import org.elasticsearch.script.ScriptEngine;
        import org.elasticsearch.script.SearchScript;
        
        import java.io.IOException;
        import java.util.*;
        
        /**
         * * Created with IntelliJ IDEA.
         * * User: 0.0
         * * Date: 18-8-9
         * * Time: 下午2:32
         * * Description:为了得到个性化推荐搜索效果,我们计算用户向量与每个产品特征向量的相似度。
         *          相似度越高,最后得到的分值越高,排序越靠前.
         * 
         */
    
        public class FeatureVectorScoreSearchScript implements ScriptEngine {
            private final static Logger logger = LogManager.getLogger(FeatureVectorScoreSearchScript.class);
            @Override
            public String getType() {
                return "feature_vector_scoring_script";
            }
        @Override
        public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
            logger.info("The feature_vector_scoring_script is calculating the similarity of users and commodities");
            if (!context.equals(SearchScript.CONTEXT)) {
                throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
            }
            if("whb_fvs".equals(scriptSource)) {
                SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {
                    // 对入参检查
                    final Map<String, Object> inputFeatureVector;
                    final String field;
                    {
                        if (p.containsKey("field") == false) {
                            throw new IllegalArgumentException("Missing parameter [field]");
                        }
                        if(p.containsKey("inputFeatureVector") == false){
                            throw new IllegalArgumentException("Missing parameter [inputFeatureVector]");
                        }
                        field = p.get("field").toString();
                        inputFeatureVector = (Map<String,Object>) p.get("inputFeatureVector");
    
                    }
                    @Override
                    public SearchScript newInstance(LeafReaderContext context) throws IOException {
                        return new SearchScript(p, lookup, context) {
                            @Override
                            public double runAsDouble() {
                                if(lookup.source().containsKey(field)==true){
                                    final Map<String, Double> productFeatureVector = (Map<String, Double>) lookup.source().get(field);
                                    return calculateVectorSimilarity(inputFeatureVector, productFeatureVector);
                                }else {
                                    logger.info("The " + field + " is not exist in the product");
                                    return 0.0D;
                                }
                            }
                        };
                    }
    
                    @Override
                    public boolean needs_score() {
                        return false;
                    }
                };
                return context.factoryClazz.cast(factory);
            }throw new IllegalArgumentException("Unknown script name " + scriptSource);
    
        }
    
        @Override
        public void close() {
        }
    
        //计算两个向量的相似度(cos)
        public double calculateVectorSimilarity(Map<String, Object> inputFeatureVector , Map<String,Double> productFeatureVector){
            double sumOfProduct = 0.0D;
            double sumOfUser = 0.0D;
            double sumOfSquare = 0.0D;
            if(inputFeatureVector!=null && productFeatureVector!=null){
                for(Map.Entry<String, Object> entry: inputFeatureVector.entrySet()){
                    String dimName = entry.getKey();
                    double dimScore = Double.parseDouble(entry.getValue().toString());
                    double itemDimScore = productFeatureVector.get(dimName);
                    sumOfUser += dimScore*dimScore;
                    sumOfProduct += itemDimScore*itemDimScore;
                    sumOfSquare += dimScore*itemDimScore;
                }
                if(sumOfUser*sumOfProduct==0.0D){
                    return 0.0D;
                }
                return sumOfSquare / (Math.sqrt(sumOfUser)*Math.sqrt(sumOfProduct));
            }else {
                return 0.0D;
            }
        }
    
        }
    
    

    脚本二(fast-vector-distance)

    
    /**
     * * Created with IntelliJ IDEA.
     * * User: 王火斌
     * * Date: 18-8-9
     * * Time: 下午2:32
     * * Description:为了得到个性化推荐搜索效果,我们计算用户向量与每个产品特征向量的相似度。
     *          相似度越高,最后得到的分值越高,排序越靠前.
     * 
     */
    /**
    package com;
    import org.apache.logging.log4j.LogManager;
    import org.apache.logging.log4j.Logger;
    import org.apache.lucene.index.LeafReaderContext;
    import org.elasticsearch.common.settings.Settings;
    import org.elasticsearch.plugins.Plugin;
    import org.elasticsearch.plugins.ScriptPlugin;
    import org.elasticsearch.script.ScriptContext;
    import org.elasticsearch.script.ScriptEngine;
    import org.elasticsearch.script.SearchScript;
    import org.apache.lucene.index.BinaryDocValues;
    import org.apache.lucene.store.ByteArrayDataInput;
    import java.io.IOException;
    import java.nio.ByteBuffer;
    import java.nio.DoubleBuffer;
    import java.util.*;
    
     * This class is instantiated when Elasticsearch loads the plugin for the
     * first time. If you change the name of this plugin, make sure to update
     * src/main/resources/es-plugin.properties file that points to this class.
     */
    public final class FastVectorDistance extends Plugin implements ScriptPlugin {
    
        @Override
        public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) {
            return new FastVectorDistanceEngine();
        }
    
        private static class FastVectorDistanceEngine implements ScriptEngine {
            private final static Logger logger = LogManager.getLogger(FastVectorDistance.class);
            private static final int DOUBLE_SIZE = 8;
    
            double queryVectorNorm;
    
            @Override
            public String getType() {
                return "feature_vector_scoring_script";
            }
    
            @Override
            public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
                logger.info("The feature_vector_scoring_script is calculating the similarity of users and commodities");
                if (!context.equals(SearchScript.CONTEXT)) {
                    throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
                }
                if ("whb_fvd".equals(scriptSource)) {
                    SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {
                        // The field to compare against
                        final String field;
                        //Whether this search should be cosine or dot product
                        final Boolean cosine;
                        //The query embedded vector
                        final Object vector;
                        Boolean exclude;
                        //The final comma delimited vector representation of the query vector
                        double[] inputVector;
    
                        {
                            if (p.containsKey("field") == false) {
                                throw new IllegalArgumentException("Missing parameter [field]");
                            }
    
                            //Determine if cosine
                            final Object cosineBool = p.get("cosine");
                            cosine = cosineBool != null ? (boolean) cosineBool : true;
    
                            //Get the field value from the query
                            field = p.get("field").toString();
    
                            final Object excludeBool = p.get("exclude");
                            exclude = excludeBool != null ? (boolean) cosineBool : true;
    
                            //Get the query vector embedding
                            vector = p.get("vector");
    
                            //Determine if raw comma-delimited vector or embedding was passed
                            if (vector != null) {
                                final ArrayList<Double> tmp = (ArrayList<Double>) vector;
                                inputVector = new double[tmp.size()];
                                for (int i = 0; i < inputVector.length; i++) {
                                    inputVector[i] = tmp.get(i);
                                }
                            } else {
                                final Object encodedVector = p.get("encoded_vector");
                                if (encodedVector == null) {
                                    throw new IllegalArgumentException("Must have 'vector' or 'encoded_vector' as a parameter");
                                }
                                inputVector = Util.convertBase64ToArray((String) encodedVector);
                            }
    
                            //If cosine calculate the query vec norm
                            if (cosine) {
                                queryVectorNorm = 0d;
                                // compute query inputVector norm once
                                for (double v : inputVector) {
                                    queryVectorNorm += Math.pow(v, 2.0);
                                }
                            }
                        }
    
                        @Override
                        public SearchScript newInstance(LeafReaderContext context) throws IOException {
    
                            return new SearchScript(p, lookup, context) {
                                Boolean is_value = false;
    
                                // Use Lucene LeafReadContext to access binary values directly.
                                BinaryDocValues accessor = context.reader().getBinaryDocValues(field);
    
                                @Override
                                public void setDocument(int docId) {
                                    // advance has undefined behavior calling with a docid <= its current docid
                                    try {
                                        accessor.advanceExact(docId);
                                        is_value = true;
                                    } catch (IOException e) {
                                        is_value = false;
                                    }
                                }
    
    
                                @Override
                                public double runAsDouble() {
    
                                    //If there is no field value return 0 rather than fail.
                                    if (!is_value) return 0.0d;
    
                                    final int inputVectorSize = inputVector.length;
                                    final byte[] bytes;
    
                                    try {
                                        bytes = accessor.binaryValue().bytes;
                                    } catch (IOException e) {
                                        return 0d;
                                    }
    
    
                                    final ByteArrayDataInput byteDocVector = new ByteArrayDataInput(bytes);
    
                                    byteDocVector.readVInt();
    
                                    final int docVectorLength = byteDocVector.readVInt(); // returns the number of bytes to read
    
                                    if (docVectorLength != inputVectorSize * DOUBLE_SIZE) {
                                        return 0d;
                                    }
    
                                    final int position = byteDocVector.getPosition();
    
                                    final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, docVectorLength).asDoubleBuffer();
    
                                    final double[] docVector = new double[inputVectorSize];
    
                                    doubleBuffer.get(docVector);
    
                                    double docVectorNorm = 0d;
                                    double score = 0d;
    
                                    //calculate dot product of document vector and query vector
                                    for (int i = 0; i < inputVectorSize; i++) {
    
                                        score += docVector[i] * inputVector[i];
    
                                        if (cosine) {
                                            docVectorNorm += Math.pow(docVector[i], 2.0);
                                        }
                                    }
    
                                    //If cosine, calcluate cosine score
                                    if (cosine) {
    
                                        if (docVectorNorm == 0 || queryVectorNorm == 0) return 0d;
    
                                        score = score / (Math.sqrt(docVectorNorm) * Math.sqrt(queryVectorNorm));
                                    }
    
                                    return score;
                                }
                            };
                        }
    
                        @Override
                        public boolean needs_score() {
                            return false;
                        }
                    };
                    return context.factoryClazz.cast(factory);
                }
                throw new IllegalArgumentException("Unknown script name " + scriptSource);
            }
    
            @Override
            public void close() {}
        }
    }
    
    

    部署

    通过maven来部署,具体部署步骤如下:

    1. 配置pom文件
      加载依赖类,设置项目创建目录。


      4.0.0
      es-plugin
      elasticsearch-plugin
      1.0-SNAPSHOT

       <dependencies>
           <dependency>
               <groupId>org.elasticsearch</groupId>
               <artifactId>elasticsearch</artifactId>
               <version>6.1.1</version>
           </dependency>
           <dependency>
               <groupId>junit</groupId>
               <artifactId>junit</artifactId>
               <version>4.12</version>
               <scope>test</scope>
           </dependency>
       </dependencies>
       <build>
           <plugins>
               <plugin>
                   <artifactId>maven-assembly-plugin</artifactId>
                   <version>2.3</version>
                   <configuration>
                       <appendAssemblyId>false</appendAssemblyId>
                       <outputDirectory>${project.build.directory}/releases/</outputDirectory>
                       <descriptors>
                           <descriptor>${basedir}/src/assembly/plugin.xml</descriptor>
                       </descriptors>
                   </configuration>
                   <executions>
                       <execution>
                           <phase>package</phase>
                           <goals>
                               <goal>single</goal>
                           </goals>
                       </execution>
                   </executions>
               </plugin>
               <plugin>
                   <groupId>org.apache.maven.plugins</groupId>
                   <artifactId>maven-compiler-plugin</artifactId>
                   <configuration>
                       <source>1.8</source>
                       <target>1.8</target>
                   </configuration>
               </plugin>
           </plugins>
       </build>
      

    2.创建xml文件

    <?xml version="1.0"?>
    <assembly>
        <id>plugin</id>
        <formats>
            <format>zip</format>
        </formats>
        <includeBaseDirectory>false</includeBaseDirectory>
        <fileSets>
            <fileSet>
                <directory>${project.basedir}/src/main/resources</directory>
                <outputDirectory>feature-vector-score</outputDirectory>
            </fileSet>
        </fileSets>
        <dependencySets>
            <dependencySet>
                <outputDirectory>feature-vector-score</outputDirectory>
                <useProjectArtifact>true</useProjectArtifact>
                <useTransitiveFiltering>true</useTransitiveFiltering>
                <excludes>
                    <exclude>org.elasticsearch:elasticsearch</exclude>
                    <exclude>org.apache.logging.log4j:log4j-api</exclude>
                </excludes>
            </dependencySet>
        </dependencySets>
    </assembly>
    

    3.创建plugin-descriptor.properties文件

    description=feature-vector-similarity
    version=1.0
    name=feature-vector-score
    site=${elasticsearch.plugin.site}
    jvm=true
    classname=com.FeatureVectorScoreSearchPlugin
    java.version=1.8
    elasticsearch.version=6.1.1
    isolated=${elasticsearch.plugin.isolated}
    

    description:simple summary of the plugin
    version(String):plugin’s version
    name(String):the plugin name
    classname(String):the name of the class to load, fully-qualified.
    java.version(String):version of java the code is built against. Use the system property java.specification.version. Version string must be a sequence of nonnegative decimal integers separated by "."'s and may have leading zeros.

    测试

    创建索引

    create_index = {
        "settings": {
            "analysis": {
                "analyzer": {
                    # this configures the custom analyzer we need to parse vectors such that the scoring
                    # plugin will work correctly
                    "payload_analyzer": {
                        "type": "custom",
                        "tokenizer":"whitespace",
                        "filter":"delimited_payload_filter"
                    }
                }
            }
        },
        "mappings": {
               "movies": {
                # this mapping definition sets up the metadata fields for the movies
                "properties": {
                    "movieId": {
                        "type": "integer"
                    },
                    "tmdbId": {
                        "type": "keyword"
                    },
                    "genres": {
                        "type": "keyword"
                    },
                    "release_date": {
                        "type": "date",
                        "format": "year"
                    },
                    "@model": {
                        # this mapping definition sets up the fields for movie factor vectors of our model
                        "properties": {
                            "factor": {
                                "type": "binary",
                                "doc_values": true
                            },
                            "version": {
                                "type": "keyword"
                            },
                            "timestamp": {
                                "type": "date"
                            }
                        }
                    }
                }}
    }}
    

    查询

    You can execute the script by specifying its lang as expert_scripts, and the name of the script as the script source:

    {
      "query": {
         
         "function_score": {
          "query": {
            "match_all": {  
            }
          },
            "functions": [
              {
                "script_score": {
                  "script": {
                      "source": "whb_fvd",
                      "lang" : "feature_vector_scoring_script",
                      "params": {
                          "field": "@model.factor",
                          "cosine": true,
                          "encoded_vector" :"v9EUmGAAAAC/6f9VAAAAAL/j+OOgAAAAv+m6+oAAAAA/lTSDIAAAAL/FdkTAAAAAv7rKHKAAAAA/0iyEYAAAAD/ZUY6gAAAAP7TzYoAAAAA/1K4IAAAAAD+yH9XgAAAAv6QRBSAAAAA/vRiiwAAAAL/mRhzgAAAAv9WxpiAAAAC/8YD+QAAAAL/jpbtgAAAAv+zmD+AAAAC/1eqtIAAAAA==" 
                      }
                  }
                }
              }
            ]
        }
      }
    }
    

    版本说明

    在最近一年中,es版本迭代速度很快,上述插件主要使用了SearchScript类适用于v5.4-v6.4。在esv5.4以下的版本,主要使用ExecutableScript类。对于es大于6.4版本,出现了一个新类ScoreScript来实现自定义评分脚本。

    项目详细见github

    https://github.com/SnailWhb/elasticsearch_pulgine_fast-vector-distance

    参考文献

    [1]https://static.javadoc.io/org.elasticsearch/elasticsearch/6.0.1/org/elasticsearch/script/ScriptEngine.html
    [2]https://www.elastic.co/guide/en/elasticsearch/reference/current/modules-scripting-engine.html
    [3]https://github.com/jiashiwen/elasticsearchpluginsample
    [4]https://www.elastic.co/guide/en/elasticsearch/plugins/6.3/plugin-authors.html

  • 相关阅读:
    蓝牙协议分析(11)_BLE安全机制之SM
    蓝牙协议分析(10)_BLE安全机制之LE Encryption
    蓝牙协议分析(9)_BLE安全机制之LL Privacy
    蓝牙协议分析(8)_BLE安全机制之白名单
    蓝牙协议分析(7)_BLE连接有关的技术分析
    蓝牙协议分析(6)_BLE地址类型
    蓝牙协议分析(5)_BLE广播通信相关的技术分析
    蓝牙协议分析(4)_IPv6 Over BLE介绍
    蓝牙协议分析(3)_BLE协议栈介绍
    ActiveMq
  • 原文地址:https://www.cnblogs.com/whb-20160329/p/10472717.html
Copyright © 2011-2022 走看看