首页 > 代码库 > mybatis拦截器分页

mybatis拦截器分页

package com.test.interceptor;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import com.mysql.jdbc.PreparedStatement;
import com.test.util.Page;

@Intercepts({
        @Signature(type = StatementHandler.class, method = "prepare", args = { Connection.class }),
        @Signature(method = "query", type = Executor.class, args = {
                MappedStatement.class, Object.class, RowBounds.class,
                ResultHandler.class }) })
public class StatementHandleInterceptor implements Interceptor {    
    public static final String MYSQL = "mysql";    
    protected ThreadLocal<Page> pageThreadLocal = new ThreadLocal<Page>();     
   
    public Object intercept(Invocation invocation) throws Throwable {
        if (invocation.getTarget() instanceof StatementHandler){
            Page<?> page = pageThreadLocal.get();
            if(page==null){
                return invocation.proceed();
            }            
            RoutingStatementHandler statementHandler = (RoutingStatementHandler) invocation
                    .getTarget();
            StatementHandler delegate = ReflectUtil.getFieldValue(
                    statementHandler, "delegate");
            BoundSql boundSql = delegate.getBoundSql();
            Connection connection = (Connection) invocation.getArgs()[0];
            
            if(page.getTotalPage()>-1){
                System.out.println("总页数:"+page.getTotalPage());                
            }else{
                Object obj = boundSql.getParameterObject();
                MetaObject metaStatementHandler = SystemMetaObject.forObject(statementHandler);
                MappedStatement mappedStatement=(MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");                
                queryTotalRecord(page, obj, mappedStatement, connection);
            }
            String sql = boundSql.getSql();
            String pageSql = buildPageSql(page,sql);
            System.out.println("分页时,生成pageSql:"+pageSql);            
            ReflectUtil.setFieldValue((Object)boundSql, "sql",pageSql);
            return invocation.proceed();
        }else{
            Page<?> page = findPageObject(invocation.getArgs()[1]);
            if(page==null){
                System.out.println("没有page参数对象,不是分页查询");                
                return invocation.proceed();
            }else{
                System.out.println("检测到page对象!使用分页查询");                
            }            
            
            pageThreadLocal.set(page);
            try{
                return invocation.proceed();
                //可setpage  Results
                /*Object resultObj = invocation.proceed();
                if(resultObj instanceof List){
                    page.setResults((List)resultObj);
                }
                return resultObj;*/
                
            }finally{
                pageThreadLocal.remove();
            }
        }         
    }
    
     private String buildPageSql(Page page,String sql) {
        // 计算第一条记录的位置,Mysql中记录的位置是从0开始的。
        int offset = (page.getPageNo() - 1) * page.getPageSize();
        return new StringBuilder(sql).append(" limit ").append(offset)
                .append(",").append(page.getPageSize()).toString();
    }    

    
    /**
     * 判定是否需要分页拦截
     * @param object
     * @return
     */
    private Page<?> findPageObject(Object object) {
        if(object instanceof Page<?>){
            return (Page<?>) object;
        }else if(object instanceof Map){
            for(Object o:((Map<?,?>) object).values()){
                if(o instanceof Page<?>){
                    return (Page<?>) o;
                }
            }
        }
        return null;
    }
    /**
     * 查询总记录数
     * @param page
     * @param obj
     * @param mappedStatement
     * @param connection
     * @throws SQLException
     */
    private void queryTotalRecord(Page<?> page, Object obj,
            MappedStatement mappedStatement, Connection connection) throws SQLException {
        BoundSql boundSql = mappedStatement.getBoundSql(page);
        String sql = boundSql.getSql();
        String countSql = this.buildCountSql(sql);
        System.out.println("分页时,生成countSql:"+countSql);        
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(),countSql,parameterMappings,obj);
        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, obj, countBoundSql);
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try{
            pstmt = (PreparedStatement) connection.prepareStatement(countSql);
            parameterHandler.setParameters(pstmt);
            rs = pstmt.executeQuery();
            if(rs.next()){
                 long totalRecord = rs.getLong(1);
                 page.setTotalRecord(totalRecord);
            }    
        }finally{
            if(rs!=null){
                rs.close();
            }
            if(pstmt!=null){
                pstmt.close();
            }
        }
        
    }
    /**
     * 构造查询总记录数sql
     * @param sql
     * @return
     */
    private String buildCountSql(String sql) {
        int index = sql.toLowerCase().indexOf("from");        
        return "select count(*)"+sql.substring(index);
    }
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    public void setProperties(Properties properties) {

    }

}

调用

技术分享

结果:

技术分享

 

mybatis拦截器分页