1.抽象接口定义
1 public abstract class SearchQueryEngine<T> { 2 3 @Autowired 4 protected ElasticsearchTemplate elasticsearchTemplate; 5 6 public abstract int saveOrUpdate(List<T> list); 7 8 public abstract <R> List<R> aggregation(T query, Class<R> clazz); 9 10 public abstract <R> Page<R> scroll(T query, Class<R> clazz, Pageable pageable, ScrollId scrollId); 11 12 public abstract <R> List<R> find(T query, Class<R> clazz, int size); 13 14 public abstract <R> Page<R> find(T query, Class<R> clazz, Pageable pageable); 15 16 public abstract <R> R sum(T query, Class<R> clazz); 17 18 protected Document getDocument(T t) { 19 Document annotation = t.getClass().getAnnotation(Document.class); 20 if (annotation == null) { 21 throw new SearchQueryBuildException("Can't find annotation @Document on " + t.getClass().getName()); 22 } 23 return annotation; 24 } 25 26 /** 27 * 获取字段名,若设置column则返回该值 28 * 29 * @param field 30 * @param column 31 * @return 32 */ 33 protected String getFieldName(Field field, String column) { 34 return StringUtils.isNotBlank(column) ? column : field.getName(); 35 } 36 37 /** 38 * 设置属性值 39 * 40 * @param field 41 * @param obj 42 * @param value 43 */ 44 protected void setFieldValue(Field field, Object obj, Object value) { 45 boolean isAccessible = field.isAccessible(); 46 field.setAccessible(true); 47 try { 48 switch (field.getType().getSimpleName()) { 49 case "BigDecimal": 50 field.set(obj, new BigDecimal(value.toString()).setScale(5, BigDecimal.ROUND_HALF_UP)); 51 break; 52 case "Long": 53 field.set(obj, new Long(value.toString())); 54 break; 55 case "Integer": 56 field.set(obj, new Integer(value.toString())); 57 break; 58 case "Date": 59 field.set(obj, new Date(Long.valueOf(value.toString()))); 60 break; 61 default: 62 field.set(obj, value); 63 } 64 } catch (IllegalAccessException e) { 65 throw new SearchQueryBuildException(e); 66 } finally { 67 field.setAccessible(isAccessible); 68 } 69 } 70 71 /** 72 * 获取字段值 73 * 74 * @param field 75 * @param obj 76 * @return 77 */ 78 protected Object getFieldValue(Field field, Object obj) { 79 boolean isAccessible = field.isAccessible(); 80 field.setAccessible(true); 81 try { 82 return field.get(obj); 83 } catch (IllegalAccessException e) { 84 throw new SearchQueryBuildException(e); 85 } finally { 86 field.setAccessible(isAccessible); 87 } 88 } 89 90 /** 91 * 转换为es识别的value值 92 * 93 * @param value 94 * @return 95 */ 96 protected Object formatValue(Object value) { 97 if (value instanceof Date) { 98 return ((Date) value).getTime(); 99 } else { 100 return value; 101 } 102 } 103 104 /** 105 * 获取索引分区数 106 * 107 * @param t 108 * @return 109 */ 110 protected int getNumberOfShards(T t) { 111 return Integer.parseInt(elasticsearchTemplate.getSetting(getDocument(t).index()).get(IndexMetaData.SETTING_NUMBER_OF_SHARDS).toString()); 112 } 113 }
2.接口实现
1 @Component 2 @ComponentScan 3 public class SimpleSearchQueryEngine<T> extends SearchQueryEngine<T> { 4 5 private int numberOfRowsPerScan = 10; 6 7 @Override 8 public int saveOrUpdate(List<T> list) { 9 if (CollectionUtils.isEmpty(list)) { 10 return 0; 11 } 12 13 T base = list.get(0); 14 Field id = null; 15 for (Field field : base.getClass().getDeclaredFields()) { 16 BusinessID businessID = field.getAnnotation(BusinessID.class); 17 if (businessID != null) { 18 id = field; 19 break; 20 } 21 } 22 if (id == null) { 23 throw new SearchQueryBuildException("Can't find @BusinessID on " + base.getClass().getName()); 24 } 25 26 Document document = getDocument(base); 27 List<UpdateQuery> bulkIndex = new ArrayList<>(); 28 for (T t : list) { 29 UpdateQuery updateQuery = new UpdateQuery(); 30 updateQuery.setIndexName(document.index()); 31 updateQuery.setType(document.type()); 32 updateQuery.setId(getFieldValue(id, t).toString()); 33 updateQuery.setUpdateRequest(new UpdateRequest(updateQuery.getIndexName(), updateQuery.getType(), updateQuery.getId()).doc(JSONObject.toJSONString(t, SerializerFeature.WriteMapNullValue))); 34 updateQuery.setDoUpsert(true); 35 updateQuery.setClazz(t.getClass()); 36 bulkIndex.add(updateQuery); 37 } 38 elasticsearchTemplate.bulkUpdate(bulkIndex); 39 return list.size(); 40 } 41 42 @Override 43 public <R> List<R> aggregation(T query, Class<R> clazz) { 44 NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query); 45 nativeSearchQueryBuilder.addAggregation(buildGroupBy(query)); 46 Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor()); 47 try { 48 return transformList(null, aggregations, clazz.newInstance(), new ArrayList()); 49 } catch (Exception e) { 50 throw new SearchResultBuildException(e); 51 } 52 } 53 54 /** 55 * 将Aggregations转为List 56 * 57 * @param terms 58 * @param aggregations 59 * @param baseObj 60 * @param resultList 61 * @param <R> 62 * @return 63 * @throws NoSuchFieldException 64 * @throws IllegalAccessException 65 * @throws InstantiationException 66 */ 67 private <R> List<R> transformList(Aggregation terms, Aggregations aggregations, R baseObj, List<R> resultList) throws NoSuchFieldException, IllegalAccessException, InstantiationException { 68 for (String column : aggregations.asMap().keySet()) { 69 Aggregation childAggregation = aggregations.get(column); 70 if (childAggregation instanceof InternalSum) { 71 // 使用@Sum 72 if (!(terms instanceof InternalSum)) { 73 R targetObj = (R) baseObj.getClass().newInstance(); 74 BeanUtils.copyProperties(baseObj, targetObj); 75 resultList.add(targetObj); 76 } 77 setFieldValue(baseObj.getClass().getDeclaredField(column), resultList.get(resultList.size() - 1), ((InternalSum) childAggregation).getValue()); 78 terms = childAggregation; 79 } else { 80 Terms childTerms = (Terms) childAggregation; 81 for (Terms.Bucket bucket : childTerms.getBuckets()) { 82 if (CollectionUtils.isEmpty(bucket.getAggregations().asList())) { 83 // 未使用@Sum 84 R targetObj = (R) baseObj.getClass().newInstance(); 85 BeanUtils.copyProperties(baseObj, targetObj); 86 setFieldValue(targetObj.getClass().getDeclaredField(column), targetObj, bucket.getKey()); 87 resultList.add(targetObj); 88 } else { 89 setFieldValue(baseObj.getClass().getDeclaredField(column), baseObj, bucket.getKey()); 90 transformList(childTerms, bucket.getAggregations(), baseObj, resultList); 91 } 92 } 93 } 94 } 95 return resultList; 96 } 97 98 @Override 99 public <R> Page<R> scroll(T query, Class<R> clazz, Pageable pageable, ScrollId scrollId) { 100 if (pageable.getPageSize() % numberOfRowsPerScan > 0) { 101 throw new SearchQueryBuildException("Page size must be an integral multiple of " + numberOfRowsPerScan); 102 } 103 SearchQuery searchQuery = buildNativeSearchQueryBuilder(query).withPageable(new PageRequest(pageable.getPageNumber(), numberOfRowsPerScan / getNumberOfShards(query), pageable.getSort())).build(); 104 if (StringUtils.isEmpty(scrollId.getValue())) { 105 scrollId.setValue(elasticsearchTemplate.scan(searchQuery, 10000l, false)); 106 } 107 Page<R> page = elasticsearchTemplate.scroll(scrollId.getValue(), 10000l, clazz); 108 if (page == null || page.getContent().size() == 0) { 109 elasticsearchTemplate.clearScroll(scrollId.getValue()); 110 } 111 return page; 112 } 113 114 @Override 115 public <R> List<R> find(T query, Class<R> clazz, int size) { 116 // Caused by: QueryPhaseExecutionException[Result window is too large, from + size must be less than or equal to: [10000] but was [2147483647]. 117 // See the scroll api for a more efficient way to request large data sets. This limit can be set by changing the [index.max_result_window] index level parameter.] 118 if (size % numberOfRowsPerScan > 0) { 119 throw new SearchQueryBuildException("Parameter 'size' must be an integral multiple of " + numberOfRowsPerScan); 120 } 121 int pageNum = 0; 122 List<R> result = new ArrayList<>(); 123 ScrollId scrollId = new ScrollId(); 124 while (true) { 125 Page<R> page = scroll(query, clazz, new PageRequest(pageNum, numberOfRowsPerScan), scrollId); 126 if (page != null && page.getContent().size() > 0) { 127 result.addAll(page.getContent()); 128 } else { 129 break; 130 } 131 if (result.size() >= size) { 132 break; 133 } else { 134 pageNum++; 135 } 136 } 137 elasticsearchTemplate.clearScroll(scrollId.getValue()); 138 return result; 139 } 140 141 @Override 142 public <R> Page<R> find(T query, Class<R> clazz, Pageable pageable) { 143 NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query).withPageable(pageable); 144 return elasticsearchTemplate.queryForPage(nativeSearchQueryBuilder.build(), clazz); 145 } 146 147 @Override 148 public <R> R sum(T query, Class<R> clazz) { 149 NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query); 150 for (SumBuilder sumBuilder : getSumBuilderList(query)) { 151 nativeSearchQueryBuilder.addAggregation(sumBuilder); 152 } 153 Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor()); 154 try { 155 return transformSumResult(aggregations, clazz); 156 } catch (Exception e) { 157 throw new SearchResultBuildException(e); 158 } 159 } 160 161 private <R> R transformSumResult(Aggregations aggregations, Class<R> clazz) throws IllegalAccessException, InstantiationException, NoSuchFieldException { 162 R targetObj = clazz.newInstance(); 163 for (Aggregation sum : aggregations.asList()) { 164 if (sum instanceof InternalSum) { 165 setFieldValue(targetObj.getClass().getDeclaredField(sum.getName()), targetObj, ((InternalSum) sum).getValue()); 166 } 167 } 168 return targetObj; 169 } 170 171 private NativeSearchQueryBuilder buildNativeSearchQueryBuilder(T query) { 172 Document document = getDocument(query); 173 NativeSearchQueryBuilder nativeSearchQueryBuilder = new NativeSearchQueryBuilder() 174 .withIndices(document.index()) 175 .withTypes(document.type()); 176 177 QueryBuilder whereBuilder = buildBoolQuery(query); 178 if (whereBuilder != null) { 179 nativeSearchQueryBuilder.withQuery(whereBuilder); 180 } 181 182 return nativeSearchQueryBuilder; 183 } 184 185 /** 186 * 布尔查询构建 187 * 188 * @param query 189 * @return 190 */ 191 private BoolQueryBuilder buildBoolQuery(T query) { 192 BoolQueryBuilder boolQueryBuilder = boolQuery(); 193 buildMatchQuery(boolQueryBuilder, query); 194 buildRangeQuery(boolQueryBuilder, query); 195 BoolQueryBuilder queryBuilder = boolQuery().must(boolQueryBuilder); 196 return queryBuilder; 197 } 198 199 /** 200 * and or 查询构建 201 * 202 * @param boolQueryBuilder 203 * @param query 204 */ 205 private void buildMatchQuery(BoolQueryBuilder boolQueryBuilder, T query) { 206 Class clazz = query.getClass(); 207 for (Field field : clazz.getDeclaredFields()) { 208 MatchQuery annotation = field.getAnnotation(MatchQuery.class); 209 Object value = getFieldValue(field, query); 210 if (annotation == null || value == null) { 211 continue; 212 } 213 if (Container.must.equals(annotation.container())) { 214 boolQueryBuilder.must(matchQuery(getFieldName(field, annotation.column()), formatValue(value))); 215 } else if (should.equals(annotation.container())) { 216 if (value instanceof Collection) { 217 BoolQueryBuilder shouldQueryBuilder = boolQuery(); 218 Collection tmp = (Collection) value; 219 for (Object obj : tmp) { 220 shouldQueryBuilder.should(matchQuery(getFieldName(field, annotation.column()), formatValue(obj))); 221 } 222 boolQueryBuilder.must(shouldQueryBuilder); 223 } else { 224 boolQueryBuilder.must(boolQuery().should(matchQuery(getFieldName(field, annotation.column()), formatValue(value)))); 225 } 226 } 227 } 228 } 229 230 /** 231 * 范围查询构建 232 * 233 * @param boolQueryBuilder 234 * @param query 235 */ 236 private void buildRangeQuery(BoolQueryBuilder boolQueryBuilder, T query) { 237 Class clazz = query.getClass(); 238 for (Field field : clazz.getDeclaredFields()) { 239 RangeQuery annotation = field.getAnnotation(RangeQuery.class); 240 Object value = getFieldValue(field, query); 241 if (annotation == null || value == null) { 242 continue; 243 } 244 if (Operator.gt.equals(annotation.operator())) { 245 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gt(formatValue(value))); 246 } else if (Operator.gte.equals(annotation.operator())) { 247 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gte(formatValue(value))); 248 } else if (Operator.lt.equals(annotation.operator())) { 249 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lt(formatValue(value))); 250 } else if (Operator.lte.equals(annotation.operator())) { 251 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lte(formatValue(value))); 252 } 253 } 254 } 255 256 /** 257 * Sum构建 258 * 259 * @param query 260 * @return 261 */ 262 private List<SumBuilder> getSumBuilderList(T query) { 263 List<SumBuilder> list = new ArrayList<>(); 264 Class clazz = query.getClass(); 265 for (Field field : clazz.getDeclaredFields()) { 266 Sum annotation = field.getAnnotation(Sum.class); 267 if (annotation == null) { 268 continue; 269 } 270 list.add(AggregationBuilders.sum(field.getName()).field(field.getName())); 271 } 272 if (CollectionUtils.isEmpty(list)) { 273 throw new SearchQueryBuildException("Can't find @Sum on " + clazz.getName()); 274 } 275 return list; 276 } 277 278 279 /** 280 * GroupBy构建 281 * 282 * @param query 283 * @return 284 */ 285 private TermsBuilder buildGroupBy(T query) { 286 List<Field> sumList = new ArrayList<>(); 287 Object groupByCollection = null; 288 Class clazz = query.getClass(); 289 for (Field field : clazz.getDeclaredFields()) { 290 Sum sumAnnotation = field.getAnnotation(Sum.class); 291 if (sumAnnotation != null) { 292 sumList.add(field); 293 } 294 GroupBy groupByannotation = field.getAnnotation(GroupBy.class); 295 Object value = getFieldValue(field, query); 296 if (groupByannotation == null || value == null) { 297 continue; 298 } else if (!(value instanceof Collection)) { 299 throw new SearchQueryBuildException("GroupBy filed must be collection"); 300 } else if (CollectionUtils.isEmpty((Collection<String>) value)) { 301 continue; 302 } else if (groupByCollection != null) { 303 throw new SearchQueryBuildException("Only one @GroupBy is allowed"); 304 } else { 305 groupByCollection = value; 306 } 307 } 308 Iterator<String> iterator = ((Collection<String>) groupByCollection).iterator(); 309 TermsBuilder termsBuilder = recursiveAddAggregation(iterator, sumList); 310 return termsBuilder; 311 } 312 313 /** 314 * 添加Aggregation 315 * 316 * @param iterator 317 * @return 318 */ 319 private TermsBuilder recursiveAddAggregation(Iterator<String> iterator, List<Field> sumList) { 320 String groupBy = iterator.next(); 321 TermsBuilder termsBuilder = AggregationBuilders.terms(groupBy).field(groupBy).size(0); 322 if (iterator.hasNext()) { 323 termsBuilder.subAggregation(recursiveAddAggregation(iterator, sumList)); 324 } else { 325 for (Field field : sumList) { 326 termsBuilder.subAggregation(AggregationBuilders.sum(field.getName()).field(field.getName())); 327 } 328 sumList.clear(); 329 } 330 return termsBuilder.order(Terms.Order.term(true)); 331 }
3.存储scrollId值对象
import lombok.Data; @Data public class ScrollId { private String value; }
4.用于判断查询操作的枚举类
public enum Operator { gt, gte, lt, lte }
public enum Container { must, should }