zoukankan      html  css  js  c++  java
  • Mybatis拦截器之数据权限过滤与分页集成

    解决方案之改SQL

    原sql

    SELECT
    	a.id AS "id",
    	a.NAME AS "name",
    	a.sex_cd AS "sexCd",
    	a.org_id AS "orgId",
    	a.STATUS AS "status",
    	a.create_org_id AS "createOrgId"
    FROM
    	pty_person a
    WHERE
    	a. STATUS = 0

    org_id是单位的标识,也就是where条件里再加个单位标识的过滤。

    改后sql

    SELECT
    	a.id AS "id",
    	a.NAME AS "name",
    	a.sex_cd AS "sexCd",
    	a.org_id AS "orgId",
    	a.STATUS AS "status",
    	a.create_org_id AS "createOrgId"
    FROM
    	pty_person a
    WHERE
    	a. STATUS = 0
    	and a.org_id LIKE concat(710701070102, '%')

    当然通过这个办法也可以实现数据的过滤,但这样的话相比大家也都有同感,那就是每个业务模块 每个人都要进行SQL改动,这次是根据单位过滤、明天又再根据其他的属性过滤,意味着要不停的改来改去,可谓是场面壮观也,而且这种集体改造耗费了时间精力不说,还会有很多不确定因素,比如SQL写错,存在漏网之鱼等等。因此这个解决方案肯定是直接PASS掉咯;

    解决方案之拦截器

    由于项目大部分采用的持久层框架是Mybatis,也是使用的Mybatis进行分页拦截处理,因此直接采用了Mybatis拦截器实现数据权限过滤。

    1、自定义数据权限过滤注解 PermissionAop,负责过滤的开关 

    package com.raising.framework.annotation;
    
    import java.lang.annotation.*;
    
    /**
     * 数据权限过滤自定义注解
     * @author lihaoshan
     * @date 2018-07-19
     * */
    @Target(ElementType.METHOD)
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    public @interface PermissionAop {
    
        String value() default "";
    }
    

    2、定义全局配置 PermissionConfig 类加载 权限过滤配置文件

    package com.raising.framework.config;
    
    import com.raising.utils.PropertiesLoader;
    import org.apache.commons.lang.StringUtils;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    import java.util.HashMap;
    import java.util.Map;
    
    /**
     * 全局配置
     * 对应 permission.properties
     * @author lihaoshan
     */
    public class PermissionConfig {
        private static Logger logger = LoggerFactory.getLogger(PropertiesLoader.class);
    
        /**
         * 保存全局属性值
         */
        private static Map<String, String> map = new HashMap<>(16);
    
        /**
         * 属性文件加载对象
         */
        private static PropertiesLoader loader = new PropertiesLoader(
                "permission.properties");
    
        /**
         * 获取配置
         */
        public static String getConfig(String key) {
            if(loader == null){
                logger.info("缺失配置文件 - permission.properties");
                return null;
            }
            String value = map.get(key);
            if (value == null) {
                value = loader.getProperty(key);
                map.put(key, value != null ? value : StringUtils.EMPTY);
            }
            return value;
        }
    
    }

    3、创建权限过滤的配置文件 permission.properties,用于配置需要拦截的DAO的 namespace

    (由于注解@PermissionAop是加在DAO层某个接口上的,而我们分页接口为封装的公共BaseDAO,所以如果仅仅使用注解方式开关拦截的话,会影响到所有的业务模块,因此需要结合额外的配置文件)

    # 需要进行拦截的SQL所属namespace
    permission.intercept.namespace=com.raising.modules.pty.dao.PtyGroupDao,com.raising.modules.pty.dao.PtyPersonDao

    4、自定义权限工具类

    根据 StatementHandler 获取Permission注解对象:

    package com.raising.utils.permission;
    
    import com.raising.framework.annotation.PermissionAop;
    import org.apache.ibatis.mapping.MappedStatement;
    
    import java.lang.reflect.Method;
    
    /**
     * 自定义权限相关工具类
     * @author lihaoshan
     * @date 2018-07-20
     * */
    public class PermissionUtils {
    
        /**
         * 根据 StatementHandler 获取 注解对象
         * @author lihaoshan
         * @date 2018-07-20
         */
        public static PermissionAop getPermissionByDelegate(MappedStatement mappedStatement){
            PermissionAop permissionAop = null;
            try {
                String id = mappedStatement.getId();
                String className = id.substring(0, id.lastIndexOf("."));
                String methodName = id.substring(id.lastIndexOf(".") + 1, id.length());
                final Class cls = Class.forName(className);
                final Method[] method = cls.getMethods();
                for (Method me : method) {
                    if (me.getName().equals(methodName) && me.isAnnotationPresent(PermissionAop.class)) {
                        permissionAop = me.getAnnotation(PermissionAop.class);
                    }
                }
            }catch (Exception e){
                e.printStackTrace();
            }
            return permissionAop;
        }
    }
    

    5、创建分页拦截器 MybatisSpringPageInterceptor 或进行改造(本文是在Mybatis分页拦截器基础上进行的数据权限拦截改造,SQL包装一定要在执行分页之前,也就是获取到原始SQL后就进行数据过滤包装) 

    首先看数据权限拦截核心代码:

    • 获取需要进行拦截的DAO层namespace拼接串;
    • 获取当前mapped所属namespace;
    • 判断配置文件中的namespace是否包含当前的mapped所属的namespace,如果包含则继续,否则直接放行;
    • 获取数据权限注解对象,及注解的值;
    • 判断注解值是否为DATA_PERMISSION_INTERCEPT,是则拦截、并进行过滤SQL包装,否则放行;
    • 根据包装后的SQL查分页总数,不能使用原始SQL进行查询;
    • 执行请求方法,获取拦截后的分页结果;

    执行流程图:

    拦截器源码:

    package com.raising.framework.interceptor;
    
    import com.raising.StaticParam;
    import com.raising.framework.annotation.PermissionAop;
    import com.raising.framework.config.PermissionConfig;
    import com.raising.modules.sys.entity.User;
    import com.raising.utils.JStringUtils;
    import com.raising.utils.UserUtils;
    import com.raising.utils.permission.PermissionUtils;
    import org.apache.ibatis.executor.Executor;
    import org.apache.ibatis.executor.parameter.ParameterHandler;
    import org.apache.ibatis.executor.statement.RoutingStatementHandler;
    import org.apache.ibatis.executor.statement.StatementHandler;
    import org.apache.ibatis.mapping.BoundSql;
    import org.apache.ibatis.mapping.MappedStatement;
    import org.apache.ibatis.mapping.ParameterMapping;
    import org.apache.ibatis.plugin.*;
    import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
    import org.apache.ibatis.session.ResultHandler;
    import org.apache.ibatis.session.RowBounds;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    import java.lang.reflect.Field;
    import java.sql.Connection;
    import java.sql.PreparedStatement;
    import java.sql.ResultSet;
    import java.sql.SQLException;
    import java.util.List;
    import java.util.Map;
    import java.util.Properties;
    
    /**
     * 分页拦截器
     * @author GaoYuan
     * @author lihaoshan 增加了数据权限的拦截过滤
     * @datetime 2017/12/1 下午5:43
     */
    @Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }),
            @Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) })
    public class MybatisSpringPageInterceptor implements Interceptor {
        private static final Logger log = LoggerFactory.getLogger(MybatisSpringPageInterceptor.class);
    
        public static final String MYSQL = "mysql";
        public static final String ORACLE = "oracle";
        /**数据库类型,不同的数据库有不同的分页方法*/
        protected String databaseType;
    
        @SuppressWarnings("rawtypes")
        protected ThreadLocal<Page> pageThreadLocal = new ThreadLocal<Page>();
    
        public String getDatabaseType() {
            return databaseType;
        }
    
        public void setDatabaseType(String databaseType) {
            if (!databaseType.equalsIgnoreCase(MYSQL) && !databaseType.equalsIgnoreCase(ORACLE)) {
                throw new PageNotSupportException("Page not support for the type of database, database type [" + databaseType + "]");
            }
            this.databaseType = databaseType;
        }
    
        @Override
        public Object plugin(Object target) {
            return Plugin.wrap(target, this);
        }
    
        @Override
        public void setProperties(Properties properties) {
            String databaseType = properties.getProperty("databaseType");
            if (databaseType != null) {
                setDatabaseType(databaseType);
            }
        }
    
        @Override
        @SuppressWarnings({ "unchecked", "rawtypes" })
        public Object intercept(Invocation invocation) throws Throwable {
            // 控制SQL和查询总数的地方
            if (invocation.getTarget() instanceof StatementHandler) {
                Page page = pageThreadLocal.get();
                //不是分页查询
                if (page == null) {
                    return invocation.proceed();
                }
    
                RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
                StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
                BoundSql boundSql = delegate.getBoundSql();
    
                Connection connection = (Connection) invocation.getArgs()[0];
                // 准备数据库类型
                prepareAndCheckDatabaseType(connection);
                MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
    
                String sql = boundSql.getSql();
    
                /** 单位数据权限拦截 begin */
                //获取需要进行拦截的DAO层namespace拼接串
                String interceptNamespace = PermissionConfig.getConfig("permission.intercept.namespace");
    
                //获取当前mapped的namespace
                String mappedStatementId = mappedStatement.getId();
                String className = mappedStatementId.substring(0, mappedStatementId.lastIndexOf("."));
    
                if(JStringUtils.isNotBlank(interceptNamespace)){
                    //判断配置文件中的namespace是否与当前的mapped namespace匹配,如果包含则进行拦截,否则放行
                    if(interceptNamespace.contains(className)){
                        //获取数据权限注解对象
                        PermissionAop permissionAop = PermissionUtils.getPermissionByDelegate(mappedStatement);
                        if (permissionAop != null){
                            //获取注解的值
                            String permissionAopValue = permissionAop.value();
                            //判断注解是否开启拦截
                            if(StaticParam.DATA_PERMISSION_INTERCEPT.equals(permissionAopValue) ){
                                if(log.isInfoEnabled()){
                                    log.info("数据权限拦截【拼接SQL】...");
                                }
                                //返回拦截包装后的sql
                                sql = permissionSql(sql);
                                ReflectUtil.setFieldValue(boundSql, "sql", sql);
                            } else {
                                if(log.isInfoEnabled()){
                                    log.info("数据权限放行...");
                                }
                            }
                        }
    
                    }
                }
                /** 单位数据权限拦截 end */
    
                if (page.getTotalPage() > -1) {
                    if (log.isTraceEnabled()) {
                        log.trace("已经设置了总页数, 不需要再查询总数.");
                    }
                } else {
                    Object parameterObj = boundSql.getParameterObject();
    ///                MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
                    queryTotalRecord(page, parameterObj, mappedStatement, sql,connection);
                }
    
                String pageSql = buildPageSql(page, sql);
                if (log.isDebugEnabled()) {
                    log.debug("分页时, 生成分页pageSql......");
                }
                ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
    
                return invocation.proceed();
            } else { // 查询结果的地方
                // 获取是否有分页Page对象
                Page<?> page = findPageObject(invocation.getArgs()[1]);
                if (page == null) {
                    if (log.isTraceEnabled()) {
                        log.trace("没有Page对象作为参数, 不是分页查询.");
                    }
                    return invocation.proceed();
                } else {
                    if (log.isTraceEnabled()) {
                        log.trace("检测到分页Page对象, 使用分页查询.");
                    }
                }
                //设置真正的parameterObj
                invocation.getArgs()[1] = extractRealParameterObject(invocation.getArgs()[1]);
    
                pageThreadLocal.set(page);
                try {
                    // Executor.query(..)
                    Object resultObj = invocation.proceed();
                    if (resultObj instanceof List) {
                        /* @SuppressWarnings({ "unchecked", "rawtypes" }) */
                        page.setResults((List) resultObj);
                    }
                    return resultObj;
                } finally {
                    pageThreadLocal.remove();
                }
            }
        }
    
        protected Page<?> findPageObject(Object parameterObj) {
            if (parameterObj instanceof Page<?>) {
                return (Page<?>) parameterObj;
            } else if (parameterObj instanceof Map) {
                for (Object val : ((Map<?, ?>) parameterObj).values()) {
                    if (val instanceof Page<?>) {
                        return (Page<?>) val;
                    }
                }
            }
            return null;
        }
    
        /**
         * <pre>
         * 把真正的参数对象解析出来
         * Spring会自动封装对个参数对象为Map<String, Object>对象
         * 对于通过@Param指定key值参数我们不做处理,因为XML文件需要该KEY值
         * 而对于没有@Param指定时,Spring会使用0,1作为主键
         * 对于没有@Param指定名称的参数,一般XML文件会直接对真正的参数对象解析,
         * 此时解析出真正的参数作为根对象
         * </pre>
         * @param parameterObj
         * @return
         */
        protected Object extractRealParameterObject(Object parameterObj) {
            if (parameterObj instanceof Map<?, ?>) {
                Map<?, ?> parameterMap = (Map<?, ?>) parameterObj;
                if (parameterMap.size() == 2) {
                    boolean springMapWithNoParamName = true;
                    for (Object key : parameterMap.keySet()) {
                        if (!(key instanceof String)) {
                            springMapWithNoParamName = false;
                            break;
                        }
                        String keyStr = (String) key;
                        if (!"0".equals(keyStr) && !"1".equals(keyStr)) {
                            springMapWithNoParamName = false;
                            break;
                        }
                    }
                    if (springMapWithNoParamName) {
                        for (Object value : parameterMap.values()) {
                            if (!(value instanceof Page<?>)) {
                                return value;
                            }
                        }
                    }
                }
            }
            return parameterObj;
        }
    
        protected void prepareAndCheckDatabaseType(Connection connection) throws SQLException {
            if (databaseType == null) {
                String productName = connection.getMetaData().getDatabaseProductName();
                if (log.isTraceEnabled()) {
                    log.trace("Database productName: " + productName);
                }
                productName = productName.toLowerCase();
                if (productName.indexOf(MYSQL) != -1) {
                    databaseType = MYSQL;
                } else if (productName.indexOf(ORACLE) != -1) {
                    databaseType = ORACLE;
                } else {
                    throw new PageNotSupportException("Page not support for the type of database, database product name [" + productName + "]");
                }
                if (log.isInfoEnabled()) {
                    log.info("自动检测到的数据库类型为: " + databaseType);
                }
            }
        }
    
        /**
         * <pre>
         * 生成分页SQL
         * </pre>
         *
         * @param page
         * @param sql
         * @return
         */
        protected String buildPageSql(Page<?> page, String sql) {
            if (MYSQL.equalsIgnoreCase(databaseType)) {
                return buildMysqlPageSql(page, sql);
            } else if (ORACLE.equalsIgnoreCase(databaseType)) {
                return buildOraclePageSql(page, sql);
            }
            return sql;
        }
    
        /**
         * <pre>
         * 生成Mysql分页查询SQL
         * </pre>
         *
         * @param page
         * @param sql
         * @return
         */
        protected String buildMysqlPageSql(Page<?> page, String sql) {
            // 计算第一条记录的位置,Mysql中记录的位置是从0开始的。
            int offset = (page.getPageNo() - 1) * page.getPageSize();
            if(offset<0){
                return " limit 0 ";
            }
            return new StringBuilder(sql).append(" limit ").append(offset).append(",").append(page.getPageSize()).toString();
        }
    
        /**
         * <pre>
         * 生成Oracle分页查询SQL
         * </pre>
         *
         * @param page
         * @param sql
         * @return
         */
        protected String buildOraclePageSql(Page<?> page, String sql) {
            // 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
            int offset = (page.getPageNo() - 1) * page.getPageSize() + 1;
            StringBuilder sb = new StringBuilder(sql);
            sb.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());
            sb.insert(0, "select * from (").append(") where r >= ").append(offset);
            return sb.toString();
        }
    
        /**
         * <pre>
         * 查询总数
         * </pre>
         *
         * @param page
         * @param parameterObject
         * @param mappedStatement
         * @param sql
         * @param connection
         * @throws SQLException
         */
        protected void queryTotalRecord(Page<?> page, Object parameterObject, MappedStatement mappedStatement, String sql, Connection connection) throws SQLException {
            BoundSql boundSql = mappedStatement.getBoundSql(page);
    ///        String sql = boundSql.getSql();
    
            String countSql = this.buildCountSql(sql);
            if (log.isDebugEnabled()) {
                log.debug("分页时, 生成countSql......");
            }
    
            List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
            BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parameterObject);
            ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql);
            PreparedStatement pstmt = null;
            ResultSet rs = null;
            try {
                pstmt = connection.prepareStatement(countSql);
                parameterHandler.setParameters(pstmt);
                rs = pstmt.executeQuery();
                if (rs.next()) {
                    long totalRecord = rs.getLong(1);
                    page.setTotalRecord(totalRecord);
                }
            } finally {
                if (rs != null) {
                    try {
                        rs.close();
                    } catch (Exception e) {
                        if (log.isWarnEnabled()) {
                            log.warn("关闭ResultSet时异常.", e);
                        }
                    }
                }
                if (pstmt != null) {
                    try {
                        pstmt.close();
                    } catch (Exception e) {
                        if (log.isWarnEnabled()) {
                            log.warn("关闭PreparedStatement时异常.", e);
                        }
                    }
                }
            }
        }
    
        /**
         * 根据原Sql语句获取对应的查询总记录数的Sql语句
         *
         * @param sql
         * @return
         */
        protected String buildCountSql(String sql) {
            //查出第一个from,先转成小写
            sql = sql.toLowerCase();
            int index = sql.indexOf("from");
            return "select count(0) " + sql.substring(index);
        }
    
        /**
         * 利用反射进行操作的一个工具类
         *
         */
        private static class ReflectUtil {
            /**
             * 利用反射获取指定对象的指定属性
             *
             * @param obj 目标对象
             * @param fieldName 目标属性
             * @return 目标属性的值
             */
            public static Object getFieldValue(Object obj, String fieldName) {
                Object result = null;
                Field field = ReflectUtil.getField(obj, fieldName);
                if (field != null) {
                    field.setAccessible(true);
                    try {
                        result = field.get(obj);
                    } catch (IllegalArgumentException e) {
                        // TODO Auto-generated catch block
                        e.printStackTrace();
                    } catch (IllegalAccessException e) {
                        // TODO Auto-generated catch block
                        e.printStackTrace();
                    }
                }
                return result;
            }
    
            /**
             * 利用反射获取指定对象里面的指定属性
             *
             * @param obj 目标对象
             * @param fieldName 目标属性
             * @return 目标字段
             */
            private static Field getField(Object obj, String fieldName) {
                Field field = null;
                for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) {
                    try {
                        field = clazz.getDeclaredField(fieldName);
                        break;
                    } catch (NoSuchFieldException e) {
                        // 这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。
                    }
                }
                return field;
            }
    
            /**
             * 利用反射设置指定对象的指定属性为指定的值
             *
             * @param obj 目标对象
             * @param fieldName 目标属性
             * @param fieldValue 目标值
             */
            public static void setFieldValue(Object obj, String fieldName, String fieldValue) {
                Field field = ReflectUtil.getField(obj, fieldName);
                if (field != null) {
                    try {
                        field.setAccessible(true);
                        field.set(obj, fieldValue);
                    } catch (IllegalArgumentException e) {
                        // TODO Auto-generated catch block
                        e.printStackTrace();
                    } catch (IllegalAccessException e) {
                        // TODO Auto-generated catch block
                        e.printStackTrace();
                    }
                }
            }
        }
    
        public static class PageNotSupportException extends RuntimeException {
    
            /** serialVersionUID*/
            private static final long serialVersionUID = 1L;
    
            public PageNotSupportException() {
                super();
            }
    
            public PageNotSupportException(String message, Throwable cause) {
                super(message, cause);
            }
    
            public PageNotSupportException(String message) {
                super(message);
            }
    
            public PageNotSupportException(Throwable cause) {
                super(cause);
            }
        }
    
        /**
         * 数据权限sql包装【只能查看本级单位及下属单位的数据】
         * @author lihaoshan
         * @date 2018-07-19
         */
        protected String permissionSql(String sql) {
            StringBuilder sbSql = new StringBuilder(sql);
            //获取当前登录人
            User user = UserUtils.getLoginUser();
            String orgId =null;
            if (user != null) {
                //获取当前登录人所属单位标识
                orgId = user.getOrganizationId();
            }
            //如果有动态参数 orgId
            if(orgId != null){
                sbSql = new StringBuilder("select * from (")
                        .append(sbSql)
                        .append(" ) s ")
                        .append(" where s.createOrgId like concat("+ orgId +",'%') ");
            }
            return sbSql.toString();
        }
    }
  • 相关阅读:
    使用百度网盘配置私有Git服务
    Linked dylibs built for GC-only but object files built for retain/release for architecture x86_64
    我的博客搬家啦!!!
    今日头条核心业务(高级)开发工程师,直接推给部门经理,HC很多,感兴趣的可以一起聊聊。
    学习Python的三种境界
    拿到阿里,网易游戏,腾讯,smartx的offer的过程
    关于计算机网络一些问题的思考
    网易游戏面试经验(三)
    网易游戏面试经验(二)
    网易游戏面试经验(一)
  • 原文地址:https://www.cnblogs.com/chenxiaoxian/p/9816817.html
Copyright © 2011-2022 走看看