Spring Boot 使用Mybatis拦截器结合Alibaba Druid解析SQL实现数据隔离
有些时候我们需要使用到Mybatis拦截器对SQL进行后期处理,比如根据用户的角色给WHERE
句子动态添加一些条件,比如如果用户是一个公司管理人员,我们希望对数据表的操作局限与该用户所对对应公司的数据。
比如一张简单的人员表:
t_person
{
long id; //主键
varchar name; //姓名
long com_id; //公司id
long dept_id; //部门id
}
当一个公司id为1
的公司管理员查询公司人员数据时,我们希望sql语句应该是这样的:
SELECT * FROM t_person WHERE com_id=1
这样查询出来的数据都是本公司的数据。我们可以在每个查询的地方设置这样的条件,比如Mybatis的Example。
当然我们最好把重复的代码抽象出来,便于复用。
这里我们使用Mybaits 拦截器来做。
思路:
- 注入一个Mybaits拦截器
- 重写拦截方法,获取一个标志位是否需要进行数据隔离,是进入3步,否则跳到4步
- 获取当前用户信息,如果是公司管理员,获取该公司管理员的公司id(假设为
1
),在拦截器获取拦截到SQL的语句加上WHERE com_id=1
- 返回处理后的
SQL
语句
伪代码:
@Intercepts({
@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})
})
@Slf4j
@Component
public class SqlInterceptor implements Interceptor {
/**
* 拦截sql,并设置约束条件
*
* @param invocation
* @return
* @throws Throwable
*/
@Override
public Object intercept(Invocation invocation) throws Throwable {
boolean flag=XXXX.XXX();
//获取一个标志位是否需要进行数据隔离
if(flag){
//获取当前用户信息
User ust=XXX.getCurrentUser();
//如果是公司管理员
if(user.isCompanyAdmin()){
//获取该公司管理员的公司id
Long companyId=user.getCompanyId();
//拦截器获取拦截到SQL的语句加上限制条件
String newSql=oldSql+"where com_id="+companyId;
//设置经过处理的SQL语句
invocation.XXX(newSql);
}
}
//返回处理后的SQL语句
return invocation.proceed();
}
}
改进:
在拦截器中我们其实不应把业务逻辑写在里面,我们想要是Mybatis拦截器获取根据是否进行数据隔离条件,如果是则获取约束的条件,把约束条件加入到WHERE
的条件中。
@Override
public Object intercept(Invocation invocation) throws Throwable {
boolean flag=XXXX.XXX();
//获取一个标志位是否需要进行数据隔离
if(flag){
//获取约束条件
Constraint constraint=XXXX.getCurrentConstraint();
//根据当前sql语句和约束拼接sql语句
String newSql=concatConstraint(oldSql,constraint);
invocation.XXX(newSql);
}
//返回处理后的SQL语句
return invocation.proceed();
}
问题:
对于Constraint
我们Mybaitis拦截器想要一个Map, 该Map的key是字段名,value为字段值。拦截器处理的条件为AND连接,key和value通过‘=’连接。
比如Constraint
可能的结构:
Constraint
{
Map<String,Object> map;
}
假设map内容为["com_id":1],有如下处理代码:
//获取约束条件
Constraint constraint=XXXX.getCurrentConstraint();
//根据当前sql语句和约束拼接sql语句
String newSql=concatConstraint(oldSql,constraint);
假设oldSql
内容为SELECT * FROM t_person WHERE name like '张三%'
。
那么上面那段代码根据上面假设的约束条件constraint处理后的SQL语句(即newSql):
SELECT * FROM t_person WHERE name like '张三%' AND com_id=1
`
再比如oldSql
内容为SELECT * FROM t_person
。那么处理后的SQL语句应该为SELECT * FROM t_person WHERE com_id=1
显然我们需要解析SQL
语句结构根据是否有WHERE
关键动态修改的条件,解决什么时候要加WHERE com_id=1
还是加AND com_id=1
。
引入Alibaba Druid:
为了解决我们改进内容中问题,我们引入阿里巴巴Druid项目,来解析SQL结构,方便我们修改条件:
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
<version>1.1.20</version>
</dependency>
先看如何约束条件如何表示:
@Data
public class ConstraintContext {
/**
* @author Shen Zhifeng
* @version 1.0.0
* @class MethodType
* @classdesc 约束语句类型
* @date 2020/9/8 17:44
* @see
* @since
*/
public enum SqlType {
//选择语句
SELECT,
//更新语句
UPDATE,
//删除语句
DELETE
}
/**
* 构建一个约束
*
* @param sqlType 约束语句类型
* @param dbType 数据库类型
* @param constraintsMap 约束键值
* @param constraintString 约束sql语句
* @return * @return: null
* @throws java.lang.IllegalArgumentException dbType为空
* @see
* @since
**/
public ConstraintContext(ConstraintContext.SqlType sqlType, String dbType, Map<String, Object> constraintsMap, String constraintString) {
Assert.notBlank(dbType, "dbType不能为空");
this.sqlType = sqlType;
this.dbType = dbType;
this.constraintsMap = constraintsMap;
this.constraintString = constraintString;
}
/**
* 构建一个约束
*
* @param dbType 数据库类型
* @param constraintsMap 约束键值
* @return * @return: null
* @throws java.lang.IllegalArgumentException dbType为空
* @see
* @since
**/
public ConstraintContext(String dbType, Map<String, Object> constraintsMap) {
this(null, dbType, constraintsMap, null);
}
/**
* 构建一个约束
*
* @param dbType 数据库类型
* @param constraintString 约束sql语句
* @return * @return: null
* @throws java.lang.IllegalArgumentException dbType为空
* @see
* @since
**/
public ConstraintContext(String dbType, String constraintString) {
this(null, dbType, null, constraintString);
}
/**
* 要插入约束的语句类型
*/
private final SqlType sqlType;
/**
* 数据库类型
* oracle AliOracle mysql mariadb h2 postgresql edb sqlserver jtds db2 odps phoenix
*/
private final String dbType;
/**
* 约束条件
*/
private final Map<String, Object> constraintsMap;
/**
* 自定义的约束条件
*/
private final String constraintString;
}
约束条件内容主要是有数据库内容,要拦截SQL语句类型,约束条件Map(改进中提到的),约束字符串(可以自定拦截内容(直接拼接到where),比约束条件Map相对起来比较灵活)。
我们看一下如何保存约束条件,通常约束条件保存在RequestAttributes
。我们使用一个帮助类来完成设置和清除约束条件:
@Slf4j
public class ConstraintHelper {
private static final String CONTEXT_KET = "bes_constrain";
private ConstraintHelper() {
}
/**
* 设置当前的约束条件上下文
*
* @param context
* @return boss.xtrain.core.mybatis.interceptor.ConstraintContext
* @throws
* @see
* @since
**/
public static ConstraintContext setContext(ConstraintContext context) {
Assert.notNull(context, "context为null");
RequestAttributes attributes = RequestContextHolder.currentRequestAttributes();
attributes.setAttribute(CONTEXT_KET, context, RequestAttributes.SCOPE_REQUEST);
return context;
}
/**
* 获取当前的约束条件上下文
*
* @return boss.xtrain.core.mybatis.interceptor.ConstraintContext
* @throws
* @see
* @since
**/
public static ConstraintContext getContext() {
try {
RequestAttributes attributes = RequestContextHolder.currentRequestAttributes();
return (ConstraintContext) attributes.getAttribute(CONTEXT_KET, RequestAttributes.SCOPE_REQUEST);
} catch (IllegalStateException e) {
//不是走Controller方法进入,调用dao层代码,RequestContextHolder获取不到RequestAttributes会抛异常,这里捕获异常防止传播
//因此对于需要数据隔离必须要走Controller方法进入
log.error("数据约束失效:{}", e.getMessage());
}
return null;
}
/**
* 清除当前的约束条件上下文
*
* @return void
* @throws
* @see
* @since
**/
public static void clearContext() {
RequestAttributes attributes = RequestContextHolder.currentRequestAttributes();
attributes.removeAttribute(CONTEXT_KET, RequestAttributes.SCOPE_REQUEST);
}
}
最后我们看下完整拦截器代码:
@Intercepts({
@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})
})
@Slf4j
@Component
public class SqlInterceptor implements Interceptor {
/**
* 拦截sql,并设置约束条件
*
* @param invocation
* @return
* @throws Throwable
*/
@Override
public Object intercept(Invocation invocation) throws Throwable {
ConstraintContext constraints = ConstraintHelper.getContext();
//没有设置约束上下文的sql不进行拦截
if (constraints != null) {
String constraintSql = getConstraintString(constraints.getConstraintsMap(), constraints.getConstraintString());
if (constraintSql != null) {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaStatementHandler = SystemMetaObject.forObject(statementHandler);
//获取sql
String oldSql = String.valueOf(metaStatementHandler.getValue("delegate.boundSql.sql"));
if (log.isDebugEnabled()) {
log.info("拦截器处理前的sql语句为" + oldSql);
}
String newSql = contactConditions(oldSql, constraintSql, constraints.getDbType(), constraints.getSqlType());
//重新设置sql
metaStatementHandler.setValue("delegate.boundSql.sql", newSql);
if (log.isDebugEnabled()) {
log.info("经拦截器处理后的sql语句为" + newSql);
}
}
}
return invocation.proceed();
}
/**
* 根据map或constraintString生成约束sql字符串,优先使用map(如果设置的话)
* map的约束条件,默认and连接,=约束
*
* @param map 约束map
* @param constraintString 约束sql
* @return
*/
private String getConstraintString(Map<String, Object> map, String constraintString) {
if (map != null && !map.isEmpty()) {
StringBuilder constraintsBuffer = new StringBuilder();
Set<String> keys = map.keySet();
Iterator<String> keyIter = keys.iterator();
if (keyIter.hasNext()) {
String key = keyIter.next();
constraintsBuffer.append(key).append(" = " + getSqlByClass(map.get(key)));
}
while (keyIter.hasNext()) {
String key = keyIter.next();
constraintsBuffer.append(" AND ").append(key).append(" = " + getSqlByClass(map.get(key)));
}
return constraintsBuffer.toString();
}
if (!StrUtil.isBlank(constraintString)) {
return constraintString;
}
return null;
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
//nothing to do
}
/**
* 根据语句类型和sqlType决定是否添加约束条件
*
* @param oldSql 旧sql
* @param constraintSql 约束sql
* @param dbType 数据库类型
* @param sqlType 语句类型
* @return
*/
private static String contactConditions(String oldSql, String constraintSql, String dbType, ConstraintContext.SqlType sqlType) {
SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(oldSql, dbType);
List<SQLStatement> stmtList = parser.parseStatementList();
SQLStatement stmt = stmtList.get(0);
SQLExprParser constraintsParser = SQLParserUtils.createExprParser(constraintSql, dbType);
SQLExpr constraintsExpr = constraintsParser.expr();
//选择语句
boolean useSelection = (sqlType == null || sqlType == ConstraintContext.SqlType.SELECT) && stmt instanceof SQLSelectStatement;
if (useSelection) {
SQLSelectStatement selectStmt = (SQLSelectStatement) stmt;
// 拿到SQLSelect
SQLSelect sqlselect = selectStmt.getSelect();
SQLSelectQueryBlock query = (SQLSelectQueryBlock) sqlselect.getQuery();
SQLExpr whereExpr = query.getWhere();
// 修改where表达式
if (whereExpr == null) {
query.setWhere(constraintsExpr);
} else {
SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, constraintsExpr);
query.setWhere(newWhereExpr);
}
sqlselect.setQuery(query);
return sqlselect.toString();
}
//更新语句
boolean useUpgrade = (sqlType == null || sqlType == ConstraintContext.SqlType.UPDATE) && stmt instanceof SQLUpdateStatement;
if (useUpgrade) {
SQLUpdateStatement updateStatement = (SQLUpdateStatement) stmt;
// 拿到SQLSelect
SQLExpr whereExpr = updateStatement.getWhere();
// 修改where表达式
if (whereExpr == null) {
updateStatement.setWhere(constraintsExpr);
} else {
SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, constraintsExpr);
updateStatement.setWhere(newWhereExpr);
}
return updateStatement.toString();
}
//删除语句
boolean useDeleting = (sqlType == null || sqlType == ConstraintContext.SqlType.DELETE) && stmt instanceof SQLDeleteStatement;
if (useDeleting) {
SQLDeleteStatement deleteStatement = (SQLDeleteStatement) stmt;
// 拿到SQLSelect
SQLExpr whereExpr = deleteStatement.getWhere();
// 修改where表达式
if (whereExpr == null) {
deleteStatement.setWhere(constraintsExpr);
} else {
SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, constraintsExpr);
deleteStatement.setWhere(newWhereExpr);
}
return deleteStatement.toString();
}
return oldSql;
}
/**
* 转换java对象到sql支持的类型值
*
* @param value
* @return
*/
private static String getSqlByClass(Object value) {
if (value instanceof Number) {
return value + "";
} else if (value instanceof String) {
return "'" + value + "'";
}
return "'" + value.toString() + "'";
}
}
我们为SQL语句设置where子句条件时,如果SQL语句本身没有WHERE
关键字,我们就设置一个否则把原来WHERE
条件拼接上我们的约束条件。解析SQL语句并结构化减少我们代码工作量。
if (useSelection) {
SQLSelectStatement selectStmt = (SQLSelectStatement) stmt;
// 拿到SQLSelect
SQLSelect sqlselect = selectStmt.getSelect();
SQLSelectQueryBlock query = (SQLSelectQueryBlock) sqlselect.getQuery();
SQLExpr whereExpr = query.getWhere();
// 修改where表达式
if (whereExpr == null) {
query.setWhere(constraintsExpr);
} else {
SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, constraintsExpr);
query.setWhere(newWhereExpr);
}
sqlselect.setQuery(query);
return sqlselect.toString();
}
使用:
比如我们有一个UserDao查询用户方法:
public List<User> query(){
//约束条件构建
Map<String,Object> map=new HashMap<>();
//获取当前用户
SecurityUser user = SecurityContextHolder.getCurrentUser();
//如果时超级管理员
if(user.isSystemAdmin()){
//没什么可以做的
}
//如果是公司管理员
if(user.isCompanyAdmin()){
//获取用户的公司id
Long companyId = user.getCompanyId();
//设置约束条件 相当于 com_id=${companyId}
map.put("com_id",companyId);
}else{
//其他人员就无法查询了,因为0=1条件无法满足
map.put("0",1);
}
ConstraintContext context = new ConstraintContext("mysql",map);
//设置条件
ConstraintHelper.setContext(context);
List<User> users = mapper.selectAll();
//清除条件,防止约束条件污染其他语句
ConstraintHelper.clearContext();
return users;
}
以下代码其实可以进行复用,可以做一个切面类配和注解解决。就不在贴代码了。
public List<User> query(){
//约束条件构建
Map<String,Object> map=new HashMap<>();
//获取当前用户
SecurityUser user = SecurityContextHolder.getCurrentUser();
//如果时超级管理员
if(user.isSystemAdmin()){
//没什么可以做的
}
//如果是公司管理员
if(user.isCompanyAdmin()){
//获取用户的公司id
Long companyId = user.getCompanyId();
//设置约束条件 相当于 com_id=${companyId}
map.put("com_id",companyId);
}else{
//其他人员就无法查询了,因为0=1条件无法满足
map.put("0",1);
}
ConstraintContext context = new ConstraintContext("mysql",map);
//设置条件
ConstraintHelper.setContext(context);
//...
}