zoukankan      html  css  js  c++  java
  • MyBatis多租户隔离插件开发

    在SASS的大潮流下,相信依然存在很多使用一个数据库为多个租户提供服务的场景,这个情况下一般是多个租户共用同一套表通过sql语句级别来隔离不同租户的资源,比如设置一个租户标识字段,每次查询的时候在后面附加一个筛选条件:TenantId=xxx。这样能低代价、简单地实现多租户服务,但是每次执行sql的时候需要附加字段隔离,否则会出现数据错乱。

    此隔离过程应该自动标识完成,所以我今天借助于Mybatis的插件机制来完成一个多租户sql隔离插件。

    一、设计需求

    1、首先,我们需要一种方案来识别哪些表需要使用多租户隔离,并且确定多租户隔离字段名称。

    2、然后拦截mybatis执行过程中的prepare方法,通过改写加入多租户隔离条件,然后替换为我们新的sql。

    3、寻找一种方法能多层次的智能的为识别到的数据表添加condition,毕竟CRUD过程都会存在子查询,并且不会丢失原有的where条件。

    二、设计思路

    对于需求1,我们可以定义一个条件字段决策器,用来决策某个表是否需要添加多租户过滤条件,比如定义一个接口:ITableFieldConditionDecision

    /**
     * 表字段条件决策器
     * 用于决策某个表是否需要添加某个字段过滤条件
     *
     * @author liushuishang@gmail.com
     * @date 2017/12/23 15:49
     **/
    public interface ITableFieldConditionDecision {
    
        /**
         * 条件字段是否运行null值
         * @return
         */
        boolean isAllowNullValue();
        /**
         * 判决某个表是否需要添加某个字段过滤
         *
         * @param tableName   表名称
         * @param fieldName   字段名称
         * @return
         */
        boolean adjudge(String tableName, String fieldName);
    
    }

    然后在使用插件的地方填写必要的参数来初始化决策器

    <!--多租户隔离插件-->
                    <bean class="com.smartdata360.smartfx.dao.plugin.MultiTenantPlugin">
                        <property name="properties">
                            <value>
                                <!--当前数据库方言-->
                                dialect=postgresql
                                <!--多租户隔离字段名称-->
                                tenantIdField=domain
                                <!--需要隔离的表名称java正则表达式-->
                                tablePattern=uam_*
                                <!--需要隔离的表名称,逗号分隔-->
                                tableSet=uam_user,uam_role
                            </value>
                        </property>
                    </bean>

    对于需求2,我们开发一个Mybatis的拦截器:MultiTenantPlugin。抽取出将要预编译的sql语句,加工后再替换,然后Mybatis最终执行的是我们加工过的sql语句。

    /**
     * 多租户数据隔离插件
     *
     * @author liushuishang@gmail.com
     * @date 2017/12/21 11:58
     **/
    @Intercepts({
            @Signature(type = StatementHandler.class,
                    method = "prepare",
                    args = {Connection.class})})
    public class MultiTenantPlugin extends BasePlugin

    对于需求3,我使用阿里Druid的sql parser模块来实现sql解析和condition附加。其大致过程如下:

    (1)把sql解析成一颗AST,基本每个部分都会有一个对象与之对应。

    (2)遍历AST,获取select、query和SQLExpr,抽取出表名称和别名,交给决策器判断是否需要添加多租户隔离条件。如果需要添加,则扩展原有condition加上多租户筛选条件;否则不做处理

    (3)把修改后的AST重新转成sql语句

    image

    执行结果:

    image

    三、代码参考

    import com.alibaba.druid.sql.SQLUtils;
    import com.alibaba.druid.sql.ast.SQLStatement;
    import com.smartdata360.smartfx.dao.extension.MultiTenantContent;
    import com.smartdata360.smartfx.dao.sqlparser.ITableFieldConditionDecision;
    import com.smartdata360.smartfx.dao.sqlparser.SqlConditionHelper;
    import org.apache.commons.lang3.StringUtils;
    import org.apache.ibatis.executor.statement.StatementHandler;
    import org.apache.ibatis.mapping.BoundSql;
    import org.apache.ibatis.plugin.Intercepts;
    import org.apache.ibatis.plugin.Invocation;
    import org.apache.ibatis.plugin.Signature;
    import org.apache.ibatis.reflection.MetaObject;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    import java.sql.Connection;
    import java.util.*;
    import java.util.regex.Pattern;
    
    /**
     * 多租户数据隔离插件
     *
     * @author liushuishang@gmail.com
     * @date 2017/12/21 11:58
     **/
    @Intercepts({
            @Signature(type = StatementHandler.class,
                    method = "prepare",
                    args = {Connection.class})})
    public class MultiTenantPlugin extends BasePlugin {
    
        private final Logger logger = LoggerFactory.getLogger(MultiTenantPlugin.class);
    
        /**
         * 当前数据库的方言
         */
        private String dialect;
        /**
         * 多租户字段名称
         */
        private String tenantIdField;
    
        /**
         * 需要识别多租户字段的表名称的正则表达式
         */
        private Pattern tablePattern;
    
        /**
         * 需要识别多租户字段的表名称列表
         */
        private Set<String> tableSet;
    
        private SqlConditionHelper conditionHelper;
    
    
        @Override
        public Object intercept(Invocation invocation) throws Throwable {
            String tenantId = MultiTenantContent.getCurrentTenantId();
            //租户id为空时不做处理
            if (StringUtils.isBlank(tenantId)) {
                return invocation.proceed();
            }
            StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
            BoundSql boundSql = statementHandler.getBoundSql();
            String newSql = addTenantCondition(boundSql.getSql(), tenantId);
            MetaObject boundSqlMeta = getMetaObject(boundSql);
            //把新sql设置到boundSql
            boundSqlMeta.setValue("sql", newSql);
    
            return invocation.proceed();
        }
    
        @Override
        public void setProperties(Properties properties) {
            dialect = properties.getProperty("dialect");
            if (StringUtils.isBlank(dialect))
                throw new IllegalArgumentException("MultiTenantPlugin need dialect property value");
            tenantIdField = properties.getProperty("tenantIdField");
            if (StringUtils.isBlank(tenantIdField))
                throw new IllegalArgumentException("MultiTenantPlugin need tenantIdField property value");
    
            String tableRegex = properties.getProperty("tableRegex");
            if (!StringUtils.isBlank(tableRegex))
                tablePattern = Pattern.compile(tableRegex);
    
            String tableNames = properties.getProperty("tableNames");
            if (!StringUtils.isBlank(tableNames)) {
                tableSet = new HashSet<String>(Arrays.asList(StringUtils.split(tableNames)));
            }
            if (tablePattern == null || tableSet == null)
                throw new IllegalArgumentException("MultiTenantPlugin tableRegex and tableNames must have one");
    
            /**
             * 多租户条件字段决策器
             */
            ITableFieldConditionDecision conditionDecision = new ITableFieldConditionDecision() {
                @Override
                public boolean isAllowNullValue() {
                    return false;
                }
                @Override
                public boolean adjudge(String tableName, String fieldName) {
                    if (tableRegex != null && tableRegex.matches(tableName)) return true;
                    if (tableSet != null && tableSet.contains(tableName)) return true;
                    return false;
                }
            };
            conditionHelper = new SqlConditionHelper(conditionDecision);
        }
    
    
        /**
         * 给sql语句where添加租户id过滤条件
         *
         * @param sql      要添加过滤条件的sql语句
         * @param tenantId 当前的租户id
         * @return 添加条件后的sql语句
         */
        private String addTenantCondition(String sql, String tenantId) {
            if (StringUtils.isBlank(sql) || StringUtils.isBlank(tenantIdField)) return sql;
            List<SQLStatement> statementList = SQLUtils.parseStatements(sql, dialect);
            if (statementList == null || statementList.size() == 0) return sql;
    
            SQLStatement sqlStatement = statementList.get(0);
            conditionHelper.addStatementCondition(sqlStatement, tenantIdField, tenantId);
            return SQLUtils.toSQLString(statementList, dialect);
        }
    
    }
    import com.alibaba.druid.sql.SQLUtils;
    import com.alibaba.druid.sql.ast.SQLExpr;
    import com.alibaba.druid.sql.ast.SQLStatement;
    import com.alibaba.druid.sql.ast.expr.*;
    import com.alibaba.druid.sql.ast.statement.*;
    import com.alibaba.druid.util.JdbcConstants;
    import org.apache.commons.lang3.NotImplementedException;
    import org.apache.commons.lang3.StringUtils;
    
    import java.util.List;
    
    /**
     * sql语句where条件处理辅助类
     *
     * @author liushuishang@gmail.com
     * @date 2017/12/21 15:05
     **/
    public class SqlConditionHelper {
    
        private ITableFieldConditionDecision conditionDecision;
    
        public SqlConditionHelper(ITableFieldConditionDecision conditionDecision) {
            this.conditionDecision = conditionDecision;
        }
    
        /**
         * 为sql'语句添加指定where条件
         *
         * @param sqlStatement
         * @param fieldName
         * @param fieldValue
         */
        public void addStatementCondition(SQLStatement sqlStatement, String fieldName, String fieldValue) {
            if (sqlStatement instanceof SQLSelectStatement) {
                SQLSelectQueryBlock queryObject = (SQLSelectQueryBlock) ((SQLSelectStatement) sqlStatement).getSelect().getQuery();
                addSelectStatementCondition(queryObject, queryObject.getFrom(), fieldName, fieldValue);
            } else if (sqlStatement instanceof SQLUpdateStatement) {
                SQLUpdateStatement updateStatement = (SQLUpdateStatement) sqlStatement;
                addUpdateStatementCondition(updateStatement, fieldName, fieldValue);
            } else if (sqlStatement instanceof SQLDeleteStatement) {
                SQLDeleteStatement deleteStatement = (SQLDeleteStatement) sqlStatement;
                addDeleteStatementCondition(deleteStatement, fieldName, fieldValue);
            } else if (sqlStatement instanceof SQLInsertStatement) {
                SQLInsertStatement insertStatement = (SQLInsertStatement) sqlStatement;
                addInsertStatementCondition(insertStatement, fieldName, fieldValue);
            }
        }
    
        /**
         * 为insert语句添加where条件
         *
         * @param insertStatement
         * @param fieldName
         * @param fieldValue
         */
        private void addInsertStatementCondition(SQLInsertStatement insertStatement, String fieldName, String fieldValue) {
            if (insertStatement != null) {
                SQLInsertInto sqlInsertInto = insertStatement;
                SQLSelect sqlSelect = sqlInsertInto.getQuery();
                if (sqlSelect != null) {
                    SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) sqlSelect.getQuery();
                    addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue);
                }
            }
        }
    
    
        /**
         * 为delete语句添加where条件
         *
         * @param deleteStatement
         * @param fieldName
         * @param fieldValue
         */
        private void addDeleteStatementCondition(SQLDeleteStatement deleteStatement, String fieldName, String fieldValue) {
            SQLExpr where = deleteStatement.getWhere();
            //添加子查询中的where条件
            addSQLExprCondition(where, fieldName, fieldValue);
    
            SQLExpr newCondition = newEqualityCondition(deleteStatement.getTableName().getSimpleName(),
                    deleteStatement.getTableSource().getAlias(), fieldName, fieldValue, where);
            deleteStatement.setWhere(newCondition);
    
        }
    
        /**
         * where中添加指定筛选条件
         *
         * @param where      源where条件
         * @param fieldName
         * @param fieldValue
         */
        private void addSQLExprCondition(SQLExpr where, String fieldName, String fieldValue) {
            if (where instanceof SQLInSubQueryExpr) {
                SQLInSubQueryExpr inWhere = (SQLInSubQueryExpr) where;
                SQLSelect subSelectObject = inWhere.getSubQuery();
                SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
                addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue);
            } else if (where instanceof SQLBinaryOpExpr) {
                SQLBinaryOpExpr opExpr = (SQLBinaryOpExpr) where;
                SQLExpr left = opExpr.getLeft();
                SQLExpr right = opExpr.getRight();
                addSQLExprCondition(left, fieldName, fieldValue);
                addSQLExprCondition(right, fieldName, fieldValue);
            } else if (where instanceof SQLQueryExpr) {
                SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) (((SQLQueryExpr) where).getSubQuery()).getQuery();
                addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue);
            }
        }
    
        /**
         * 为update语句添加where条件
         *
         * @param updateStatement
         * @param fieldName
         * @param fieldValue
         */
        private void addUpdateStatementCondition(SQLUpdateStatement updateStatement, String fieldName, String fieldValue) {
            SQLExpr where = updateStatement.getWhere();
            //添加子查询中的where条件
            addSQLExprCondition(where, fieldName, fieldValue);
            SQLExpr newCondition = newEqualityCondition(updateStatement.getTableName().getSimpleName(),
                    updateStatement.getTableSource().getAlias(), fieldName, fieldValue, where);
            updateStatement.setWhere(newCondition);
        }
    
        /**
         * 给一个查询对象添加一个where条件
         *
         * @param queryObject
         * @param fieldName
         * @param fieldValue
         */
        private void addSelectStatementCondition(SQLSelectQueryBlock queryObject, SQLTableSource from, String fieldName, String fieldValue) {
            if (StringUtils.isBlank(fieldName) || from == null || queryObject == null) return;
    
            SQLExpr originCondition = queryObject.getWhere();
            if (from instanceof SQLExprTableSource) {
                String tableName = ((SQLIdentifierExpr) ((SQLExprTableSource) from).getExpr()).getName();
                String alias = from.getAlias();
                SQLExpr newCondition = newEqualityCondition(tableName, alias, fieldName, fieldValue, originCondition);
                queryObject.setWhere(newCondition);
            } else if (from instanceof SQLJoinTableSource) {
                SQLJoinTableSource joinObject = (SQLJoinTableSource) from;
                SQLTableSource left = joinObject.getLeft();
                SQLTableSource right = joinObject.getRight();
    
                addSelectStatementCondition(queryObject, left, fieldName, fieldValue);
                addSelectStatementCondition(queryObject, right, fieldName, fieldValue);
    
            } else if (from instanceof SQLSubqueryTableSource) {
                SQLSelect subSelectObject = ((SQLSubqueryTableSource) from).getSelect();
                SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
                addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue);
            } else {
                throw new NotImplementedException("未处理的异常");
            }
        }
    
        /**
         * 根据原来的condition创建一个新的condition
         *
         * @param tableName       表名称
         * @param tableAlias      表别名
         * @param fieldName
         * @param fieldValue
         * @param originCondition
         * @return
         */
        private SQLExpr newEqualityCondition(String tableName, String tableAlias, String fieldName, String fieldValue, SQLExpr originCondition) {
            //如果不需要设置条件
            if (!conditionDecision.adjudge(tableName, fieldName)) return originCondition;
            //如果条件字段不允许为空
            if (fieldValue == null && !conditionDecision.isAllowNullValue()) return originCondition;
    
            String filedName = StringUtils.isBlank(tableAlias) ? fieldName : tableAlias + "." + fieldName;
            SQLExpr condition = new SQLBinaryOpExpr(new SQLIdentifierExpr(filedName), new SQLCharExpr(fieldValue), SQLBinaryOperator.Equality);
            return SQLUtils.buildCondition(SQLBinaryOperator.BooleanAnd, condition, false, originCondition);
        }
    
    
        public static void main(String[] args) {
    //        String sql = "select * from user s  ";
    //        String sql = "select * from user s where s.name='333'";
    //        String sql = "select * from (select * from tab t where id = 2 and name = 'wenshao') s where s.name='333'";
    //        String sql="select u.*,g.name from user u join user_group g on u.groupId=g.groupId where u.name='123'";
    
    //        String sql = "update user set name=? where id =(select id from user s)";
    //        String sql = "delete from user where id = ( select id from user s )";
    
    //        String sql = "insert into user (id,name) select g.id,g.name from user_group g where id=1";
    
            String sql = "select u.*,g.name from user u join (select * from user_group g  join user_role r on g.role_code=r.code  ) g on u.groupId=g.groupId where u.name='123'";
            List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.POSTGRESQL);
            SQLStatement sqlStatement = statementList.get(0);
            //决策器定义
            SqlConditionHelper helper = new SqlConditionHelper(new ITableFieldConditionDecision() {
                @Override
                public boolean adjudge(String tableName, String fieldName) {
                    return true;
                }
    
                @Override
                public boolean isAllowNullValue() {
                    return false;
                }
            });
            //添加多租户条件,domain是字段ignc,yay是筛选值
            helper.addStatementCondition(sqlStatement, "domain", "yay");
            System.out.println("源sql:" + sql);
            System.out.println("修改后sql:" + SQLUtils.toSQLString(statementList, JdbcConstants.POSTGRESQL));
        }
    
    
    }

    因为时间和环境限制,仅仅提供一个基础版本,可能测试不够充分,欢迎提出修正意见。

  • 相关阅读:
    php 判断访问是否是手机或者pc
    SQLSTATE[HY000] [2002] No such file or directory
    No input file specified.
    Call to undefined function openssl_decrypt()
    Parse error: syntax error, unexpected 'class' (T_CLASS)
    tp5关联模型进行条件查询
    windows下php7.1安装redis扩展以及redis测试使用全过程
    SourceTree跳过初始设置
    对象数组(JSON) 根据某个共同字段 分组
    SDUT 3377 数据结构实验之查找五:平方之哈希表
  • 原文地址:https://www.cnblogs.com/yuananyun/p/8093853.html
Copyright © 2011-2022 走看看