1、遇到Oracle Not In 无效的问题,原因是Not In里面的子查询结果有空值,需要过滤掉
2、Oracle Limit 1000的问题,自己按照Mybatis Plus的租户拦截器做了修改
1)、重点需要理解下表达树,这个刚好旁边大佬学历高,跟我普及了下二叉树用来做数学公式计算的原理
2)、需要写递归,拆分左右节点类型的,比如Or或And,然后也要拆包括号表达式,最终处理In表达式,其他类型的直接返回
3)、版本是3.4.2
import com.baomidou.mybatisplus.core.parser.SqlParserHelper; import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper; import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; import com.baomidou.mybatisplus.core.toolkit.PluginUtils; import com.baomidou.mybatisplus.core.toolkit.StringPool; import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport; import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import lombok.ToString; import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.NotExpression; import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.relational.ExistsExpression; import net.sf.jsqlparser.expression.operators.relational.ExpressionList; import net.sf.jsqlparser.expression.operators.relational.InExpression; import net.sf.jsqlparser.expression.operators.relational.ItemsList; import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.delete.Delete; import net.sf.jsqlparser.statement.insert.Insert; import net.sf.jsqlparser.statement.select.*; import net.sf.jsqlparser.statement.update.Update; import org.apache.commons.lang3.StringUtils; import org.apache.ibatis.executor.Executor; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.SqlCommandType; import org.apache.ibatis.session.ResultHandler; import org.apache.ibatis.session.RowBounds; import java.sql.Connection; import java.util.List; import java.util.stream.Collectors; /** * @author linjiabin * @since 1.0.0 */ @Data @NoArgsConstructor @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) public class OracleLimit1000InnerInterceptor extends JsqlParserSupport implements InnerInterceptor { @Override public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) { if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return; if (SqlParserHelper.getSqlParserInfo(ms)) return; PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql); try { String sql = mpBs.sql(); if (needRebuild(sql)) { mpBs.sql(parserSingle(sql, null)); } } catch (Exception e) { logger.error(e.getMessage(), e); } } private boolean needRebuild(String sql) { return StringUtils.isNotBlank(sql) && sql.toUpperCase().contains(" IN "); } @Override public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) { PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh); MappedStatement ms = mpSh.mappedStatement(); SqlCommandType sct = ms.getSqlCommandType(); if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) { if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return; if (SqlParserHelper.getSqlParserInfo(ms)) return; PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql(); try { String sql = mpBs.sql(); if (needRebuild(sql)) { String parserMulti = parserMulti(sql, null); if (sql.endsWith(StringPool.SEMICOLON) && !parserMulti.endsWith(StringPool.SEMICOLON)) { parserMulti += StringPool.SEMICOLON; } mpBs.sql(parserMulti); } } catch (Exception e) { logger.error(e.getMessage(), e); } } } @Override protected void processSelect(Select select, int index, String sql, Object obj) { processSelectBody(select.getSelectBody()); List<WithItem> withItemsList = select.getWithItemsList(); if (!CollectionUtils.isEmpty(withItemsList)) { withItemsList.forEach(this::processSelectBody); } } protected void processSelectBody(SelectBody selectBody) { if (selectBody == null) { return; } if (selectBody instanceof PlainSelect) { processPlainSelect((PlainSelect) selectBody); } else if (selectBody instanceof WithItem) { WithItem withItem = (WithItem) selectBody; processSelectBody(withItem.getSelectBody()); } else { SetOperationList operationList = (SetOperationList) selectBody; if (operationList.getSelects() != null && !operationList.getSelects().isEmpty()) { operationList.getSelects().forEach(this::processSelectBody); } } } @Override protected void processInsert(Insert insert, int index, String sql, Object obj) { // no do anything at insert } /** * update 语句处理 */ @Override protected void processUpdate(Update update, int index, String sql, Object obj) { update.setWhere(this.andExpression(update.getWhere())); } /** * delete 语句处理 */ @Override protected void processDelete(Delete delete, int index, String sql, Object obj) { delete.setWhere(this.andExpression(delete.getWhere())); } /** * delete update select 语句 where 处理 */ protected Expression andExpression(Expression where) { // 遇到左右表达式类型的,继续递归 if (where instanceof BinaryExpression) { BinaryExpression binaryExpression = (BinaryExpression) where; Expression rightExpression = binaryExpression.getRightExpression(); binaryExpression.setRightExpression(andExpression(rightExpression)); Expression leftExpression = binaryExpression.getLeftExpression(); binaryExpression.setLeftExpression(andExpression(leftExpression)); } // 遇到括号类型的,拆包递归 if (where instanceof Parenthesis) { Parenthesis parenthesis = (Parenthesis) where; return new Parenthesis(andExpression(parenthesis.getExpression())); } // 遇到in表达式的时候,尝试拆分 if (where instanceof InExpression) { return builderExpression((InExpression) where); } // 其他表达式直接返回 return where; } /** * 处理 PlainSelect */ protected void processPlainSelect(PlainSelect plainSelect) { FromItem fromItem = plainSelect.getFromItem(); Expression where = plainSelect.getWhere(); processWhereSubSelect(where); if (fromItem instanceof Table) { plainSelect.setWhere(builderExpression(where)); } else { processFromItem(fromItem); } List<Join> joins = plainSelect.getJoins(); if (joins != null && !joins.isEmpty()) { joins.forEach(j -> { processJoin(j); processFromItem(j.getRightItem()); }); } } /** * 处理where条件内的子查询 * <p> * 支持如下: * 1. in * 2. = * 3. > * 4. < * 5. >= * 6. <= * 7. <> * 8. EXISTS * 9. NOT EXISTS * <p> * 前提条件: * 1. 子查询必须放在小括号中 * 2. 子查询一般放在比较操作符的右边 * * @param where where 条件 */ protected void processWhereSubSelect(Expression where) { if (where == null) { return; } if (where instanceof FromItem) { processFromItem((FromItem) where); return; } if (where.toString().contains("SELECT")) { // 有子查询 if (where instanceof BinaryExpression) { // 比较符号 , and , or , 等等 BinaryExpression expression = (BinaryExpression) where; processWhereSubSelect(expression.getLeftExpression()); processWhereSubSelect(expression.getRightExpression()); } else if (where instanceof InExpression) { // in InExpression expression = (InExpression) where; ItemsList itemsList = expression.getRightItemsList(); if (itemsList instanceof SubSelect) { processSelectBody(((SubSelect) itemsList).getSelectBody()); } } else if (where instanceof ExistsExpression) { // exists ExistsExpression expression = (ExistsExpression) where; processWhereSubSelect(expression.getRightExpression()); } else if (where instanceof NotExpression) { // not exists NotExpression expression = (NotExpression) where; processWhereSubSelect(expression.getExpression()); } else if (where instanceof Parenthesis) { Parenthesis expression = (Parenthesis) where; processWhereSubSelect(expression.getExpression()); } } } /** * 处理子查询等 */ protected void processFromItem(FromItem fromItem) { if (fromItem instanceof SubJoin) { SubJoin subJoin = (SubJoin) fromItem; if (subJoin.getJoinList() != null) { subJoin.getJoinList().forEach(this::processJoin); } if (subJoin.getLeft() != null) { processFromItem(subJoin.getLeft()); } } else if (fromItem instanceof SubSelect) { SubSelect subSelect = (SubSelect) fromItem; if (subSelect.getSelectBody() != null) { processSelectBody(subSelect.getSelectBody()); } } else if (fromItem instanceof ValuesList) { logger.debug("Perform a subquery, if you do not give us feedback"); } else if (fromItem instanceof LateralSubSelect) { LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem; if (lateralSubSelect.getSubSelect() != null) { SubSelect subSelect = lateralSubSelect.getSubSelect(); if (subSelect.getSelectBody() != null) { processSelectBody(subSelect.getSelectBody()); } } } } /** * 处理联接语句 */ protected void processJoin(Join join) { if (join.getRightItem() instanceof Table) { join.setOnExpression(builderExpression(join.getOnExpression())); } } /** * 处理条件 */ protected Expression builderExpression(Expression currentExpression) { if (currentExpression == null) { return null; } return andExpression(currentExpression); } protected Expression builderExpression(InExpression inExpression) { Expression leftExpression = inExpression.getLeftExpression(); ItemsList rightItemsList = inExpression.getRightItemsList(); if (rightItemsList instanceof ExpressionList) { ExpressionList expressionList = (ExpressionList) rightItemsList; List<Expression> expressions = expressionList.getExpressions(); int size = expressions.size(); int limit = 1000; if (size > limit) { OrExpression root = new OrExpression(); int step = size / limit + 1; root.setLeftExpression(new InExpression(leftExpression, new ExpressionList(expressions.subList(0, limit)))); if (step == 2) { int toIndex = getToIndex(size, limit); root.setRightExpression(new InExpression(leftExpression, new ExpressionList(expressions.subList(limit, toIndex)))); return root; } OrExpression orExpression = new OrExpression(); root.setRightExpression(orExpression); for (int i = 1; i < step; i++) { List<Expression> segment = expressions.stream().skip((long) i * limit) .limit(limit).collect(Collectors.toList()); if (i == step - 2) { orExpression.setLeftExpression(new InExpression(leftExpression, new ExpressionList(segment))); List<Expression> last = expressions.stream().skip((long) (i + 1) * limit).collect(Collectors.toList()); orExpression.setRightExpression(new InExpression(leftExpression, new ExpressionList(last))); break; } else { OrExpression orExpression1 = new OrExpression(); orExpression.setLeftExpression(new InExpression(leftExpression, new ExpressionList(segment))); orExpression.setRightExpression(orExpression1); orExpression = orExpression1; } } return new Parenthesis(root); } } return inExpression; } private int getToIndex(int size, int limit) { int toIndex = limit * 2; if (toIndex > size) { toIndex = size; } return toIndex; } }
由于遇到了update 自动去除最后一个分号的问题,追加了一个判断和处理
并且将重组sql的方法用异常捕获包起来,避免因为重组sql导致出错而无法继续的情况
简单地过滤掉不需要解析的场景,目前仅判断是否包含in关键词