zoukankan      html  css  js  c++  java
  • 基于 @SelectProvider 注解实现无侵入的通用Dao

    基于 @SelectProvider 注解实现无侵入的通用Dao

    项目框架

    基于 SpringBoot 2.x 和 mybatis-spring-boot-starter

    代码设计

    通用Dao

    public interface BaseDao<E,I> {
    
        @SelectProvider(type = BaseSqlProvider.class,method = "getById")
        E getById(I id);
    
        @SelectProvider(type = BaseSqlProvider.class,method = "listByEntity")
        List<E> listByEntity(E e);
    
        @SelectProvider(type = BaseSqlProvider.class,method = "getByEntity")
        E getByEntity(E e);
    
        @SelectProvider(type = BaseSqlProvider.class,method = "listByLambdaQuery")
        List<E> listByLambdaQuery(GetterFunction<E,?> lambda, Object val);
        
        @SelectProvider(type = BaseSqlProvider.class,method = "getByLambdaQuery")
        List<E> getByLambdaQuery(GetterFunction<E,?> lambda, Object val);
    
        @SelectProvider(type = BaseSqlProvider.class,method = "listByIds")
        List<E> listByIds(Collection<I> collection);
    
        @InsertProvider(type = BaseSqlProvider.class,method = "insert")
        @Options(keyProperty="id",useGeneratedKeys=true)
        int insert(E e);
    
        @InsertProvider(type = BaseSqlProvider.class,method = "insertBatch")
        @Options(keyProperty="id",useGeneratedKeys=true)
        int insertBatch(Collection<E> list);
    
        @UpdateProvider(type = BaseSqlProvider.class,method = "update")
        int update(E e);
    
        @UpdateProvider(type = BaseSqlProvider.class,method = "updateBatch")
        int updateBatch(Collection<E> list);
    
        @DeleteProvider(type = BaseSqlProvider.class,method = "deleteById")
        int deleteById(I id);
    
        @DeleteProvider(type = BaseSqlProvider.class,method = "deleteByEntity")
        int deleteByEntity(E e);
    
        @DeleteProvider(type = BaseSqlProvider.class,method = "deleteByIds")
        int deleteByIds(Collection<I> list);
    
        @SelectProvider(type = BaseSqlProvider.class,method = "countAll")
        int countAll();
    
        @SelectProvider(type = BaseSqlProvider.class,method = "countByEntity")
        int countByEntity(E e);
    
    }
    

    通用SQL Provider

    //用于缓存和返回通用SQL语句
    public class BaseSqlProvider {
    
        private static final Map<Integer,String> sqlCache = new ConcurrentHashMap<>();
    
        public String getById(ProviderContext context) {
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.getById(context);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String getByEntity(Object object,ProviderContext context) throws Exception {
            if (object==null){
                throw new Exception("entity can not be null!");
            }
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.getByEntity(object);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String listByIds(Collection collection, ProviderContext context) throws Exception {
            if (collection==null || collection.size()==0){
                throw new Exception("id list can not be empty!");
            }
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.listByIds(context);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String listByEntity(Object object,ProviderContext context) throws Exception {
            if (object==null){
                throw new Exception("entity can not be null!");
            }
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.listByEntity(object);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String listByLambdaQuery(Map<String,Object> params,ProviderContext context) throws Exception {
            Object val = params.get("val");
            if (val==null){
                throw new Exception("value can not be null!");
            }
            GetterFunction lambda = (GetterFunction)params.get("lambda");
            int key = context.hashCode();
            String fieldName = lambda.getFieldName(lambda);
            String value = sqlCache.get(key+fieldName);
            if (value==null){
                value = BaseSqlBuilder.listByField(fieldName,context);
                sqlCache.put(key+fieldName,value);
            }
            return value;
        }
    
        public String getByLambdaQuery(Map<String,Object> params,ProviderContext context) throws Exception {
            Object val = params.get("val");
            if (val==null){
                throw new Exception("value can not be null!");
            }
            GetterFunction lambda = (GetterFunction)params.get("lambda");
            int key = context.hashCode();
            String fieldName = lambda.getFieldName(lambda);
            String value = sqlCache.get(key+fieldName);
            if (value==null){
                value = BaseSqlBuilder.getByField(fieldName,context);
                sqlCache.put(key+fieldName,value);
            }
            return value;
        }
    
        public String insert(Object object, ProviderContext context) throws Exception {
            if (object==null){
                throw new Exception("entity can not be null!");
            }
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.insert(object);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String insertBatch(Collection collection, ProviderContext context) throws Exception {
            if (collection==null || collection.size()==0){
                throw new Exception("entity list can not be empty!");
            }
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.insertBatch(context);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String update(Object object, ProviderContext context) throws Exception {
            if (object==null){
                throw new Exception("entity can not be null!");
            }
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.update(object);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String updateBatch(Collection collection, ProviderContext context) throws Exception {
            if (collection==null || collection.size()==0){
                throw new Exception("entity list can not be empty!");
            }
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.updateBatch(context);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String deleteById(ProviderContext context) {
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.deleteById(context);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String deleteByEntity(Object object,ProviderContext context) throws Exception {
            if (object==null){
                throw new Exception("entity can not be null!");
            }
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.deleteByEntity(object);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String deleteByIds(Collection collection, ProviderContext context) throws Exception {
            if (collection==null || collection.size()==0){
                throw new Exception("id list can not be empty!");
            }
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.deleteByIds(context);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String countAll(ProviderContext context) {
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.countAll(context);
                sqlCache.put(key,value);
            }
            return value;
        }
    
        public String countByEntity(Object object,ProviderContext context) throws Exception {
            if (object==null){
                throw new Exception("entity can not be null!");
            }
            int key = context.hashCode();
            String value = sqlCache.get(key);
            if (value==null){
                value = BaseSqlBuilder.countByEntity(object);
                sqlCache.put(key,value);
            }
            return value;
        }
    
    }
    

    通用SQL构建类

    //生成通用SQL语句
    public class BaseSqlBuilder {
    
        public static String getById(ProviderContext context) {
            Class eClass = TableEntityMetaData.getEntityType(context);
            String tableName = TableEntityMetaData.tableName(eClass);
            List<String> fields = TableEntityMetaData.entityFields(eClass);
            List<String> columns = TableEntityMetaData.tableColumns(fields);
            return "SELECT "+String.join(",",columns)+" FROM "+tableName+" WHERE "+TableEntityMetaData.getIdColumn(eClass)+" = #{id}";
        }
    
        public static String listByEntity(Object object) {
            Class eClass = object.getClass();
            String tableName = TableEntityMetaData.tableName(eClass);
            List<String> fields = TableEntityMetaData.entityFields(eClass);
            List<String> columns = TableEntityMetaData.tableColumns(fields);
            StringBuilder sql = new StringBuilder("<script> SELECT ");
            sql.append(String.join(",",columns));
            sql.append(" FROM ").append(tableName);
            sql.append(" <where>");
            whereByEntity(fields,columns,sql);
            sql.append("</where></script>");
            return sql.toString();
        }
    
        public static String getByEntity(Object object) {
            return listByEntity(object)+" LIMIT 1";
        }
    
        public static String listByIds(ProviderContext context) {
            Class eClass = TableEntityMetaData.getEntityType(context);
            String tableName = TableEntityMetaData.tableName(eClass);
            List<String> fields = TableEntityMetaData.entityFields(eClass);
            List<String> columns = TableEntityMetaData.tableColumns(fields);
            StringBuilder sql = new StringBuilder("<script> SELECT ");
            sql.append(String.join(",",columns));
            sql.append(" FROM ").append(tableName);
            sql.append(" WHERE ").append(TableEntityMetaData.getIdColumn(eClass)).append(" IN ");
            sql.append("<foreach item="item" collection="list" separator="," open="(" close=")" index="index">");
            sql.append("#{item}</foreach></script>");
            return sql.toString();
        }
    
        public static String listByField(String fieldName, ProviderContext context) throws Exception {
            Class eClass = TableEntityMetaData.getEntityType(context);
            String tableName = TableEntityMetaData.tableName(eClass);
            List<String> fields = TableEntityMetaData.entityFields(eClass);
            if (!fields.contains(fieldName)) {
                throw new Exception("not exist column '"+fieldName+"'");
            }
            List<String> columns = TableEntityMetaData.tableColumns(fields);
            return "SELECT "+String.join(",",columns)+" FROM "+tableName+" WHERE "+TableEntityMetaData.toLowerCase(fieldName)+" = #{val}";
        }
    
        public static String getByField(String fieldName, ProviderContext context) throws Exception {
            return listByField(fieldName,context)+" LIMIT 1";
        }
    
        public static String insert(Object object) {
            Class eClass = object.getClass();
            String tableName = TableEntityMetaData.tableName(eClass);
            List<String> fields = TableEntityMetaData.entityFields(eClass);
            List<String> columns = TableEntityMetaData.tableColumns(fields);
            StringBuilder sql = new StringBuilder();
            sql.append("<script> INSERT INTO ").append(tableName);
            sql.append(" <trim prefix="(" suffix=")" suffixOverrides=",">");
            for (int i = 0; i < fields.size(); i++) {
                sql.append("<if test="").append(fields.get(i)).append(" != null">");
                sql.append(columns.get(i)).append(",").append("</if>");
            }
            sql.append("</trim><trim prefix="values (" suffix=")" suffixOverrides=",">");
            for (int i = 0; i < fields.size(); i++) {
                sql.append("<if test="").append(fields.get(i)).append(" != null">");
                sql.append("#{").append(fields.get(i)).append("},").append("</if>");
            }
            sql.append("</trim></script>");
            return sql.toString();
        }
    
        public static String insertBatch(ProviderContext context) {
            Class eClass = TableEntityMetaData.getEntityType(context);
            String tableName = TableEntityMetaData.tableName(eClass);
            List<String> fields = TableEntityMetaData.entityFields(eClass);
            List<String> columns = TableEntityMetaData.tableColumns(fields);
            StringBuilder sql = new StringBuilder();
            sql.append("<script> INSERT INTO ").append(tableName);
            sql.append("(").append(String.join(", ",columns)).append(") values ");
            sql.append("<foreach item="item" collection="list" separator="," open="" close="" index="index"> (");
            for (int i = 0; i < fields.size(); i++) {
                sql.append("#{item.").append(fields.get(i)).append("}");
                if (i<fields.size()-1){
                    sql.append(", ");
                }
            }
            sql.append(")</foreach></script>");
            return sql.toString();
        }
    
        public static String update(Object object) {
            Class eClass = object.getClass();
            String tableName = TableEntityMetaData.tableName(eClass);
            List<String> fields = TableEntityMetaData.entityFields(eClass);
            List<String> columns = TableEntityMetaData.tableColumns(fields);
            StringBuilder sql = new StringBuilder("<script> UPDATE ");
            sql.append(tableName).append(" <set>");
            for (int i = 1; i < fields.size(); i++) {
                sql.append("<if test="").append(fields.get(i)).append(" != null">");
                sql.append(columns.get(i)).append(" = #{").append(fields.get(i)).append("},</if>");
            }
            sql.append("</set> WHERE ").append(TableEntityMetaData.getIdColumn(eClass));
            sql.append(" = #{").append(TableEntityMetaData.getIdField(eClass)).append("} </script>");
            return sql.toString();
        }
    
        public static String updateBatch(ProviderContext context) {
            Class eClass = TableEntityMetaData.getEntityType(context);
            String tableName = TableEntityMetaData.tableName(eClass);
            List<String> fields = TableEntityMetaData.entityFields(eClass);
            List<String> columns = TableEntityMetaData.tableColumns(fields);
            StringBuilder sql = new StringBuilder("<script> UPDATE ");
            sql.append(tableName).append(" <trim prefix="set" suffixOverrides=",">");
            for (int i = 1; i < fields.size(); i++) {
                sql.append("<trim prefix="").append(columns.get(i)).append(" = case" suffix="end,">");
                sql.append("<foreach collection="list" item="item" index="index">");
                sql.append("when ").append(TableEntityMetaData.getIdColumn(eClass));
                sql.append(" = #{item.").append(TableEntityMetaData.getIdField(eClass)).append("} then #{item.").append(fields.get(i)).append("}");
                sql.append("</foreach></trim>");
            }
            sql.append("</trim> WHERE ").append(TableEntityMetaData.getIdColumn(eClass)).append(" IN ");
            sql.append("<foreach collection="list" index="index" item="item" separator="," open="(" close=")">");
            sql.append("#{item.").append(TableEntityMetaData.getIdField(eClass)).append("} </foreach></script>");
            return sql.toString();
        }
    
        public static String deleteById(ProviderContext context) {
            Class eClass = TableEntityMetaData.getEntityType(context);
            String tableName = TableEntityMetaData.tableName(eClass);
            return "DELETE FROM "+tableName+" WHERE "+TableEntityMetaData.getIdColumn(eClass)+" = #{id}";
        }
    
        public static String deleteByEntity(Object object) {
            Class eClass = object.getClass();
            String tableName = TableEntityMetaData.tableName(eClass);
            List<String> fields = TableEntityMetaData.entityFields(eClass);
            List<String> columns = TableEntityMetaData.tableColumns(fields);
            StringBuilder sql = new StringBuilder("<script> DELETE FROM ");
            sql.append(tableName).append(" <where> ");
            whereByEntity(fields,columns,sql);
            sql.append("</where></script>");
            return sql.toString();
        }
    
        public static String deleteByIds(ProviderContext context) {
            Class eClass = TableEntityMetaData.getEntityType(context);
            String tableName = TableEntityMetaData.tableName(eClass);
            StringBuilder sql = new StringBuilder("<script> DELETE FROM ");
            sql.append(tableName).append(" WHERE ").append(TableEntityMetaData.getIdColumn(eClass)).append(" IN ");
            sql.append("<foreach item="item" collection="list" separator="," open="(" close=")" index="index">");
            sql.append("#{item}</foreach></script>");
            return sql.toString();
        }
    
        public static String countAll(ProviderContext context) {
            Class eClass = TableEntityMetaData.getEntityType(context);
            String tableName = TableEntityMetaData.tableName(eClass);
            return "SELECT COUNT(*) FROM "+tableName;
        }
    
        public static String countByEntity(Object object) {
            Class eClass = object.getClass();
            String tableName = TableEntityMetaData.tableName(eClass);
            List<String> fields = TableEntityMetaData.entityFields(eClass);
            List<String> columns = TableEntityMetaData.tableColumns(fields);
            StringBuilder sql = new StringBuilder("<script> SELECT COUNT(*) FROM ");
            sql.append(tableName).append(" <where> ");
            whereByEntity(fields,columns,sql);
            sql.append("</where></script>");
            return sql.toString();
        }
    
        private static void whereByEntity(List<String> fields,List<String> columns,StringBuilder sql){
            for (int i = 0; i < fields.size(); i++) {
                sql.append("<if test="").append(fields.get(i)).append(" != null">");
                sql.append("and ").append(columns.get(i));
                sql.append(" = #{").append(fields.get(i)).append("}</if>");
            }
        }
    }
    

    表实体元数据工具类

    //通过ProviderContext和entity实体对象获取表和实体元数据信息
    //通过实体类型获取表名和列名,但数据库和实体必须遵循下划线转驼峰规则
    //即表列名必须全小写,多单词以下划线分割,实体属性必须为驼峰规则
    public class TableEntityMetaData {
        public static Class getEntityType(ProviderContext context) {
            Class mClass = context.getMapperType();
            return (Class) ((ParameterizedType) (mClass.getGenericInterfaces()[0])).getActualTypeArguments()[0];
        }
    
        public static String getIdColumn(Class eClass){
            return  "id";
        }
    
        public static String getIdField(Class eClass){
            return "id";
        }
    
        public static String tableName(Class eClass) {
            String entityName = eClass.getSimpleName();
            return toLowerCase(entityName);
        }
    
        public static List<String> entityFields(Class eClass) {
            Field[] fields = eClass.getDeclaredFields();
            List<String> entityFields = new ArrayList<>(fields.length);
            for (int i = 0; i < fields.length; i++) {
                String name = fields[i].getName();
                if (name.equals(getIdField(eClass))){
                    entityFields.add(0,name);
                }else {
                    entityFields.add(name);
                }
            }
            return entityFields;
        }
    
        public static List<String> tableColumns(List<String> entityFields) {
            List<String> tableColumns =new ArrayList<>(entityFields.size());
            for (String field : entityFields) {
                tableColumns.add(toLowerCase(field));
            }
            return tableColumns;
        }
    
        public static String toLowerCase(String camelStr) {
            String lowerCase = camelStr.replaceAll("[A-Z]", "_$0").toLowerCase();
            if (lowerCase.startsWith("_")){
                lowerCase = lowerCase.substring(1);
            }
            return lowerCase;
        }
    }
    

    lambda query function 接口

    @FunctionalInterface
    public interface GetterFunction<T,R> extends Serializable,Function<T,R> {
        default String getFieldName(GetterFunction<T,?> func) {
            try {
                Method method = func.getClass().getDeclaredMethod("writeReplace");
                method.setAccessible(Boolean.TRUE);
                SerializedLambda serializedLambda = (SerializedLambda) method.invoke(func);
                String getter = serializedLambda.getImplMethodName();
                String get = "get";
                if (getter.startsWith("is")) {
                    get = "is";
                }
                String fieldName = Introspector.decapitalize(getter.replace(get, ""));
                return fieldName;
            } catch (ReflectiveOperationException e) {
                throw new RuntimeException(e);
            }
        }
    }
    

    实体类

    @Data//lombok
    public class User {
        /**
        * 主键,自增
        */
        private Integer id;
        private String username;
        private String password;
        /**
        * 记录生成时间,默认当前时间
        */
        private Date gmtCreate;
        /**
        * 记录修改时间,默认当前时间
        */
        private Date gmtModified;
    }
    

    具体Dao

    public interface UserDao extends BaseDao<User,Integer> {
    }
    

    yml mybatis配置

    一定要开启下划线转驼峰设置

    mybatis:
      mapper-locations: classpath*:mapper/**/*.xml
      configuration:
        map-underscore-to-camel-case: true
    

    使用

    • 自己写的SQL可以放在resources/mapper路径里的Mapper.xml中
    • 对应dao方法则放在具体的dao中
  • 相关阅读:
    IntellJ IDEA快捷键
    【Java基础】Java 增强型的for循环 for each
    Markdown简易入门
    kafka性能调优
    百度地图 libBaiduMapSDK_base_v4_2_1.so" is 32-bit instead of 64-bit错误
    centos7防火墙firewalld拒绝某ip或者某ip段访问服务器任何服务
    华为策略路由配置
    Windows Server 2012 R2 英文版安装中文语言包教程更改为中文版
    linux修改网卡名为eth0
    华为路由配置IPSec
  • 原文地址:https://www.cnblogs.com/xiaogblog/p/14151888.html
Copyright © 2011-2022 走看看