zoukankan      html  css  js  c++  java
  • JPA利用Java反射机制动态构建sql

    第一步:在项目pom.xml 加入 JPA 框架的 maven 依赖坐标

    <!-- 数据库 ORM 框架 -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-data-jpa</artifactId>
    </dependency>
    JPA的maven坐标

    第二步:在项目中的 model 层下创建一个实体类 CityData

    /**
     * @author chaoyou
     * @email 
     * @date 2019-10-8 17:55
     * @Description 城市信息实体类
     */
    @Entity
    @Table(name = "city_data")
    public class CityData implements Serializable {
        @Id
        @GeneratedValue(strategy = GenerationType.IDENTITY)
        private Long id;
        @Column(unique = true)
        private String cityCode;    // 城市编码
        @Column(unique = true)
        private String cityName;    // 城市名字
        @Column
        private String area;    // 区域
        @Column
        private String province;    // 省份/直辖市
    
    
        @Column
        private String prefectureLevel;  // 地级城市
        @Column
        private String town;    // 县城
        @Column
        private Integer level;  // 城市等级:1、省会城市,2、地级城市,3、直辖市,4、县级市
        @Column(name = "create_time", updatable = false)
        @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8")
        private Date createTime;    // 创建时间
        @Column(name = "update_time", insertable = false)
        @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8")
        private Date updateTime;    // 更新时间
        @Column
        private String cityRank;        //城市级别  (T1,T2)
    
        public CityData() {
        }
    
        public CityData(Long id) {
            this.id = id;
        }
    
        public Long getId() {
            return id;
        }
    
        public void setId(Long id) {
            this.id = id;
        }
    
        public String getCityCode() {
            return cityCode;
        }
    
        public void setCityCode(String cityCode) {
            this.cityCode = cityCode;
        }
    
        public String getCityName() {
            return cityName;
        }
    
        public void setCityName(String cityName) {
            this.cityName = cityName;
        }
    
        public String getArea() {
            return area;
        }
    
        public void setArea(String area) {
            this.area = area;
        }
    
        public String getProvince() {
            return province;
        }
    
        public void setProvince(String province) {
            this.province = province;
        }
    
        public String getPrefectureLevel() {
            return prefectureLevel;
        }
    
        public void setPrefectureLevel(String prefectureLevel) {
            this.prefectureLevel = prefectureLevel;
        }
    
        public String getTown() {
            return town;
        }
    
        public void setTown(String town) {
            this.town = town;
        }
    
        public Integer getLevel() {
            return level;
        }
    
        public void setLevel(Integer level) {
            this.level = level;
        }
    
        public Date getCreateTime() {
            return createTime;
        }
    
        public void setCreateTime(Date createTime) {
            this.createTime = createTime;
        }
    
        public Date getUpdateTime() {
            return updateTime;
        }
    
        public void setUpdateTime(Date updateTime) {
            this.updateTime = updateTime;
        }
    
        public String getCityRank() {
            return cityRank;
        }
    
        public void setCityRank(String cityRank) {
            this.cityRank = cityRank;
        }
    }
    CityData实体类

    第三步:在项目的 dao 层下创建一个供外部调用的接口 EntryMapper

    import java.lang.reflect.Field;
    import java.util.List;
    
    /**
     * @author chaoyou
     * @email 
     * @date 2020-6-11 14:56
     * @Description 定义接口用于 获取实体
     */
    public interface EntryMapper {
    
        /**
         * 获取实体类对应数据表的表名
         */
        String getEntryTableName(Class<?> clazz);
    
        /**
         * 获取实体类对应数据表的主键字段
         */
        String getPKFieldName(Class<?> clazz);
    
        /**
         * 获取实体类对应数据表的外键字段
         */
        List<String> getFKFieldName(Class<?> clazz);
    
        /**
         * 获取实体类对应数据表的所有字段列名
         */
        List<String> getSequenceName(Class<?> clazz);
    
        /**
         * 获取实体类对应数据表的所有变量
         */
        List<Field> getFieldList(Class<?> clazz);
    
        /**
         * 获取实体类对应数据表的常规插入操作的 sql 语句
         */
        String getSqlToSave(Class<?> clazz);
    
        /**
         * 获取实体类对应数据表的常规更新操作的 sql 语句
         */
        String getSqlToUpdate(Class<?> clazz, String field);
    }
    实体类操作接口

    第四步:在项目的 impl 层下创建一个对 EntryMapper 接口的实现类

    import org.apache.commons.lang3.StringUtils;
    import org.springframework.stereotype.Component;
    import org.springframework.util.Assert;
    
    import javax.persistence.*;
    import java.lang.reflect.Field;
    import java.util.ArrayList;
    import java.util.List;
    
    /**
     * @author chaoyou
     * @email 
     * @date 2020-6-11 14:47
     * @Description 这是获取实体类型信息的工具类
     */
    @Component
    public class JpaEntryMapper implements EntryMapper {
    
        /**
         * 获取实体类的名字
         *
         * @param clazz 实体对象
         * @return
         */
        @Override
        public String getEntryTableName(Class<?> clazz) {
            // 校验参数对象是否为空
            Assert.notNull(clazz, "clazz不能为空");
            // 获取「Table」注解的控制权
            Table  tableAnno = clazz.getAnnotation(Table.class);
            // 校验是否有「table」注解
            Assert.notNull(tableAnno, "@Table注解未设置");
            // 检验「table」注解是否设置 name 属性值
            Assert.state(StringUtils.isNotEmpty(tableAnno.name()), "@Table 的 name 属性未设置");
    
            return tableAnno.name();
        }
    
        /**
         * 获取实体类中的主键属性名
         *
         * @param clazz     实体对象
         * @return
         */
        @Override
        public String getPKFieldName(Class<?> clazz) {
            Assert.notNull(clazz, "clazz不能为空");
    
            // 通过「Java反射机制」拿到实体类的所有被 private 修饰的属性(不包括继承属性)
            Field[] fields = clazz.getDeclaredFields();
    
            // 获取参数类中所有public的属性,包括继承的public属性
    //        Field[] fields = clazz.getFields();
    
            String pk = null;
    
            if (fields.length == 0){
                return pk;
            }
    
            /**
             * 遍历属性列表,找到被「@Id」注解修饰的属性
             */
            for (Field field : fields){
                if (field.getAnnotation(Id.class) != null){
                    pk = field.getName();
                    break;
                }
            }
            return pk;
        }
    
        /**
         * 获取实体类中的外键属性
         *
         * @param clazz
         * @return
         */
        @Override
        public List<String> getFKFieldName(Class<?> clazz) {
            Assert.notNull(clazz, "clazz不能为空");
    
            List<String> fks = new ArrayList<>();
            String fk = null;
    
            Field[] fields = clazz.getDeclaredFields();
            if (fields.length == 0){
                return fks;
            }
    
            /**
             * 遍历属性列表,找到被「JoinColumn」注解修饰的属性
             */
            for (Field field : fields){
                JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
                if (joinColumn != null){
                    if (!"".equals(joinColumn.name()) && !"".equals(joinColumn.referencedColumnName())){
                        fks.add(joinColumn.name() + "-" + joinColumn.referencedColumnName() + "-" + field.getName());
                    } else if (!"".equals(joinColumn.name()) && "".equals(joinColumn.referencedColumnName())){
                        fks.add(joinColumn.name() + "-id" + "-" + field.getName());
                    } else if ("".equals(joinColumn.name()) && !"".equals(joinColumn.referencedColumnName())){
                        fks.add(field.getName() + "-" + joinColumn.referencedColumnName() + "-" + field.getName());
                    } else if ("".equals(joinColumn.name()) && "".equals(joinColumn.referencedColumnName())){
                        fks.add(field.getName() + "-id" + "-" + field.getName());
                    }
                    break;
                }
            }
    
            return fks;
        }
    
        /**
         * 获取实体类中所有属性对应的持久化字段(主键除外)
         *
         * @param clazz
         * @return
         */
        @Override
        public List<String> getSequenceName(Class<?> clazz) {
            Assert.notNull(clazz, "clazz不能为空");
            List<String> fieldList = null;
            Field[] fields = clazz.getDeclaredFields();
            if (fields.length == 0){
                return fieldList;
            }
            fieldList = new ArrayList<>();
            // 普通持久化字段
            Column column = null;
            // 非持久化字段
            Transient tran = null;
            // 外键字段
            JoinColumn joinColumn = null;
            for (Field field : fields){
                tran = field.getAnnotation(Transient.class);
                if (tran != null){
                    continue;
                }
                column = field.getAnnotation(Column.class);
                joinColumn = field.getAnnotation(JoinColumn.class);
                if (column != null){
                    if (!"".equals(column.name())){
                        // Column 注解的 name 属性作为其对应的数据表映射字段
                        fieldList.add(column.name());
                    } else{
                        fieldList.add(field.getName());
                    }
                } else if (joinColumn != null){
                    if (!"".equals(joinColumn.name())){
                        // JoinColumn 注解的 name 属性作为其对应的数据表映射字段
                        fieldList.add(joinColumn.name());
                    } else{
                        fieldList.add(field.getName());
                    }
                }
            }
            return fieldList;
        }
    
        /**
         * 获取实体类中所有持久化字段的属性
         *
         * @param clazz
         * @return
         */
        @Override
        public List<Field> getFieldList(Class<?> clazz) {
            Assert.notNull(clazz, "clazz不能为空");
            List<Field> fieldList = null;
            Field[] fields = clazz.getDeclaredFields();
            if (fields.length == 0){
                return fieldList;
            }
            fieldList = new ArrayList<>();
            Column column = null;
            // 非持久化字段注解
            Transient tran = null;
            JoinColumn joinColumn = null;
            for (Field field : fields){
                tran = field.getAnnotation(Transient.class);
                if (tran != null){
                    continue;
                }
                column = field.getAnnotation(Column.class);
                joinColumn = field.getAnnotation(JoinColumn.class);
                if (column != null || joinColumn != null){
                    fieldList.add(field);
                }
            }
            return fieldList;
        }
    
        /**
         * 设置一个该实体类的 insert 持久化 sql
         *
         * @param clazz
         * @return
         */
        @Override
        public String getSqlToSave(Class<?> clazz) {
            List<String> list = getSequenceName(clazz);
            List<String> sqls = new ArrayList<>();
            for (int i=0; i<list.size(); i++){
                sqls.add("?");
            }
            String sql = "insert into " + getEntryTableName(clazz) + "("
                    + ArrayUtil.getStringByArray(list.toArray(), ", ")
                    + ") values(" + ArrayUtil.getStringByArray(sqls.toArray(), ", ") + ")";
            return sql;
        }
    
        /**
         * 设置一个该实体类的 update 持久化 sql
         *
         * @param clazz
         * @param field
         * @return
         */
        @Override
        public String getSqlToUpdate(Class<?> clazz, String field) {
            List<String> list = getSequenceName(clazz);
            List<String> sqls = new ArrayList<>();
            for (int i=0; i<list.size(); i++){
                sqls.add(list.get(i) + "=?");
            }
            String sql = "update " + getEntryTableName(clazz) + " set "
                    + ArrayUtil.getStringByArray(sqls.toArray(), ", ")
                    + " where " + field + "=?";
            return sql;
        }
    }
    接口实现类

    第五步:在 test 层下创建一个测试类LCYTest

    import org.junit.Test;
    import org.junit.runner.RunWith;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.boot.test.context.SpringBootTest;
    import org.springframework.test.context.ActiveProfiles;
    import org.springframework.test.context.junit4.SpringRunner;
    
    import java.lang.reflect.Field;
    import java.text.ParseException;
    import java.util.Date;
    import java.util.List;
    
    /**
     * @author chaoyou
     * @email 
     * @date 2020-8-12 16:44
     * @Description
     * @Reference
     */
    @RunWith(SpringRunner.class)
    @SpringBootTest
    public class LCYTest {
    
        @Autowired
        private EntryMapper entryMapper;
    
        @Test
        public  void test09(){
            String entryTableName = entryMapper.getEntryTableName(CityData.class);
            String pkFieldName = entryMapper.getPKFieldName(CityData.class);
            List<String> fkFieldNameList = entryMapper.getFKFieldName(CityData.class);
            List<String> sequenceNameList = entryMapper.getSequenceName(CityData.class);
            List<Field> fieldList = entryMapper.getFieldList(CityData.class);
            String sqlToSave = entryMapper.getSqlToSave(CityData.class);
            String sqlToUpdate = entryMapper.getSqlToUpdate(CityData.class, "id");
            System.out.println("entryTableName:" + entryTableName);
            System.out.println("pkFieldName:" + pkFieldName);
            System.out.println("fkFieldNameList:" + fkFieldNameList);
            System.out.println("sequenceNameList:" + sequenceNameList);
            System.out.println("sqlToSave:" + sqlToSave);
            System.out.println("sqlToUpdate:" + sqlToUpdate);
        }
    }
    测试类

    第六步:当然就是看结果了

  • 相关阅读:
    Linux cron
    web报表工具FineReport常用函数的用法总结(文本函数)
    web报表工具FineReport常用函数的用法总结(文本函数)
    oracle instr函数
    死锁的例子和 synchronized 嵌套使用
    死锁的例子和 synchronized 嵌套使用
    Perl 监控批量错误
    Linux以百万兆字节显示内存大小
    Linux以GB显示内存大小
    Linux以KB显示内存大小
  • 原文地址:https://www.cnblogs.com/chaoyou/p/13541811.html
Copyright © 2011-2022 走看看