zoukankan      html  css  js  c++  java
  • myBatis学习笔记(10)——使用拦截器实现分页查询

    1. Page

    package com.sm.model;
    
    import java.util.List;
    
    public class Page<T> {
    
        public static final int DEFAULT_PAGE_SIZE = 20;
    
        protected int pageNo = 1; // 当前页, 默觉得第1页
        protected int pageSize = DEFAULT_PAGE_SIZE; // 每页记录数
        protected long totalRecord = -1; // 总记录数, 默觉得-1, 表示须要查询
        protected int totalPage = -1; // 总页数, 默觉得-1, 表示须要计算
    
        protected List<T> results; // 当前页记录List形式
    
    
        public int getPageNo() {
            return pageNo;
        }
    
        public void setPageNo(int pageNo) {
            this.pageNo = pageNo;
        }
    
        public int getPageSize() {
            return pageSize;
        }
    
        public void setPageSize(int pageSize) {
            this.pageSize = pageSize;
            computeTotalPage();
        }
    
        public long getTotalRecord() {
            return totalRecord;
        }
    
        public int getTotalPage() {
            return totalPage;
        }
    
        public void setTotalRecord(long totalRecord) {
            this.totalRecord = totalRecord;
            computeTotalPage();
        }
    
        protected void computeTotalPage() {
            if (getPageSize() > 0 && getTotalRecord() > -1) {
                this.totalPage = (int) (getTotalRecord() % getPageSize() == 0 ? getTotalRecord() / getPageSize() : getTotalRecord() / getPageSize() + 1);
            }
        }
    
        public List<T> getResults() {
            return results;
        }
    
        public void setResults(List<T> results) {
            this.results = results;
        }
    
        @Override
        public String toString() {
            StringBuilder builder = new StringBuilder().append("Page [pageNo=").append(pageNo).append(", pageSize=").append(pageSize)
                    .append(", totalRecord=").append(totalRecord < 0 ?

    "null" : totalRecord).append(", totalPage=") .append(totalPage < 0 ? "null" : totalPage).append(", results=").append(results == null ?

    "null" : results).append("]"); return builder.toString(); } }

    2. 实现拦截器

    package com.sm.model;
    
    import java.lang.reflect.Field;
    import java.sql.Connection;
    import java.sql.PreparedStatement;
    import java.sql.ResultSet;
    import java.sql.SQLException;
    import java.util.List;
    import java.util.Map;
    import java.util.Properties;
    
    import org.apache.ibatis.executor.Executor;
    import org.apache.ibatis.executor.parameter.DefaultParameterHandler;
    import org.apache.ibatis.executor.parameter.ParameterHandler;
    import org.apache.ibatis.executor.statement.RoutingStatementHandler;
    import org.apache.ibatis.executor.statement.StatementHandler;
    import org.apache.ibatis.mapping.BoundSql;
    import org.apache.ibatis.mapping.MappedStatement;
    import org.apache.ibatis.mapping.ParameterMapping;
    import org.apache.ibatis.plugin.Interceptor;
    import org.apache.ibatis.plugin.Intercepts;
    import org.apache.ibatis.plugin.Invocation;
    import org.apache.ibatis.plugin.Plugin;
    import org.apache.ibatis.plugin.Signature;
    import org.apache.ibatis.session.ResultHandler;
    import org.apache.ibatis.session.RowBounds;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    @Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }),
            @Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) })
    public class PageInterceptor implements Interceptor {
    
        private static final Logger log = LoggerFactory.getLogger(PageInterceptor.class);
    
        public static final String MYSQL = "mysql";
        public static final String ORACLE = "oracle";
    
        protected String databaseType;// 数据库类型。不同的数据库有不同的分页方法
    
        protected ThreadLocal<Page> pageThreadLocal = new ThreadLocal<Page>();
    
        public String getDatabaseType() {
            return databaseType;
        }
    
        public void setDatabaseType(String databaseType) {
            if (!databaseType.equalsIgnoreCase(MYSQL) && !databaseType.equalsIgnoreCase(ORACLE)) {
                throw new PageNotSupportException("Page not support for the type of database, database type [" + databaseType + "]");
            }
            this.databaseType = databaseType;
        }
    
        @Override
        public Object plugin(Object target) {
            return Plugin.wrap(target, this);
        }
    
        @Override
        public void setProperties(Properties properties) {
            String databaseType = properties.getProperty("databaseType");
            if (databaseType != null) {
                setDatabaseType(databaseType);
            }
        }
    
        @Override
        @SuppressWarnings({ "unchecked", "rawtypes" })
        public Object intercept(Invocation invocation) throws Throwable {
            if (invocation.getTarget() instanceof StatementHandler) {// 控制SQL和查询总数的地方
                Page page = pageThreadLocal.get();
                if (page == null) { //不是分页查询
                    return invocation.proceed();
                }
    
                RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
                StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
                BoundSql boundSql = delegate.getBoundSql();
                Connection connection = (Connection) invocation.getArgs()[0];
                prepareAndCheckDatabaseType(connection); // 准备数据库类型
    
                if (page.getTotalPage() > -1) {
                    if (log.isTraceEnabled()) {
                        log.trace("已经设置了总页数, 不须要再查询总数.");
                    }
                } else {
                    Object parameterObj = boundSql.getParameterObject();
                    MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
                    queryTotalRecord(page, parameterObj, mappedStatement, connection);
                }
    
                String sql = boundSql.getSql();
                String pageSql = buildPageSql(page, sql);
                if (log.isDebugEnabled()) {
                    log.debug("分页时, 生成分页pageSql: " + pageSql);
                }
                ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
    
                return invocation.proceed();
            } else { // 查询结果的地方
                // 获取是否有分页Page对象
                Page<?> page = findPageObject(invocation.getArgs()[1]);
                if (page == null) {
                    if (log.isTraceEnabled()) {
                        log.trace("没有Page对象作为參数, 不是分页查询.");
                    }
                    return invocation.proceed();
                } else {
                    if (log.isTraceEnabled()) {
                        log.trace("检測到分页Page对象, 使用分页查询.");
                    }
                }
                //设置真正的parameterObj
                invocation.getArgs()[1] = extractRealParameterObject(invocation.getArgs()[1]);
    
                pageThreadLocal.set(page);
                try {
                    Object resultObj = invocation.proceed(); // Executor.query(..)
                    if (resultObj instanceof List) {
                        /* @SuppressWarnings({ "unchecked", "rawtypes" }) */
                        page.setResults((List) resultObj);
                    }
                    return resultObj;
                } finally {
                    pageThreadLocal.remove();
                }
            }
        }
    
        protected Page<?> findPageObject(Object parameterObj) {
            if (parameterObj instanceof Page<?>) {
                return (Page<?>) parameterObj;
            } else if (parameterObj instanceof Map) {
                for (Object val : ((Map<?

    , ?

    >) parameterObj).values()) { if (val instanceof Page<?>) { return (Page<?

    >) val; } } } return null; } /** * <pre> * 把真正的參数对象解析出来 * Spring会自己主动封装对个參数对象为Map<String, Object>对象 * 对于通过@Param指定key值參数我们不做处理,由于XML文件须要该KEY值 * 而对于没有@Param指定时,Spring会使用0,1作为主键 * 对于没有@Param指定名称的參数,一般XML文件会直接对真正的參数对象解析。 * 此时解析出真正的參数作为根对象 * </pre> * @author jundong.xu_C * @param parameterObj * @return */ protected Object extractRealParameterObject(Object parameterObj) { if (parameterObj instanceof Map<?, ?>) { Map<?, ?

    > parameterMap = (Map<?, ?>) parameterObj; if (parameterMap.size() == 2) { boolean springMapWithNoParamName = true; for (Object key : parameterMap.keySet()) { if (!(key instanceof String)) { springMapWithNoParamName = false; break; } String keyStr = (String) key; if (!"0".equals(keyStr) && !"1".equals(keyStr)) { springMapWithNoParamName = false; break; } } if (springMapWithNoParamName) { for (Object value : parameterMap.values()) { if (!(value instanceof Page<?

    >)) { return value; } } } } } return parameterObj; } protected void prepareAndCheckDatabaseType(Connection connection) throws SQLException { if (databaseType == null) { String productName = connection.getMetaData().getDatabaseProductName(); if (log.isTraceEnabled()) { log.trace("Database productName: " + productName); } productName = productName.toLowerCase(); if (productName.indexOf(MYSQL) != -1) { databaseType = MYSQL; } else if (productName.indexOf(ORACLE) != -1) { databaseType = ORACLE; } else { throw new PageNotSupportException("Page not support for the type of database, database product name [" + productName + "]"); } if (log.isInfoEnabled()) { log.info("自己主动检測到的数据库类型为: " + databaseType); } } } /** * <pre> * 生成分页SQL * </pre> * * @author jundong.xu_C * @param page * @param sql * @return */ protected String buildPageSql(Page<?> page, String sql) { if (MYSQL.equalsIgnoreCase(databaseType)) { return buildMysqlPageSql(page, sql); } else if (ORACLE.equalsIgnoreCase(databaseType)) { return buildOraclePageSql(page, sql); } return sql; } /** * <pre> * 生成Mysql分页查询SQL * </pre> * * @author jundong.xu_C * @param page * @param sql * @return */ protected String buildMysqlPageSql(Page<?> page, String sql) { // 计算第一条记录的位置,Mysql中记录的位置是从0開始的。 int offset = (page.getPageNo() - 1) * page.getPageSize(); return new StringBuilder(sql).append(" limit ").append(offset).append(",").append(page.getPageSize()).toString(); } /** * <pre> * 生成Oracle分页查询SQL * </pre> * * @author jundong.xu_C * @param page * @param sql * @return */ protected String buildOraclePageSql(Page<?> page, String sql) { // 计算第一条记录的位置。Oracle分页是通过rownum进行的。而rownum是从1開始的 int offset = (page.getPageNo() - 1) * page.getPageSize() + 1; StringBuilder sb = new StringBuilder(sql); sb.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize()); sb.insert(0, "select * from (").append(") where r >= ").append(offset); return sb.toString(); } /** * <pre> * 查询总数 * </pre> * * @author jundong.xu_C * @param page * @param parameterObject * @param mappedStatement * @param connection * @throws SQLException */ protected void queryTotalRecord(Page<?> page, Object parameterObject, MappedStatement mappedStatement, Connection connection) throws SQLException { BoundSql boundSql = mappedStatement.getBoundSql(page); String sql = boundSql.getSql(); String countSql = this.buildCountSql(sql); if (log.isDebugEnabled()) { log.debug("分页时, 生成countSql: " + countSql); } List<ParameterMapping> parameterMappings = boundSql.getParameterMappings(); BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parameterObject); ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql); PreparedStatement pstmt = null; ResultSet rs = null; try { pstmt = connection.prepareStatement(countSql); parameterHandler.setParameters(pstmt); rs = pstmt.executeQuery(); if (rs.next()) { long totalRecord = rs.getLong(1); page.setTotalRecord(totalRecord); } } finally { if (rs != null) try { rs.close(); } catch (Exception e) { if (log.isWarnEnabled()) { log.warn("关闭ResultSet时异常.", e); } } if (pstmt != null) try { pstmt.close(); } catch (Exception e) { if (log.isWarnEnabled()) { log.warn("关闭PreparedStatement时异常.", e); } } } } /** * 依据原Sql语句获取相应的查询总记录数的Sql语句 * * @param sql * @return */ protected String buildCountSql(String sql) { int index = sql.indexOf("from"); return "select count(*) " + sql.substring(index); } /** * 利用反射进行操作的一个工具类 * */ private static class ReflectUtil { /** * 利用反射获取指定对象的指定属性 * * @param obj 目标对象 * @param fieldName 目标属性 * @return 目标属性的值 */ public static Object getFieldValue(Object obj, String fieldName) { Object result = null; Field field = ReflectUtil.getField(obj, fieldName); if (field != null) { field.setAccessible(true); try { result = field.get(obj); } catch (IllegalArgumentException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IllegalAccessException e) { // TODO Auto-generated catch block e.printStackTrace(); } } return result; } /** * 利用反射获取指定对象里面的指定属性 * * @param obj 目标对象 * @param fieldName 目标属性 * @return 目标字段 */ private static Field getField(Object obj, String fieldName) { Field field = null; for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) { try { field = clazz.getDeclaredField(fieldName); break; } catch (NoSuchFieldException e) { // 杩欓噷涓嶇敤鍋氬�鐞嗭紝瀛愮被娌℃湁璇ュ瓧娈靛彲鑳藉�搴旂殑鐖剁被鏈夛紝閮芥病鏈夊氨杩斿洖null銆� } } return field; } /** * 利用反射设置指定对象的指定属性为指定的值 * * @param obj 目标对象 * @param fieldName 目标属性 * @param fieldValue 目标值 */ public static void setFieldValue(Object obj, String fieldName, String fieldValue) { Field field = ReflectUtil.getField(obj, fieldName); if (field != null) { try { field.setAccessible(true); field.set(obj, fieldValue); } catch (IllegalArgumentException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IllegalAccessException e) { // TODO Auto-generated catch block e.printStackTrace(); } } } } public static class PageNotSupportException extends RuntimeException { public PageNotSupportException() { super(); } public PageNotSupportException(String message, Throwable cause) { super(message, cause); } public PageNotSupportException(String message) { super(message); } public PageNotSupportException(Throwable cause) { super(cause); } } }

    3. spring配置文件(mybatis已和spring整合)

        <!-- 配置mybatis的sqlSessionFactory -->
        <bean id="sqlSessionFactoryBean" class="org.mybatis.spring.SqlSessionFactoryBean">
            <property name="dataSource" ref="dataSource"></property>
            <!-- 配置了typeAliasesPackage之后,在映射文件里,这个包下的实体类能够不写全名 -->
            <property name="typeAliasesPackage" value="com.sm.model"></property>
            <!-- 配置映射映射文件的位置 -->
            <property name="mapperLocations" value="classpath:resources/mapper/*.xml"></property>
            <property name="plugins">
                <!-- 分页拦截器 -->
                <bean class="com.sm.model.PageInterceptor"></bean>
            </property>
        </bean>

    4. mapper.xml

    <select id="getUsers" resultType="User" parameterType="Map">
        select * from user where username=#{user.username}
    </select>

    6. DAO

    List<User> getUsers(Map map);

    7. 測试

    Page page = new Page();
    //配置分页參数
    page.setPageNo(1);
    page.setPageSize(3);
    //条件查询,传參
    User user = new User();
    user.setUsername("2");
    
    Map map = new HashMap<>();
    map.put("user", user);
    map.put("page", page);
    
    List<User> list = userDAO.getUsers(map);
    System.out.println(list);
    System.out.println(page);

    这里写图片描写叙述

    8. 总结

    上面的分页拦截器,拷下来直接用就好了。假设想了解实现原理。能够看慕课网的视频通过自己主动回复机器人学Mybatis—加强版

  • 相关阅读:
    Python 工程管理及 virtualenv 的迁移
    Python基础系列讲解——random模块随机数的生成
    Python进阶量化交易场外篇5——标记A股市场涨跌周期
    Python学习案例之视频人脸检测识别
    基于python的Splash基本使用和负载均衡配置
    你所听到的技术原理、技术本质到底是什么?
    BAT大厂面试流程剖析
    基于Python的ModbusTCP客户端实现
    互联网寒冬,Python 程序员如何准备面试
    ES-查询后10000条数据的设置
  • 原文地址:https://www.cnblogs.com/cynchanpin/p/7215932.html
Copyright © 2011-2022 走看看