zoukankan      html  css  js  c++  java
  • ElasticSearch 工具类封装(基于ElasticsearchTemplate)

    1.抽象接口定义

      1 public abstract class SearchQueryEngine<T> {
      2 
      3     @Autowired
      4     protected ElasticsearchTemplate elasticsearchTemplate;
      5 
      6     public abstract int saveOrUpdate(List<T> list);
      7 
      8     public abstract <R> List<R> aggregation(T query, Class<R> clazz);
      9 
     10     public abstract <R> Page<R> scroll(T query, Class<R> clazz, Pageable pageable, ScrollId scrollId);
     11 
     12     public abstract <R> List<R> find(T query, Class<R> clazz, int size);
     13 
     14     public abstract <R> Page<R> find(T query, Class<R> clazz, Pageable pageable);
     15 
     16     public abstract <R> R sum(T query, Class<R> clazz);
     17 
     18     protected Document getDocument(T t) {
     19         Document annotation = t.getClass().getAnnotation(Document.class);
     20         if (annotation == null) {
     21             throw new SearchQueryBuildException("Can't find annotation @Document on " + t.getClass().getName());
     22         }
     23         return annotation;
     24     }
     25 
     26     /**
     27      * 获取字段名,若设置column则返回该值
     28      *
     29      * @param field
     30      * @param column
     31      * @return
     32      */
     33     protected String getFieldName(Field field, String column) {
     34         return StringUtils.isNotBlank(column) ? column : field.getName();
     35     }
     36 
     37     /**
     38      * 设置属性值
     39      *
     40      * @param field
     41      * @param obj
     42      * @param value
     43      */
     44     protected void setFieldValue(Field field, Object obj, Object value) {
     45         boolean isAccessible = field.isAccessible();
     46         field.setAccessible(true);
     47         try {
     48             switch (field.getType().getSimpleName()) {
     49                 case "BigDecimal":
     50                     field.set(obj, new BigDecimal(value.toString()).setScale(5, BigDecimal.ROUND_HALF_UP));
     51                     break;
     52                 case "Long":
     53                     field.set(obj, new Long(value.toString()));
     54                     break;
     55                 case "Integer":
     56                     field.set(obj, new Integer(value.toString()));
     57                     break;
     58                 case "Date":
     59                     field.set(obj, new Date(Long.valueOf(value.toString())));
     60                     break;
     61                 default:
     62                     field.set(obj, value);
     63             }
     64         } catch (IllegalAccessException e) {
     65             throw new SearchQueryBuildException(e);
     66         } finally {
     67             field.setAccessible(isAccessible);
     68         }
     69     }
     70 
     71     /**
     72      * 获取字段值
     73      *
     74      * @param field
     75      * @param obj
     76      * @return
     77      */
     78     protected Object getFieldValue(Field field, Object obj) {
     79         boolean isAccessible = field.isAccessible();
     80         field.setAccessible(true);
     81         try {
     82             return field.get(obj);
     83         } catch (IllegalAccessException e) {
     84             throw new SearchQueryBuildException(e);
     85         } finally {
     86             field.setAccessible(isAccessible);
     87         }
     88     }
     89 
     90     /**
     91      * 转换为es识别的value值
     92      *
     93      * @param value
     94      * @return
     95      */
     96     protected Object formatValue(Object value) {
     97         if (value instanceof Date) {
     98             return ((Date) value).getTime();
     99         } else {
    100             return value;
    101         }
    102     }
    103 
    104     /**
    105      * 获取索引分区数
    106      *
    107      * @param t
    108      * @return
    109      */
    110     protected int getNumberOfShards(T t) {
    111         return Integer.parseInt(elasticsearchTemplate.getSetting(getDocument(t).index()).get(IndexMetaData.SETTING_NUMBER_OF_SHARDS).toString());
    112     }
    113 }

    2.接口实现

      1 @Component
      2 @ComponentScan
      3 public class SimpleSearchQueryEngine<T> extends SearchQueryEngine<T> {
      4 
      5     private int numberOfRowsPerScan = 10;
      6 
      7     @Override
      8     public int saveOrUpdate(List<T> list) {
      9         if (CollectionUtils.isEmpty(list)) {
     10             return 0;
     11         }
     12 
     13         T base = list.get(0);
     14         Field id = null;
     15         for (Field field : base.getClass().getDeclaredFields()) {
     16             BusinessID businessID = field.getAnnotation(BusinessID.class);
     17             if (businessID != null) {
     18                 id = field;
     19                 break;
     20             }
     21         }
     22         if (id == null) {
     23             throw new SearchQueryBuildException("Can't find @BusinessID on " + base.getClass().getName());
     24         }
     25 
     26         Document document = getDocument(base);
     27         List<UpdateQuery> bulkIndex = new ArrayList<>();
     28         for (T t : list) {
     29             UpdateQuery updateQuery = new UpdateQuery();
     30             updateQuery.setIndexName(document.index());
     31             updateQuery.setType(document.type());
     32             updateQuery.setId(getFieldValue(id, t).toString());
     33             updateQuery.setUpdateRequest(new UpdateRequest(updateQuery.getIndexName(), updateQuery.getType(), updateQuery.getId()).doc(JSONObject.toJSONString(t, SerializerFeature.WriteMapNullValue)));
     34             updateQuery.setDoUpsert(true);
     35             updateQuery.setClazz(t.getClass());
     36             bulkIndex.add(updateQuery);
     37         }
     38         elasticsearchTemplate.bulkUpdate(bulkIndex);
     39         return list.size();
     40     }
     41 
     42     @Override
     43     public <R> List<R> aggregation(T query, Class<R> clazz) {
     44         NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query);
     45         nativeSearchQueryBuilder.addAggregation(buildGroupBy(query));
     46         Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor());
     47         try {
     48             return transformList(null, aggregations, clazz.newInstance(), new ArrayList());
     49         } catch (Exception e) {
     50             throw new SearchResultBuildException(e);
     51         }
     52     }
     53 
     54     /**
     55      * 将Aggregations转为List
     56      *
     57      * @param terms
     58      * @param aggregations
     59      * @param baseObj
     60      * @param resultList
     61      * @param <R>
     62      * @return
     63      * @throws NoSuchFieldException
     64      * @throws IllegalAccessException
     65      * @throws InstantiationException
     66      */
     67     private <R> List<R> transformList(Aggregation terms, Aggregations aggregations, R baseObj, List<R> resultList) throws NoSuchFieldException, IllegalAccessException, InstantiationException {
     68         for (String column : aggregations.asMap().keySet()) {
     69             Aggregation childAggregation = aggregations.get(column);
     70             if (childAggregation instanceof InternalSum) {
     71                 // 使用@Sum
     72                 if (!(terms instanceof InternalSum)) {
     73                     R targetObj = (R) baseObj.getClass().newInstance();
     74                     BeanUtils.copyProperties(baseObj, targetObj);
     75                     resultList.add(targetObj);
     76                 }
     77                 setFieldValue(baseObj.getClass().getDeclaredField(column), resultList.get(resultList.size() - 1), ((InternalSum) childAggregation).getValue());
     78                 terms = childAggregation;
     79             } else {
     80                 Terms childTerms = (Terms) childAggregation;
     81                 for (Terms.Bucket bucket : childTerms.getBuckets()) {
     82                     if (CollectionUtils.isEmpty(bucket.getAggregations().asList())) {
     83                         // 未使用@Sum
     84                         R targetObj = (R) baseObj.getClass().newInstance();
     85                         BeanUtils.copyProperties(baseObj, targetObj);
     86                         setFieldValue(targetObj.getClass().getDeclaredField(column), targetObj, bucket.getKey());
     87                         resultList.add(targetObj);
     88                     } else {
     89                         setFieldValue(baseObj.getClass().getDeclaredField(column), baseObj, bucket.getKey());
     90                         transformList(childTerms, bucket.getAggregations(), baseObj, resultList);
     91                     }
     92                 }
     93             }
     94         }
     95         return resultList;
     96     }
     97 
     98     @Override
     99     public <R> Page<R> scroll(T query, Class<R> clazz, Pageable pageable, ScrollId scrollId) {
    100         if (pageable.getPageSize() % numberOfRowsPerScan > 0) {
    101             throw new SearchQueryBuildException("Page size must be an integral multiple of " + numberOfRowsPerScan);
    102         }
    103         SearchQuery searchQuery = buildNativeSearchQueryBuilder(query).withPageable(new PageRequest(pageable.getPageNumber(), numberOfRowsPerScan / getNumberOfShards(query), pageable.getSort())).build();
    104         if (StringUtils.isEmpty(scrollId.getValue())) {
    105             scrollId.setValue(elasticsearchTemplate.scan(searchQuery, 10000l, false));
    106         }
    107         Page<R> page = elasticsearchTemplate.scroll(scrollId.getValue(), 10000l, clazz);
    108         if (page == null || page.getContent().size() == 0) {
    109             elasticsearchTemplate.clearScroll(scrollId.getValue());
    110         }
    111         return page;
    112     }
    113 
    114     @Override
    115     public <R> List<R> find(T query, Class<R> clazz, int size) {
    116         // Caused by: QueryPhaseExecutionException[Result window is too large, from + size must be less than or equal to: [10000] but was [2147483647].
    117         // See the scroll api for a more efficient way to request large data sets. This limit can be set by changing the [index.max_result_window] index level parameter.]
    118         if (size % numberOfRowsPerScan > 0) {
    119             throw new SearchQueryBuildException("Parameter 'size' must be an integral multiple of " + numberOfRowsPerScan);
    120         }
    121         int pageNum = 0;
    122         List<R> result = new ArrayList<>();
    123         ScrollId scrollId = new ScrollId();
    124         while (true) {
    125             Page<R> page = scroll(query, clazz, new PageRequest(pageNum, numberOfRowsPerScan), scrollId);
    126             if (page != null && page.getContent().size() > 0) {
    127                 result.addAll(page.getContent());
    128             } else {
    129                 break;
    130             }
    131             if (result.size() >= size) {
    132                 break;
    133             } else {
    134                 pageNum++;
    135             }
    136         }
    137         elasticsearchTemplate.clearScroll(scrollId.getValue());
    138         return result;
    139     }
    140 
    141     @Override
    142     public <R> Page<R> find(T query, Class<R> clazz, Pageable pageable) {
    143         NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query).withPageable(pageable);
    144         return elasticsearchTemplate.queryForPage(nativeSearchQueryBuilder.build(), clazz);
    145     }
    146 
    147     @Override
    148     public <R> R sum(T query, Class<R> clazz) {
    149         NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query);
    150         for (SumBuilder sumBuilder : getSumBuilderList(query)) {
    151             nativeSearchQueryBuilder.addAggregation(sumBuilder);
    152         }
    153         Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor());
    154         try {
    155             return transformSumResult(aggregations, clazz);
    156         } catch (Exception e) {
    157             throw new SearchResultBuildException(e);
    158         }
    159     }
    160 
    161     private <R> R transformSumResult(Aggregations aggregations, Class<R> clazz) throws IllegalAccessException, InstantiationException, NoSuchFieldException {
    162         R targetObj = clazz.newInstance();
    163         for (Aggregation sum : aggregations.asList()) {
    164             if (sum instanceof InternalSum) {
    165                 setFieldValue(targetObj.getClass().getDeclaredField(sum.getName()), targetObj, ((InternalSum) sum).getValue());
    166             }
    167         }
    168         return targetObj;
    169     }
    170 
    171     private NativeSearchQueryBuilder buildNativeSearchQueryBuilder(T query) {
    172         Document document = getDocument(query);
    173         NativeSearchQueryBuilder nativeSearchQueryBuilder = new NativeSearchQueryBuilder()
    174                 .withIndices(document.index())
    175                 .withTypes(document.type());
    176 
    177         QueryBuilder whereBuilder = buildBoolQuery(query);
    178         if (whereBuilder != null) {
    179             nativeSearchQueryBuilder.withQuery(whereBuilder);
    180         }
    181 
    182         return nativeSearchQueryBuilder;
    183     }
    184 
    185     /**
    186      * 布尔查询构建
    187      *
    188      * @param query
    189      * @return
    190      */
    191     private BoolQueryBuilder buildBoolQuery(T query) {
    192         BoolQueryBuilder boolQueryBuilder = boolQuery();
    193         buildMatchQuery(boolQueryBuilder, query);
    194         buildRangeQuery(boolQueryBuilder, query);
    195         BoolQueryBuilder queryBuilder = boolQuery().must(boolQueryBuilder);
    196         return queryBuilder;
    197     }
    198 
    199     /**
    200      * and or 查询构建
    201      *
    202      * @param boolQueryBuilder
    203      * @param query
    204      */
    205     private void buildMatchQuery(BoolQueryBuilder boolQueryBuilder, T query) {
    206         Class clazz = query.getClass();
    207         for (Field field : clazz.getDeclaredFields()) {
    208             MatchQuery annotation = field.getAnnotation(MatchQuery.class);
    209             Object value = getFieldValue(field, query);
    210             if (annotation == null || value == null) {
    211                 continue;
    212             }
    213             if (Container.must.equals(annotation.container())) {
    214                 boolQueryBuilder.must(matchQuery(getFieldName(field, annotation.column()), formatValue(value)));
    215             } else if (should.equals(annotation.container())) {
    216                 if (value instanceof Collection) {
    217                     BoolQueryBuilder shouldQueryBuilder = boolQuery();
    218                     Collection tmp = (Collection) value;
    219                     for (Object obj : tmp) {
    220                         shouldQueryBuilder.should(matchQuery(getFieldName(field, annotation.column()), formatValue(obj)));
    221                     }
    222                     boolQueryBuilder.must(shouldQueryBuilder);
    223                 } else {
    224                     boolQueryBuilder.must(boolQuery().should(matchQuery(getFieldName(field, annotation.column()), formatValue(value))));
    225                 }
    226             }
    227         }
    228     }
    229 
    230     /**
    231      * 范围查询构建
    232      *
    233      * @param boolQueryBuilder
    234      * @param query
    235      */
    236     private void buildRangeQuery(BoolQueryBuilder boolQueryBuilder, T query) {
    237         Class clazz = query.getClass();
    238         for (Field field : clazz.getDeclaredFields()) {
    239             RangeQuery annotation = field.getAnnotation(RangeQuery.class);
    240             Object value = getFieldValue(field, query);
    241             if (annotation == null || value == null) {
    242                 continue;
    243             }
    244             if (Operator.gt.equals(annotation.operator())) {
    245                 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gt(formatValue(value)));
    246             } else if (Operator.gte.equals(annotation.operator())) {
    247                 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gte(formatValue(value)));
    248             } else if (Operator.lt.equals(annotation.operator())) {
    249                 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lt(formatValue(value)));
    250             } else if (Operator.lte.equals(annotation.operator())) {
    251                 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lte(formatValue(value)));
    252             }
    253         }
    254     }
    255 
    256     /**
    257      * Sum构建
    258      *
    259      * @param query
    260      * @return
    261      */
    262     private List<SumBuilder> getSumBuilderList(T query) {
    263         List<SumBuilder> list = new ArrayList<>();
    264         Class clazz = query.getClass();
    265         for (Field field : clazz.getDeclaredFields()) {
    266             Sum annotation = field.getAnnotation(Sum.class);
    267             if (annotation == null) {
    268                 continue;
    269             }
    270             list.add(AggregationBuilders.sum(field.getName()).field(field.getName()));
    271         }
    272         if (CollectionUtils.isEmpty(list)) {
    273             throw new SearchQueryBuildException("Can't find @Sum on " + clazz.getName());
    274         }
    275         return list;
    276     }
    277 
    278 
    279     /**
    280      * GroupBy构建
    281      *
    282      * @param query
    283      * @return
    284      */
    285     private TermsBuilder buildGroupBy(T query) {
    286         List<Field> sumList = new ArrayList<>();
    287         Object groupByCollection = null;
    288         Class clazz = query.getClass();
    289         for (Field field : clazz.getDeclaredFields()) {
    290             Sum sumAnnotation = field.getAnnotation(Sum.class);
    291             if (sumAnnotation != null) {
    292                 sumList.add(field);
    293             }
    294             GroupBy groupByannotation = field.getAnnotation(GroupBy.class);
    295             Object value = getFieldValue(field, query);
    296             if (groupByannotation == null || value == null) {
    297                 continue;
    298             } else if (!(value instanceof Collection)) {
    299                 throw new SearchQueryBuildException("GroupBy filed must be collection");
    300             } else if (CollectionUtils.isEmpty((Collection<String>) value)) {
    301                 continue;
    302             } else if (groupByCollection != null) {
    303                 throw new SearchQueryBuildException("Only one @GroupBy is allowed");
    304             } else {
    305                 groupByCollection = value;
    306             }
    307         }
    308         Iterator<String> iterator = ((Collection<String>) groupByCollection).iterator();
    309         TermsBuilder termsBuilder = recursiveAddAggregation(iterator, sumList);
    310         return termsBuilder;
    311     }
    312 
    313     /**
    314      * 添加Aggregation
    315      *
    316      * @param iterator
    317      * @return
    318      */
    319     private TermsBuilder recursiveAddAggregation(Iterator<String> iterator, List<Field> sumList) {
    320         String groupBy = iterator.next();
    321         TermsBuilder termsBuilder = AggregationBuilders.terms(groupBy).field(groupBy).size(0);
    322         if (iterator.hasNext()) {
    323             termsBuilder.subAggregation(recursiveAddAggregation(iterator, sumList));
    324         } else {
    325             for (Field field : sumList) {
    326                 termsBuilder.subAggregation(AggregationBuilders.sum(field.getName()).field(field.getName()));
    327             }
    328             sumList.clear();
    329         }
    330         return termsBuilder.order(Terms.Order.term(true));
    331     }

    3.存储scrollId值对象

    import lombok.Data;
    
    @Data
    public class ScrollId {
    
        private String value;
    
    }

    4.用于判断查询操作的枚举类

    public enum Operator {
        gt, gte, lt, lte
    }
    public enum Container {
        must, should
    }
  • 相关阅读:
    puppet之模板和类
    puppet之资源
    puppet自动化安装服务
    puppet自动化搭建lnmp架构
    puppet工简介一
    CDN杂谈
    cdn工作原理
    mysql之innodb存储引擎
    Android应用开发基础篇(11)-----ViewFlipper
    Android应用开发基础篇(10)-----Menu(菜单)
  • 原文地址:https://www.cnblogs.com/xiaochangwei/p/10280672.html
Copyright © 2011-2022 走看看