1 package com.dx.efuwu.core 2 3 import org.apache.commons.lang.StringUtils 4 import java.sql.PreparedStatement 5 6 /** 7 * sql 模板处理 8 * @author sunzq 9 * 2017/06/02 10 */ 11 12 /** 13 * 查询的一个条件句 14 */ 15 class QueryBranch(val content: String, val key: String, val type: String) { 16 17 override fun toString(): String { 18 return "{content:$content, key:$key, type: $type}" 19 } 20 21 fun build(input: String?): String { 22 return if (StringUtils.isNotBlank(input)) content.replace("""(?:#|##|@|@@){.*}""".toRegex(), replacement = "?") 23 else if (type.length > 1) " 1=0 " 24 else " 1=1 " 25 } 26 27 fun doSetParameter(input: String?, index: Int, pstmt: PreparedStatement): Boolean { 28 input?.takeIf { 29 StringUtils.isNotBlank(it) 30 }?.run { 31 when (type) { 32 "#", "##" -> pstmt.setString(index, input) 33 "@", "@@" -> pstmt.setLong(index, input.toLong()) 34 } 35 return true 36 } 37 return false 38 } 39 40 } 41 42 43 /** 44 * sql 查询模板处理器,并不提供 sql 是否正确的解析. 45 * @ 原样填写,不自动填充单引号, 为空则 true @@ 为空则 false 46 * # 自动包一下引号, 为空则 true ## 为空则 false 47 */ 48 class SQLTemplateExecutor(sql: String) { 49 50 51 /** 52 * 查询条件句 53 */ 54 val queryBranches = ArrayList<QueryBranch>() 55 56 /** 57 * 切割后的查询 sql 58 */ 59 val splitQueryStrings = ArrayList<String>() 60 61 init { 62 val _sql = sql.replace(""" """.toRegex(), " ") 63 val regex = """S+(?:s|=|>|<|in|not|like|IN|NOT|LIKE)+(#|##|@|@@){([^}]+)}""".toRegex() 64 splitQueryStrings.addAll(_sql.split(regex)) 65 regex.findAll(_sql).iterator().forEach { matchResult -> 66 matchResult.groupValues.apply { 67 queryBranches.add(QueryBranch(content = get(0), key = get(2).trim(), type = get(1))) 68 } 69 } 70 } 71 72 fun doQuery(dict: Map<String, String?>): ArrayList<Map<String, Any>> { 73 val sqlBuilder = StringBuilder() 74 val queryBranchIterator = queryBranches.iterator() 75 splitQueryStrings.forEach { queryString -> 76 sqlBuilder.append(queryString) 77 if (queryBranchIterator.hasNext()) { 78 queryBranchIterator.next().apply { 79 sqlBuilder.append(build(dict[key])) 80 } 81 } 82 } 83 84 val queryResults = ArrayList<Map<String, Any>>() 85 SQL_SERVER_DATA_SOURCE.connection.apply { 86 prepareStatement(sqlBuilder.toString()).apply { 87 // 设置查询参数 88 var paramIndex = 1 89 queryBranches.forEach { queryBranch -> 90 val param = dict[queryBranch.key] 91 if (queryBranch.doSetParameter(param, paramIndex, this)) { 92 paramIndex++ 93 } 94 } 95 // 执行查询 96 executeQuery().apply { 97 while (next()) { 98 val result = HashMap<String, Any>(8) 99 metaData.apply { 100 (1..columnCount).forEach { i -> 101 result[getColumnLabel(i)] = 102 when (getColumnTypeName(i)) { 103 "int" -> getInt(i) 104 else -> getString(i) 105 } 106 } 107 } 108 queryResults.add(result) 109 } 110 close() 111 } 112 close() 113 } 114 close() 115 } 116 return queryResults 117 } 118 } 119 120 fun main(args: Array<String>) { 121 val sql = """select * from test""" 122 println(SQLTemplateExecutor(sql).doQuery(hashMapOf( 123 "userName1" to "大%", 124 "age" to "18", 125 "s2" to "hello,world", 126 "userId" to "10000" 127 ))) 128 }