之前一篇Python 封装DBUtils 和pymysql 中写过一个basedao.py,最近几天又重新整理了下思绪,优化了下 basedao.py,目前支持的方法还不多,后续会进行改进、添加。
主要功能:
1.查询单个对象:
所需参数:表名,过滤条件
2.查询多个对象:
所需参数:表名,过滤条件
3.按主键查询:
所需参数:表名,值
4.分页查询:
所需参数:表名,页码,每页记录数,过滤条件
调用方法锁获取的对象都是以字典形式存储,例如:查询user表(字段有id,name,age)里的id=1的数据返回的对象为user = {"id":1,"name","zhangsan","age":18},我们可以通过user.get("id")来获取id值,非常方便,不用定义什么类对象来表示。如果查询的是多个,那么多个字典对象将会存放在一个列表里返回。
具体代码如下:
1 import json, os, sys, time, pymysql, pprint 2 3 from DBUtils import PooledDB 4 5 def print(*args): 6 pprint.pprint(args) 7 8 def get_time(): 9 '获取时间' 10 return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 11 12 def stitch_sequence(seq=None, suf=None): 13 '如果参数("suf")不为空,则根据特殊的suf拼接列表元素,返回一个字符串' 14 if seq is None: raise Exception("Parameter seq is None"); 15 if suf is None: suf = "," 16 r = str() 17 for s in seq: 18 r += s + suf 19 return r[:-len(suf)] 20 21 class BaseDao(object): 22 """ 23 简便的数据库操作基类 24 """ 25 def __init__(self, creator=pymysql, host="localhost",port=3306, user=None, password="", 26 database=None, charset="utf8"): 27 if host is None: raise Exception("Parameter [host] is None.") 28 if port is None: raise Exception("Parameter [port] is None.") 29 if user is None: raise Exception("Parameter [user] is None.") 30 if password is None: raise Exception("Parameter [password] is None.") 31 if database is None: raise Exception("Parameter [database] is None.") 32 # 数据库连接配置 33 self.__config = dict({ 34 "creator" : creator, "charset":charset, "host":host, "port":port, 35 "user":user, "password":password, "database":database 36 }) 37 self.__database = self.__config["database"] # 用于存储查询数据库 38 self.__tableName = None # 用于临时存储当前查询表名 39 # 初始化 40 self.__init_connect() # 初始化连接 41 self.__init_params() # 初始化参数 42 print(get_time(), self.__database, "数据库初始化成功。") 43 44 def __del__(self): 45 '重写类被清除时调用的方法' 46 if self.__cursor: self.__cursor.close() 47 if self.__conn: self.__conn.close() 48 print(get_time(), self.__database, "连接关闭") 49 50 def __init_connect(self): 51 self.__conn = PooledDB.connect(**self.__config) 52 self.__cursor = self.__conn.cursor() 53 54 def __init_params(self): 55 '初始化参数' 56 self.__init_table_dict() 57 self.__init__table_column_dict_list() 58 59 def __init__information_schema_columns(self): 60 "查询 information_schema.`COLUMNS` 中的列" 61 sql = """ SELECT COLUMN_NAME FROM information_schema.`COLUMNS` 62 WHERE TABLE_SCHEMA='information_schema' AND TABLE_NAME='COLUMNS' 63 """ 64 result_tuple = self.__exec_query(sql) 65 column_list = [r[0] for r in result_tuple] 66 return column_list 67 68 def __init_table_dict(self): 69 "查询配置数据库中改的所有表" 70 schema_column_list = self.__init__information_schema_columns() 71 stitch_str = stitch_sequence(schema_column_list) 72 sql1 = """ SELECT TABLE_NAME FROM information_schema.`TABLES` 73 WHERE TABLE_SCHEMA='%s' 74 """ %(self.__database) 75 table_tuple = self.__exec_query(sql1) 76 self.__table_dict = {t[0]:{} for t in table_tuple} 77 for table in self.__table_dict.keys(): 78 sql = """ SELECT %s FROM information_schema.`COLUMNS` 79 WHERE TABLE_SCHEMA='%s' AND TABLE_NAME='%s' 80 """ %(stitch_str, self.__database, table) 81 column_tuple = self.__exec_query(sql) 82 column_dict = {} 83 for vs in column_tuple: 84 d = {k:v for k,v in zip(schema_column_list, vs)} 85 column_dict[d["COLUMN_NAME"]] = d 86 self.__table_dict[table] = column_dict 87 88 def __init__table_column_dict_list(self): 89 self.__table_column_dict_list = {} 90 for table, column_dict in self.__table_dict.items(): 91 column_list = [column for column in column_dict.keys()] 92 self.__table_column_dict_list[table] = column_list 93 94 def __exec_query(self, sql, single=False): 95 ''' 96 执行查询方法 97 - @sql 查询 sql 98 - @single 是否查询单个结果集,默认False 99 ''' 100 try: 101 self.__cursor.execute(sql) 102 print(get_time(), "SQL[%s]"%sql) 103 if single: 104 result_tuple = self.__cursor.fetchone() 105 else: 106 result_tuple = self.__cursor.fetchall() 107 return result_tuple 108 except Exception as e: 109 print(e) 110 111 def __exec_update(self, sql): 112 try: 113 # 获取数据库游标 114 result = self.__cursor.execute(sql) 115 print(get_time(), "SQL[%s]"%sql) 116 self.__conn.commit() 117 return result 118 except Exception as e: 119 print(e) 120 self.__conn.rollback() 121 122 def __parse_result(self, result): 123 '用于解析单个查询结果,返回字典对象' 124 if result is None: return None 125 obj = {k:v for k,v in zip(self.__column_list, result)} 126 return obj 127 128 def __parse_results(self, results): 129 '用于解析多个查询结果,返回字典列表对象' 130 if results is None: return None 131 objs = [self.__parse_result(result) for result in results] 132 return objs 133 134 def __getpk(self, tableName): 135 if self.__table_dict.get(tableName) is None: raise Exception(tableName, "is not exist.") 136 for column, column_dict in self.__table_dict[tableName].items(): 137 if column_dict["COLUMN_KEY"] == "PRI": return column 138 139 def __get_table_column_list(self, tableName=None): 140 '查询表的字段列表, 将查询出来的字段列表存入 __fields 中' 141 return self.__table_column_dict_list[tableName] 142 143 def __query_util(self, filters=None): 144 """ 145 SQL 语句拼接方法 146 @filters 过滤条件 147 """ 148 sql = r'SELECT #{FIELDS} FROM #{TABLE_NAME} WHERE 1=1 #{FILTERS}' 149 # 拼接查询表 150 sql = sql.replace("#{TABLE_NAME}", self.__tableName) 151 # 拼接查询字段 152 FIELDS = stitch_sequence(self.__get_table_column_list(self.__tableName)) 153 sql = sql.replace("#{FIELDS}", FIELDS) 154 # 拼接查询条件(待优化) 155 if filters is None: 156 sql = sql.replace("#{FILTERS}", "") 157 else: 158 FILTERS = "" 159 if not isinstance(filters, dict): 160 raise Exception("Parameter [filters] must be dict type. ") 161 isPage = False 162 if filters.get("_limit_"): isPage = True 163 if isPage: beginindex, limit = filters.pop("_limit_") 164 for k, v in filters.items(): 165 if k.startswith("_in_"): # 拼接 in 166 FILTERS += "AND %s IN (" %(k[4:]) 167 values = v.split(",") 168 for value in values: 169 FILTERS += "%s,"%value 170 FILTERS = FILTERS[0:len(FILTERS)-1] + ") " 171 elif k.startswith("_nein_"): # 拼接 not in 172 FILTERS += "AND %s NOT IN (" %(k[4:]) 173 values = v.split(",") 174 for value in values: 175 FILTERS += "%s,"%value 176 FILTERS = FILTERS[0:len(FILTERS)-1] + ") " 177 elif k.startswith("_like_"): # 拼接 like 178 FILTERS += "AND %s like '%%%s%%' " %(k[6:], v) 179 elif k.startswith("_ne_"): # 拼接不等于 180 FILTERS += "AND %s != '%s' " %(k[4:], v) 181 elif k.startswith("_lt_"): # 拼接小于 182 FILTERS += "AND %s < '%s' " %(k[4:], v) 183 elif k.startswith("_le_"): # 拼接小于等于 184 FILTERS += "AND %s <= '%s' " %(k[4:], v) 185 elif k.startswith("_gt_"): # 拼接大于 186 FILTERS += "AND %s > '%s' " %(k[4:], v) 187 elif k.startswith("_ge_"): # 拼接大于等于 188 FILTERS += "AND %s >= '%s' " %(k[4:], v) 189 else: # 拼接等于 190 FILTERS += "AND %s='%s' "%(k, v) 191 sql = sql.replace("#{FILTERS}", FILTERS) 192 if isPage: sql += "LIMIT %d,%d"%(beginindex, limit) 193 return sql 194 195 def __check_params(self, tableName): 196 ''' 197 检查参数 198 ''' 199 if tableName is None and self.__tableName is None: 200 raise Exception("Parameter [tableName] is None.") 201 elif self.__tableName is None or self.__tableName != tableName: 202 self.__tableName = tableName 203 self.__column_list = self.__table_column_dict_list[self.__tableName] 204 205 def select_one(self, tableName=None, filters={}): 206 ''' 207 查询单个对象 208 @tableName 表名 209 @filters 过滤条件 210 @return 返回字典集合,集合中以表字段作为 key,字段值作为 value 211 ''' 212 self.__check_params(tableName) 213 sql = self.__query_util(filters) 214 result = self.__exec_query(sql, single=True) 215 return self.__parse_result(result) 216 217 def select_pk(self, tableName=None, primaryKey=None): 218 ''' 219 按主键查询 220 @tableName 表名 221 @primaryKey 主键值 222 ''' 223 self.__check_params(tableName) 224 filters = {} 225 filters.setdefault(self.__getpk(tableName), primaryKey) 226 sql = self.__query_util(filters) 227 result = self.__exec_query(sql, single=True) 228 return self.__parse_result(result) 229 230 def select_all(self, tableName=None, filters={}): 231 ''' 232 查询所有 233 @tableName 表名 234 @filters 过滤条件 235 @return 返回字典集合,集合中以表字段作为 key,字段值作为 value 236 ''' 237 self.__check_params(tableName) 238 sql = self.__query_util(filters) 239 results = self.__exec_query(sql) 240 return self.__parse_results(results) 241 242 def count(self, tableName=None): 243 ''' 244 统计记录数 245 ''' 246 self.__check_params(tableName) 247 sql = "SELECT count(*) FROM %s"%(self.__tableName) 248 result = self.__exec_query(sql, single=True) 249 return result[0] 250 251 def select_page(self, tableName=None, pageNum=1, limit=10, filters={}): 252 ''' 253 分页查询 254 @tableName 表名 255 @return 返回字典集合,集合中以表字段作为 key,字段值作为 value 256 ''' 257 self.__check_params(tableName) 258 totalCount = self.count(tableName) 259 if totalCount / limit == 0 : 260 totalPage = totalCount / limit 261 else: 262 totalPage = totalCount // limit + 1 263 if pageNum > totalPage: 264 print("最大页数为%d"%totalPage) 265 pageNum = totalPage 266 elif pageNum < 1: 267 print("页数不能小于1") 268 pageNum = 1 269 beginindex = (pageNum-1) * limit 270 filters.setdefault("_limit_", (beginindex, limit)) 271 sql = self.__query_util(filters) 272 result_tuple = self.__exec_query(sql) 273 return self.__parse_results(result_tuple) 274 275 if __name__ == "__main__": 276 config = { 277 # "creator": pymysql, 278 # "host" : "127.0.0.1", 279 "user" : "root", 280 "password" : "root", 281 "database" : "test", 282 # "port" : 3306, 283 # "charset" : 'utf8' 284 } 285 base = BaseDao(**config) 286 ######################################################################## 287 user = base.select_one("user") 288 print(user) 289 ######################################################################## 290 # users = base.select_all("user") 291 # print(users) 292 ######################################################################## 293 filter1 = { 294 "status":1, 295 "_in_id":"1,2,3,4,5", 296 "_like_name":"zhang", 297 "_ne_name":"wangwu" 298 } 299 user_filters = base.select_all("user", filter1) 300 print(user_filters) 301 ######################################################################## 302 role = base.select_one("role") 303 print(role) 304 ######################################################################## 305 user_pk = base.select_pk("user", 2) 306 print(user_pk) 307 ######################################################################## 308 user_limit = base.select_page("user", 1, 10) 309 print(user_limit) 310 ########################################################################
更新:2017-08-25
1 import json, os, sys, time, pymysql, pprint, logging 2 3 logging.basicConfig( 4 level=logging.DEBUG, 5 format='%(asctime)s [%(levelname)s] %(message)s', 6 datefmt='%a, %d %b %Y %H:%M:%S') 7 8 from DBUtils import PooledDB 9 10 def print(*args): 11 pprint.pprint(args) 12 13 def get_time(): 14 '获取时间' 15 return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 16 17 def stitch_sequence(seq=None, suf=None): 18 '如果参数("suf")不为空,则根据特殊的suf拼接列表元素,返回一个字符串。默认使用 ","。' 19 if seq is None: raise Exception("Parameter seq is None"); 20 if suf is None: suf = "," 21 r = str() 22 for s in seq: 23 r += s + suf 24 return r[:-len(suf)] 25 26 class BaseDao(object): 27 """ 28 简便的数据库操作基类,该类所操作的表必须有主键 29 初始化参数如下: 30 - creator: 创建连接对象(默认: pymysql) 31 - host: 连接数据库主机地址(默认: localhost) 32 - port: 连接数据库端口(默认: 3306) 33 - user: 连接数据库用户名(默认: None), 如果为空,则会抛异常 34 - password: 连接数据库密码(默认: None), 如果为空,则会抛异常 35 - database: 连接数据库(默认: None), 如果为空,则会抛异常 36 - chatset: 编码(默认: utf8) 37 - tableName: 初始化 BaseDao 对象的数据库表名(默认: None), 如果为空, 38 则会初始化该数据库下所有表的信息, 如果不为空,则只初始化传入的 tableName 的表 39 """ 40 def __init__(self, creator=pymysql, host="localhost",port=3306, user=None, password=None, 41 database=None, charset="utf8", tableName=None): 42 if host is None: raise Exception("Parameter [host] is None.") 43 if port is None: raise Exception("Parameter [port] is None.") 44 if user is None: raise Exception("Parameter [user] is None.") 45 if password is None: raise Exception("Parameter [password] is None.") 46 if database is None: raise Exception("Parameter [database] is None.") 47 if tableName is None: print("WARNING >>> Parameter [tableName] is None. All tables will be initialized.") 48 logging.debug("[%s] 数据库初始化>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>开始"%(database)) 49 start = time.time() 50 # 数据库连接配置 51 self.__config = dict({ 52 "creator" : creator, "charset":charset, "host":host, "port":port, 53 "user":user, "password":password, "database":database 54 }) 55 self.__database = database # 用于存储查询数据库 56 self.__tableName = tableName # 用于临时存储当前查询表名 57 # 初始化 58 self.__init_connect() # 初始化连接 59 self.__init_params() # 初始化参数 60 end = time.time() 61 logging.debug("[%s] 数据库初始化>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>结束"%(database)) 62 logging.debug("[%s] 数据库初始化成功。耗时:%d ms。"%(database, (end-start))) 63 64 def __del__(self): 65 '重写类被清除时调用的方法' 66 if self.__cursor: self.__cursor.close() 67 if self.__conn: self.__conn.close() 68 logging.debug("[%s] 连接关闭。"%(self.__database)) 69 70 def __init_connect(self): 71 '初始化连接' 72 self.__conn = PooledDB.connect(**self.__config) 73 self.__cursor = self.__conn.cursor() 74 75 def __init_params(self): 76 '初始化参数' 77 self.__table_dict = {} 78 self.__information_schema_columns = [] 79 self.__table_column_dict_list = {} 80 if self.__tableName is None: 81 self.__init_table_dict_list() 82 self.__init__table_column_dict_list() 83 else: 84 self.__init_table_dict(self.__tableName) 85 self.__init__table_column_dict_list() 86 self.__column_list = self.__table_column_dict_list[self.__tableName] 87 88 def __init__information_schema_columns(self): 89 "查询 information_schema.`COLUMNS` 中的列" 90 sql = """ SELECT COLUMN_NAME 91 FROM information_schema.`COLUMNS` 92 WHERE TABLE_SCHEMA='information_schema' AND TABLE_NAME='COLUMNS' 93 """ 94 result_tuple = self.__exec_query(sql) 95 column_list = [r[0] for r in result_tuple] 96 self.__information_schema_columns = column_list 97 98 def __init_table_dict(self, tableName): 99 '初始化表' 100 if not self.__information_schema_columns: 101 self.__init__information_schema_columns() 102 stitch_str = stitch_sequence(self.__information_schema_columns) 103 sql = """ SELECT %s FROM information_schema.`COLUMNS` 104 WHERE TABLE_SCHEMA='%s' AND TABLE_NAME='%s' 105 """ %(stitch_str, self.__database, tableName) 106 column_tuple = self.__exec_query(sql) 107 column_dict = {} 108 for vs in column_tuple: 109 d = {k:v for k,v in zip(self.__information_schema_columns, vs)} 110 column_dict[d["COLUMN_NAME"]] = d 111 self.__table_dict[tableName] = column_dict 112 113 def __init_table_dict_list(self): 114 "初始化表字典对象" 115 if not self.__information_schema_columns: 116 self.__init__information_schema_columns() 117 stitch_str = stitch_sequence(self.__information_schema_columns) 118 sql1 = """ 119 SELECT TABLE_NAME FROM information_schema.`TABLES` WHERE TABLE_SCHEMA='%s' 120 """ %(self.__database) 121 table_tuple = self.__exec_query(sql1) 122 self.__table_dict = {t[0]:{} for t in table_tuple} 123 for table in table_tuple: 124 self.__init_table_dict(table[0]) 125 126 def __init__table_column_dict_list(self): 127 '''初始化表字段字典列表''' 128 for table, column_dict in self.__table_dict.items(): 129 column_list = [column for column in column_dict.keys()] 130 self.__table_column_dict_list[table] = column_list 131 132 def __exec_query(self, sql, single=False): 133 '''执行查询 SQL 语句 134 - @sql 查询 sql 135 - @single 是否查询单个结果集,默认False 136 ''' 137 try: 138 logging.debug("[%s] SQL >>> [%s]"%(self.__database, sql)) 139 self.__cursor.execute(sql) 140 if single: 141 result_tuple = self.__cursor.fetchone() 142 else: 143 result_tuple = self.__cursor.fetchall() 144 return result_tuple 145 except Exception as e: 146 print(e) 147 148 def __exec_update(self, sql): 149 '''执行更新 SQL 语句''' 150 try: 151 # 获取数据库游标 152 logging.debug("[%s] SQL >>> [%s]"%(self.__database, sql)) 153 result = self.__cursor.execute(sql) 154 self.__conn.commit() 155 return result 156 except Exception as e: 157 print(e) 158 self.__conn.rollback() 159 160 def __parse_result(self, result): 161 '用于解析单个查询结果,返回字典对象' 162 if result is None: return None 163 obj = {k:v for k,v in zip(self.__column_list, result)} 164 return obj 165 166 def __parse_results(self, results): 167 '用于解析多个查询结果,返回字典列表对象' 168 if results is None: return None 169 objs = [self.__parse_result(result) for result in results] 170 return objs 171 172 def __getpk(self, tableName): 173 '获取表对应的主键字段' 174 if self.__table_dict.get(tableName) is None: raise Exception(tableName, "is not exist.") 175 for column, column_dict in self.__table_dict[tableName].items(): 176 if column_dict["COLUMN_KEY"] == "PRI": return column 177 178 def __get_table_column_list(self, tableName=None): 179 '查询表的字段列表, 将查询出来的字段列表存入 __fields 中' 180 return self.__table_column_dict_list[tableName] 181 182 def __check_tableName(self, tableName): 183 '''验证 tableName 参数''' 184 if tableName is None: 185 if self.__tableName is None: 186 raise Exception("Parameter [tableName] is None.") 187 else: 188 if self.__tableName != tableName: 189 self.__tableName = tableName 190 self.__column_list = self.__table_column_dict_list[self.__tableName] 191 192 def select_one(self, tableName=None, filters={}): 193 '''查询单个对象 194 - @tableName 表名 195 - @filters 过滤条件 196 - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value 197 ''' 198 self.__check_tableName(tableName) 199 FIELDS = stitch_sequence(self.__get_table_column_list(self.__tableName)) 200 sql = "SELECT %s FROM %s"%(FIELDS ,self.__tableName) 201 sql = QueryUtil.query_sql(sql, filters) 202 result = self.__exec_query(sql, single=True) 203 return self.__parse_result(result) 204 205 def select_pk(self, tableName=None, primaryKey=None): 206 '''按主键查询 207 - @tableName 表名 208 - @primaryKey 主键值 209 ''' 210 self.__check_tableName(tableName) 211 FIELDS = stitch_sequence(self.__get_table_column_list(self.__tableName)) 212 sql = "SELECT %s FROM %s"%(FIELDS, self.__tableName) 213 sql = QueryUtil.query_sql(sql, {self.__getpk(tableName):primaryKey}) 214 result = self.__exec_query(sql, single=True) 215 return self.__parse_result(result) 216 217 def select_all(self, tableName=None, filters={}): 218 '''查询所有 219 - @tableName 表名 220 - @filters 过滤条件 221 - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value 222 ''' 223 self.__check_tableName(tableName) 224 FIELDS = stitch_sequence(self.__get_table_column_list(self.__tableName)) 225 sql = "SELECT %s FROM %s"%(FIELDS ,self.__tableName) 226 sql = QueryUtil.query_sql(sql, filters) 227 results = self.__exec_query(sql) 228 return self.__parse_results(results) 229 230 def count(self, tableName=None): 231 '''统计记录数''' 232 self.__check_tableName(tableName) 233 sql = "SELECT count(*) FROM %s"%(self.__tableName) 234 result = self.__exec_query(sql, single=True) 235 return result[0] 236 237 def select_page(self, tableName=None, page=None, filters={}): 238 '''分页查询 239 - @tableName 表名 240 - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value 241 ''' 242 self.__check_tableName(tableName) 243 if page is None: 244 page = Page() 245 filters["page"] = page 246 FIELDS = stitch_sequence(self.__get_table_column_list(self.__tableName)) 247 sql = "SELECT %s FROM %s"%(FIELDS ,self.__tableName) 248 sql = QueryUtil.query_sql(sql, filters) 249 result_tuple = self.__exec_query(sql) 250 return self.__parse_results(result_tuple) 251 252 def save(self, tableName=None, obj=dict()): 253 '''保存方法 254 - @param tableName 表名 255 - @param obj 对象 256 - @return 影响行数 257 ''' 258 self.__check_tableName(tableName) 259 FIELDS = stitch_sequence(seq=obj.keys()) 260 VALUES = [] 261 for k, v in obj.items(): 262 if self.__table_dict[self.__tableName][k]["COLUMN_KEY"] != "PKI": 263 if v is None: 264 v = "null" 265 else: 266 v = '"%s"'%v 267 VALUES.append(v) 268 VALUES = stitch_sequence(seq=VALUES) 269 sql = ' INSERT INTO `%s` (%s) VALUES(%s)'%(self.__tableName, FIELDS, VALUES) 270 return self.__exec_update(sql) 271 272 def update(self, tableName=None, obj={}): 273 '''更新方法(根据主键更新,包含空值) 274 - @param tableName 表名 275 - @param obj 对象 276 - @return 影响行数 277 ''' 278 self.__check_tableName(tableName) 279 l = [] 280 where = "WHERE " 281 for k, v in obj.items(): 282 if self.__table_dict[self.__tableName][k]["COLUMN_KEY"] != "PRI": 283 if v is None: 284 if self.__table_dict[self.__tableName][k]["IS_NULLABLE"] == "YES": 285 l.append("%s=null"%(k)) 286 else: 287 l.append("%s=''"%(k)) 288 else: 289 l.append("%s='%s'"%(k, v)) 290 else: 291 where += "%s='%s'"%(k, v) 292 sql = "UPDATE `%s` SET %s %s"%(self.__tableName, stitch_sequence(l), where) 293 return self.__exec_update(sql) 294 295 def update_selective(self, tableName=None, obj={}): 296 '''更新方法(根据主键更新,不包含空值) 297 - @param tableName 表名 298 - @param obj 对象 299 - @return 影响行数 300 ''' 301 self.__check_tableName(tableName) 302 where = "WHERE " 303 l = [] 304 for k, v in obj.items(): 305 if self.__table_dict[self.__tableName][k]["COLUMN_KEY"] != "PRI": 306 if v is None: 307 continue 308 l.append("%s='%s'"%(k, v)) 309 else: 310 where += "%s='%s'"%(k, v) 311 sql = "UPDATE `%s` SET %s %s"%(self.__tableName, stitch_sequence(l), where) 312 return self.__exec_update(sql) 313 314 def remove(self, tableName=None, obj={}): 315 '''删除方法(根据主键删除) 316 - @param tableName 表名 317 - @param obj 对象 318 - @return 影响行数 319 ''' 320 self.__check_tableName(tableName) 321 pk = self.__getpk(self.__tableName) 322 sql = "DELETE FROM `%s` WHERE %s=%s"%(self.__tableName, pk, obj[pk]) 323 print(sql) 324 # return self.__exec_update(sql) 325 326 class Page(object): 327 '分页对象' 328 def __init__(self, pageNum=1, pageSize=10, count=False): 329 ''' 330 Page 初始化方法 331 - @param pageNum 页码,默认为1 332 - @param pageSize 页面大小, 默认为10 333 - @param count 是否包含 count 查询 334 ''' 335 self.pageNum = pageNum if pageNum > 0 else 1 # 当前页数 336 self.pageSize = pageSize if pageSize > 0 else 10 # 分页大小 337 self.total = 0 # 总记录数 338 self.pages = 1 # 总页数 339 self.startRow = (self.pageNum - 1 ) * self.pageSize # 起始行(用于 mysql 分页查询) 340 self.endRow = self.startRow + self.pageSize # 结束行(用于 mysql 分页查询) 341 342 class QueryUtil(object): 343 ''' 344 SQL 语句拼接工具类: 345 - 主方法:querySql(sql, filters) 346 - 参数说明: 347 - @param sql:需要拼接的 SQL 语句 348 - @param filters:拼接 SQL 的过滤条件 349 filters 过滤条件说明: 350 - 支持拼接条件如下: 351 - 1、等于(如:{"id": 2}, 拼接后为:id=2) 352 - 2、不等于(如:{"_ne_id": 2}, 拼接后为:id != 2) 353 - 3、小于(如:{"_lt_id": 2},拼接后为:id < 2) 354 - 4、小于等于(如:{"_le_id": 2},拼接后为:id <= 2) 355 - 5、大于(如:{"_gt_id": },拼接后为:id > 2) 356 - 6、大于等于(如:{"_ge_id": },拼接后为:id >=2) 357 - 7、in(如:{"_in_id": "1,2,3"},拼接后为:id IN(1,2,3)) 358 - 8、not in(如:{"_nein_id": "4,5,6"},拼接后为:id NOT IN(4,5,6)) 359 - 9、like(如:{"_like_name": },拼接后为:name LIKE '%zhang%') 360 - 10、like(如:{"_llike_name": },拼接后为:name LIKE '%zhang') 361 - 11、like(如:{"_rlike_name": },拼接后为:name LIKE 'zhang%') 362 - 12、分组(如:{"groupby": "status"},拼接后为:GROUP BY status) 363 - 13、排序(如:{"orderby": "createDate"},拼接后为:ORDER BY createDate) 364 ''' 365 366 NE = "_ne_" # 拼接不等于 367 LT = "_lt_" # 拼接小于 368 LE = "_le_" # 拼接小于等于 369 GT = "_gt_" # 拼接大于 370 GE = "_ge_" # 拼接大于等于 371 IN = "_in_" # 拼接 in 372 NE_IN = "_nein_" # 拼接 not in 373 LIKE = "_like_" # 拼接 like 374 LEFT_LIKE = "_llike_" # 拼接左 like 375 RIGHT_LIKE = "_rlike_" # 拼接右 like 376 GROUP = "groupby" # 拼接分组 377 ORDER = "orderby" # 拼接排序 378 ORDER_TYPE = "ordertype" # 排序类型:asc(升序)、desc(降序) 379 380 @staticmethod 381 def __filter_params(filters): 382 '''过滤参数条件''' 383 s = " WHERE 1=1" 384 for k, v in filters.items(): 385 if k.startswith(QueryUtil.IN): # 拼接 in 386 s += " AND %s IN (" %(k[4:]) 387 values = v.split(",") 388 for value in values: 389 s += " %s,"%value 390 s = s[0:len(s)-1] + ") " 391 elif k.startswith(QueryUtil.NE_IN): # 拼接 not in 392 s += " AND %s NOT IN (" %(k[4:]) 393 values = v.split(",") 394 for value in values: 395 s += " %s,"%value 396 s = s[0:len(s)-1] + ") " 397 elif k.startswith(QueryUtil.LIKE): # 拼接 like 398 s += " AND %s LIKE '%%%s%%' " %(k[len(QueryUtil.LIKE):], v) 399 elif k.startswith(QueryUtil.LEFT_LIKE): # 拼接左 like 400 s += " AND %s LIKE '%%%s' " %(k[len(QueryUtil.LEFT_LIKE):], v) 401 elif k.startswith(QueryUtil.RIGHT_LIKE): # 拼接右 like 402 s += " AND %s LIKE '%s%%' " %(k[len(QueryUtil.RIGHT_LIKE):], v) 403 elif k.startswith(QueryUtil.NE): # 拼接不等于 404 s += " AND %s != '%s' " %(k[4:], v) 405 elif k.startswith(QueryUtil.LT): # 拼接小于 406 s += " AND %s < '%s' " %(k[4:], v) 407 elif k.startswith(QueryUtil.LE): # 拼接小于等于 408 s += " AND %s <= '%s' " %(k[4:], v) 409 elif k.startswith(QueryUtil.GT): # 拼接大于 410 s += " AND %s > '%s' " %(k[4:], v) 411 elif k.startswith(QueryUtil.GE): # 拼接大于等于 412 s += " AND %s >= '%s' " %(k[4:], v) 413 else: # 拼接等于 414 if isinstance(v, str): 415 s += " AND %s='%s' "%(k, v) 416 elif isinstance(v, int): 417 s += " AND %s=%d "%(k, v) 418 return s 419 420 @staticmethod 421 def __filter_group(filters): 422 '''过滤分组''' 423 group = filters.pop(QueryUtil.GROUP) 424 s = " GROUP BY %s"%(group) 425 return s 426 427 @staticmethod 428 def __filter_order(filters): 429 '''过滤排序''' 430 order = filters.pop(QueryUtil.ORDER) 431 type = filters.pop(QueryUtil.ORDER_TYPE) 432 s = " ORDER BY `%s` %s"%(order, type) 433 return s 434 435 @staticmethod 436 def __filter_page(filters): 437 '''过滤 page 对象''' 438 page = filters.pop("page") 439 return " LIMIT %d,%d"%(page.startRow, page.endRow) 440 441 @staticmethod 442 def query_sql(sql=None, filters=dict()): 443 '''拼接 SQL 查询条件 444 - @param sql SQL 语句 445 - @param filters 过滤条件 446 - @return 返回拼接 SQL 447 ''' 448 if not filters: 449 return sql 450 else: 451 if not isinstance(filters, dict): 452 raise Exception("Parameter [filters] must be dict.") 453 group = None 454 order = None 455 page = None 456 if filters.get("groupby") != None: 457 group = QueryUtil.__filter_group(filters) 458 if filters.get("orderby") != None: 459 order = QueryUtil.__filter_order(filters) 460 if filters.get("page") != None: 461 page = QueryUtil.__filter_page(filters) 462 sql += QueryUtil.__filter_params(filters) 463 if group: 464 sql += group 465 if order: 466 sql += order 467 if page: 468 sql += page 469 return sql 470 471 @staticmethod 472 def query_set(fields, values): 473 s = " SET " 474 for f, v in zip(fields, values): 475 s += '%s="%s", ' 476 pass 477 478 def test(): 479 config = { 480 # "creator": pymysql, 481 # "host" : "127.0.0.1", 482 "user" : "root", 483 "password" : "root", 484 "database" : "py", 485 # "port" : 3306, 486 # "charset" : 'utf8' 487 # "tableName" : "fake", 488 } 489 base = BaseDao(**config) 490 ######################################################################## 491 # fake = base.select_one("fake") 492 # print(fake) 493 ######################################################################## 494 # users = base.select_all("fake") 495 # print(users) 496 ######################################################################## 497 # filter1 = { 498 # "status":1, 499 # "_in_id":"1,2,3,4,5", 500 # "_like_name":"zhang", 501 # "_ne_name":"wangwu" 502 # } 503 # user_filters = base.select_all("user", filter1) 504 # print(user_filters) 505 ######################################################################## 506 # role = base.select_one("role") 507 # print(role) 508 ######################################################################## 509 # fake = base.select_pk("fake", 2) 510 # print(fake) 511 # base.update("fake", fake) 512 # base.update_selective("fake", fake) 513 # base.remove("fake", fake) 514 ######################################################################## 515 # user_limit = base.select_page("user") 516 # print(user_limit) 517 ######################################################################## 518 # fake = { 519 # "id": "null", 520 # "name": "test", 521 # "value": "test" 522 # } 523 # flag = base.save("fake", fake) 524 # print(flag) 525 526 if __name__ == "__main__": 527 test()
以上更新部分比较多,整体上进行了优化,新增了(save,update,delete 方法)。
更新:2017-10-26
1 #!/usr/bin/env python3 2 # -*- coding=utf-8 -*- 3 4 import json 5 import logging 6 import os 7 import sys 8 import time 9 10 import pymysql 11 from DBUtils import PooledDB 12 13 __author__ = "阮程" 14 15 logging.basicConfig( 16 level=logging.INFO, 17 datefmt='%Y-%m-%d %H:%M:%S', 18 format='%(asctime)s [%(levelname)s] %(message)s' 19 ) 20 21 22 def get_time(format=None): 23 '获取时间' 24 format = format or "%Y-%m-%d %H:%M:%S" 25 return time.strftime(format, time.localtime()) 26 27 28 def stitch_sequence(seq=None, suf=None, isField=True): 29 '如果参数("suf")不为空,则根据特殊的suf拼接列表元素,返回一个字符串。默认使用 ","。' 30 if seq is None: 31 raise Exception("Parameter seq is None") 32 suf = suf or "," 33 r = str() 34 for s in seq: 35 r += '`%s`%s' % (s, suf) if isField else '%s%s' % (s, suf) 36 # if isField: 37 # r += '`%s`%s' % (s, suf) 38 # else: 39 # r += '%s%s' % (s, suf) 40 return r[:-len(suf)] 41 42 43 class BaseDao(object): 44 """ 45 简便的数据库操作基类,该类所操作的表必须有主键 46 初始化参数如下: 47 - :creator: 创建连接对象(默认: pymysql) 48 - :host: 连接数据库主机地址(默认: localhost) 49 - :port: 连接数据库端口(默认: 3306) 50 - :user: 连接数据库用户名(默认: None), 如果为空,则会抛异常 51 - :password: 连接数据库密码(默认: None), 如果为空,则会抛异常 52 - :database: 连接数据库(默认: None), 如果为空,则会抛异常 53 - :chatset: 编码(默认: utf8) 54 - :tableName: 初始化 BaseDao 对象的数据库表名(默认: None), 如果为空, 55 则会初始化该数据库下所有表的信息, 如果不为空,则只初始化传入的 tableName 的表 56 """ 57 58 def __init__(self, creator=pymysql, host="localhost", port=3306, user=None, password=None, 59 database=None, charset="utf8", tableName=None, *args, **kwargs): 60 if host is None: 61 raise ValueError("Parameter [host] is None.") 62 if port is None: 63 raise ValueError("Parameter [port] is None.") 64 if user is None: 65 raise ValueError("Parameter [user] is None.") 66 if password is None: 67 raise ValueError("Parameter [password] is None.") 68 if database is None: 69 raise ValueError("Parameter [database] is None.") 70 if tableName is None: 71 print( 72 "WARNING >>> Parameter [tableName] is None. All tables will be initialized.") 73 logging.debug( 74 "[%s] 数据库初始化>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>开始" % (database)) 75 start = time.time() 76 # 数据库连接配置 77 self.__config = dict({ 78 "creator": creator, "charset": charset, "host": host, "port": port, 79 "user": user, "password": password, "database": database 80 }) 81 self.__database = database # 用于存储查询数据库 82 self.__tableName = tableName # 用于临时存储当前查询表名 83 # 初始化 84 self.__init_connect() # 初始化连接 85 self.__init_params() # 初始化参数 86 end = time.time() 87 logging.debug( 88 "[%s] 数据库初始化>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>结束" % (database)) 89 logging.info("[%s] 数据库初始化成功。耗时:%d ms。" % (database, (end - start))) 90 91 def __del__(self): 92 '重写类被清除时调用的方法' 93 if self.__cursor: 94 self.__cursor.close() 95 if self.__conn: 96 self.__conn.close() 97 logging.debug("[%s] 连接关闭。" % (self.__database)) 98 99 def __init_connect(self): 100 '初始化连接' 101 self.__conn = PooledDB.connect(**self.__config) 102 self.__cursor = self.__conn.cursor() 103 104 def __init_params(self): 105 '初始化参数' 106 self._table_dict = {} 107 self.__information_schema_columns = [] 108 self.__table_column_dict_list = {} 109 if self.__tableName is None: 110 self.__init_table_dict_list() 111 self.__init__table_column_dict_list() 112 else: 113 self.__init_table_dict(self.__tableName) 114 self.__init__table_column_dict_list() 115 self.__column_list = self.__table_column_dict_list[self.__tableName] 116 117 def __init__information_schema_columns(self): 118 "查询 information_schema.`COLUMNS` 中的列" 119 sql = """ SELECT COLUMN_NAME 120 FROM information_schema.`COLUMNS` 121 WHERE TABLE_SCHEMA='information_schema' AND TABLE_NAME='COLUMNS' 122 """ 123 result_tuple = self.execute_query(sql) 124 column_list = [r[0] for r in result_tuple] 125 self.__information_schema_columns = column_list 126 127 def __init_table_dict(self, tableName): 128 '初始化表' 129 if not self.__information_schema_columns: 130 self.__init__information_schema_columns() 131 stitch_str = stitch_sequence(self.__information_schema_columns) 132 sql = """ SELECT %s FROM information_schema.`COLUMNS` 133 WHERE TABLE_SCHEMA='%s' AND TABLE_NAME='%s' 134 """ % (stitch_str, self.__database, tableName) 135 column_tuple = self.execute_query(sql) 136 column_dict = {} 137 for vs in column_tuple: 138 d = {k: v for k, v in zip(self.__information_schema_columns, vs)} 139 column_dict[d["COLUMN_NAME"]] = d 140 self._table_dict[tableName] = column_dict 141 142 def __init_table_dict_list(self): 143 "初始化表字典对象" 144 if not self.__information_schema_columns: 145 self.__init__information_schema_columns() 146 stitch_str = stitch_sequence(self.__information_schema_columns) 147 sql = """ SELECT TABLE_NAME FROM information_schema.`TABLES` 148 WHERE TABLE_SCHEMA='%s' 149 """ % (self.__database) 150 table_tuple = self.execute_query(sql) 151 self._table_dict = {t[0]: {} for t in table_tuple} 152 for table in table_tuple: 153 self.__init_table_dict(table[0]) 154 155 def __init__table_column_dict_list(self): 156 '''初始化表字段字典列表''' 157 for table, column_dict in self._table_dict.items(): 158 column_list = [column for column in column_dict.keys()] 159 self.__table_column_dict_list[table] = column_list 160 161 def __parse_result(self, result): 162 '用于解析单个查询结果,返回字典对象' 163 if result is None: 164 return None 165 obj = {k: v for k, v in zip(self.__column_list, result)} 166 return obj 167 168 def __parse_results(self, results): 169 '用于解析多个查询结果,返回字典列表对象' 170 if results is None: 171 return None 172 objs = [self.__parse_result(result) for result in results] 173 return objs 174 175 def __getpk(self, tableName): 176 '获取表对应的主键字段' 177 if self._table_dict.get(tableName) is None: 178 raise Exception(tableName, "is not exist.") 179 for column, column_dict in self._table_dict[tableName].items(): 180 if column_dict["COLUMN_KEY"] == "PRI": 181 return column 182 183 def __get_table_column_list(self, tableName=None): 184 '查询表的字段列表, 将查询出来的字段列表存入 __fields 中' 185 return self.__table_column_dict_list[tableName] 186 187 def __check_tableName(self, tableName): 188 '''验证 tableName 参数''' 189 if tableName is None: 190 if self.__tableName is None: 191 raise Exception("Parameter [tableName] is None.") 192 else: 193 if self.__tableName != tableName: 194 self.__tableName = tableName 195 self.__column_list = self.__table_column_dict_list[self.__tableName] 196 197 def execute_query(self, sql, single=False): 198 '''执行查询 SQL 语句 199 - @sql 查询 sql 200 - @single 是否查询单个结果集,默认False 201 ''' 202 try: 203 logging.info("[%s] SQL >>> [%s]" % (self.__database, sql)) 204 self.__cursor.execute(sql) 205 if single: 206 result_tuple = self.__cursor.fetchone() 207 else: 208 result_tuple = self.__cursor.fetchall() 209 return result_tuple 210 except Exception as e: 211 logging.error(e) 212 213 def execute_update(self, sql): 214 '''执行更新 SQL 语句''' 215 try: 216 # 获取数据库游标 217 logging.info("[%s] SQL >>> [%s]" % (self.__database, sql)) 218 result = self.__cursor.execute(sql) 219 self.__conn.commit() 220 return result 221 except Exception as e: 222 logging.error(e) 223 self.__conn.rollback() 224 225 def select_one(self, tableName=None, filters={}): 226 '''查询单个对象 227 - @tableName 表名 228 - @filters 过滤条件 229 - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value 230 ''' 231 self.__check_tableName(tableName) 232 FIELDS = stitch_sequence( 233 self.__get_table_column_list(self.__tableName)) 234 sql = "SELECT %s FROM %s" % (FIELDS, self.__tableName) 235 sql = QueryUtil.query_sql(sql, filters) 236 result = self.execute_query(sql, single=True) 237 return self.__parse_result(result) 238 239 def select_pk(self, tableName=None, primaryKey=None): 240 '''按主键查询 241 - @tableName 表名 242 - @primaryKey 主键值 243 ''' 244 self.__check_tableName(tableName) 245 FIELDS = stitch_sequence( 246 self.__get_table_column_list(self.__tableName)) 247 sql = "SELECT %s FROM %s" % (FIELDS, self.__tableName) 248 sql = QueryUtil.query_sql(sql, {self.__getpk(tableName): primaryKey}) 249 result = self.execute_query(sql, single=True) 250 return self.__parse_result(result) 251 252 def select_all(self, tableName=None, filters={}): 253 '''查询所有 254 - @tableName 表名 255 - @filters 过滤条件 256 - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value 257 ''' 258 self.__check_tableName(tableName) 259 FIELDS = stitch_sequence( 260 self.__get_table_column_list(self.__tableName)) 261 sql = "SELECT %s FROM %s" % (FIELDS, self.__tableName) 262 sql = QueryUtil.query_sql(sql, filters) 263 results = self.execute_query(sql) 264 return self.__parse_results(results) 265 266 def count(self, tableName=None): 267 '''统计记录数''' 268 self.__check_tableName(tableName) 269 sql = "SELECT count(*) FROM %s" % (self.__tableName) 270 result = self.execute_query(sql, single=True) 271 return result[0] 272 273 def select_page(self, tableName=None, page=None, filters={}): 274 '''分页查询 275 - @tableName 表名 276 - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value 277 ''' 278 self.__check_tableName(tableName) 279 if page is None: 280 page = Page() 281 filters["page"] = page 282 FIELDS = stitch_sequence( 283 self.__get_table_column_list(self.__tableName)) 284 sql = "SELECT %s FROM %s" % (FIELDS, self.__tableName) 285 sql = QueryUtil.query_sql(sql, filters) 286 result_tuple = self.execute_query(sql) 287 return self.__parse_results(result_tuple) 288 289 def save(self, tableName=None, obj=dict()): 290 '''保存方法 291 - @param tableName 表名 292 - @param obj 对象 293 - @return 影响行数 294 ''' 295 self.__check_tableName(tableName) 296 pk = self.__getpk(self.__tableName) 297 if pk not in obj.keys(): 298 obj[pk] = None 299 FIELDS = stitch_sequence(obj.keys()) 300 VALUES = [] 301 for k, v in obj.items(): 302 if self._table_dict[self.__tableName][k]["COLUMN_KEY"] != "PKI": 303 v = "null" if v is None else '"%s"' % v 304 # if v is None: 305 # v = "null" 306 # else: 307 # v = '"%s"' % v 308 VALUES.append(v) 309 VALUES = stitch_sequence(VALUES, isField=False) 310 sql = 'INSERT INTO `%s` (%s) VALUES(%s)' % ( 311 self.__tableName, FIELDS, VALUES) 312 return self.execute_update(sql) 313 314 def update_by_primarykey(self, tableName=None, obj={}): 315 '''更新方法(根据主键更新,包含空值) 316 - @param tableName 表名 317 - @param obj 对象 318 - @return 影响行数 319 ''' 320 self.__check_tableName(tableName) 321 pk = self.__getpk(self.__tableName) 322 if pk not in obj.keys() or obj.get(pk) is None: 323 raise ValueError("Parameter [obj.%s] is None." % pk) 324 l = [] 325 where = "WHERE " 326 for k, v in obj.items(): 327 if self._table_dict[tableName][k]["COLUMN_KEY"] != "PRI": 328 if v is None: 329 if self._table_dict[tableName][k]["IS_NULLABLE"] == "YES": 330 l.append("%s=null" % (k)) 331 else: 332 l.append("%s=''" % (k)) 333 else: 334 l.append("%s='%s'" % (k, v)) 335 else: 336 where += "%s='%s'" % (k, v) 337 sql = "UPDATE `%s` SET %s %s" % ( 338 self.__tableName, stitch_sequence(l, isField=False), where) 339 return self.execute_update(sql) 340 341 def update_by_primarikey_selective(self, tableName=None, obj={}): 342 '''更新方法(根据主键更新,不包含空值) 343 - @param tableName 表名 344 - @param obj 对象 345 - @return 影响行数 346 ''' 347 self.__check_tableName(tableName) 348 pk = self.__getpk(self.__tableName) 349 if pk not in obj.keys() or obj.get(pk) is None: 350 raise ValueError("Parameter [obj.%s] is None." % pk) 351 where = "WHERE " 352 l = [] 353 for k, v in obj.items(): 354 if self._table_dict[self.__tableName][k]["COLUMN_KEY"] != "PRI": 355 if v is None: 356 continue 357 l.append("%s='%s'" % (k, v)) 358 else: 359 where += "%s='%s'" % (k, v) 360 sql = "UPDATE `%s` SET %s %s" % ( 361 self.__tableName, stitch_sequence(l, isField=False), where) 362 return self.execute_update(sql) 363 364 def remove_by_primarykey(self, tableName=None, value=None): 365 '''删除方法(根据主键删除) 366 - @param tableName 表名 367 - @param valuej 主键值 368 - @return 影响行数 369 ''' 370 self.__check_tableName(tableName) 371 if value is None: 372 raise ValueError("Parameter [value] can not be None.") 373 pk = self.__getpk(self.__tableName) 374 sql = "DELETE FROM `%s` WHERE `%s`='%s'" % ( 375 self.__tableName, pk, value) 376 return self.execute_update(sql) 377 378 379 class Page(object): 380 '分页对象' 381 382 def __init__(self, pageNum=1, pageSize=10, count=False): 383 ''' 384 Page 初始化方法 385 - @param pageNum 页码,默认为1 386 - @param pageSize 页面大小, 默认为10 387 - @param count 是否包含 count 查询 388 ''' 389 self.pageNum = pageNum if pageNum > 0 else 1 # 当前页数 390 self.pageSize = pageSize if pageSize > 0 else 10 # 分页大小 391 self.total = 0 # 总记录数 392 self.pages = 1 # 总页数 393 self.startRow = (self.pageNum - 1) * 394 self.pageSize # 起始行(用于 mysql 分页查询) 395 self.endRow = self.startRow + self.pageSize # 结束行(用于 mysql 分页查询) 396 397 398 class QueryUtil(object): 399 ''' 400 SQL 语句拼接工具类: 401 - 主方法:querySql(sql, filters) 402 - 参数说明: 403 - @param sql:需要拼接的 SQL 语句 404 - @param filters:拼接 SQL 的过滤条件 405 filters 过滤条件说明: 406 - 支持拼接条件如下: 407 - 1、等于(如:{"id": 2}, 拼接后为:id=2) 408 - 2、不等于(如:{"_ne_id": 2}, 拼接后为:id != 2) 409 - 3、小于(如:{"_lt_id": 2},拼接后为:id < 2) 410 - 4、小于等于(如:{"_le_id": 2},拼接后为:id <= 2) 411 - 5、大于(如:{"_gt_id": },拼接后为:id > 2) 412 - 6、大于等于(如:{"_ge_id": },拼接后为:id >=2) 413 - 7、in(如:{"_in_id": "1,2,3"},拼接后为:id IN(1,2,3)) 414 - 8、not in(如:{"_nein_id": "4,5,6"},拼接后为:id NOT IN(4,5,6)) 415 - 9、like(如:{"_like_name": },拼接后为:name LIKE '%zhang%') 416 - 10、like(如:{"_llike_name": },拼接后为:name LIKE '%zhang') 417 - 11、like(如:{"_rlike_name": },拼接后为:name LIKE 'zhang%') 418 - 12、分组(如:{"groupby": "status"},拼接后为:GROUP BY status) 419 - 13、排序(如:{"orderby": "createDate"},拼接后为:ORDER BY createDate) 420 ''' 421 422 NE = "_ne_" # 拼接不等于 423 LT = "_lt_" # 拼接小于 424 LE = "_le_" # 拼接小于等于 425 GT = "_gt_" # 拼接大于 426 GE = "_ge_" # 拼接大于等于 427 IN = "_in_" # 拼接 in 428 NE_IN = "_nein_" # 拼接 not in 429 LIKE = "_like_" # 拼接 like 430 LEFT_LIKE = "_llike_" # 拼接左 like 431 RIGHT_LIKE = "_rlike_" # 拼接右 like 432 GROUP = "groupby" # 拼接分组 433 ORDER = "orderby" # 拼接排序 434 ORDER_TYPE = "ordertype" # 排序类型:asc(升序)、desc(降序) 435 436 @staticmethod 437 def __filter_params(filters): 438 '''过滤参数条件''' 439 s = " WHERE 1=1" 440 for k, v in filters.items(): 441 if k.startswith(QueryUtil.IN): # 拼接 in 442 s += " AND `%s` IN (" % (k[len(QueryUtil.IN):]) 443 values = v.split(",") 444 for value in values: 445 s += " %s," % value 446 s = s[0:len(s) - 1] + ") " 447 elif k.startswith(QueryUtil.NE_IN): # 拼接 not in 448 s += " AND `%s` NOT IN (" % (k[len(QueryUtil.NE_IN):]) 449 values = v.split(",") 450 for value in values: 451 s += " %s," % value 452 s = s[0:len(s) - 1] + ") " 453 elif k.startswith(QueryUtil.LIKE): # 拼接 like 454 s += " AND `%s` LIKE '%%%s%%' " % (k[len(QueryUtil.LIKE):], v) 455 elif k.startswith(QueryUtil.LEFT_LIKE): # 拼接左 like 456 s += " AND `%s` LIKE '%%%s' " % ( 457 k[len(QueryUtil.LEFT_LIKE):], v) 458 elif k.startswith(QueryUtil.RIGHT_LIKE): # 拼接右 like 459 s += " AND `%s` LIKE '%s%%' " % ( 460 k[len(QueryUtil.RIGHT_LIKE):], v) 461 elif k.startswith(QueryUtil.NE): # 拼接不等于 462 s += " AND `%s` != '%s' " % (k[len(QueryUtil.NE):], v) 463 elif k.startswith(QueryUtil.LT): # 拼接小于 464 s += " AND `%s` < '%s' " % (k[len(QueryUtil.LT):], v) 465 elif k.startswith(QueryUtil.LE): # 拼接小于等于 466 s += " AND `%s` <= '%s' " % (k[len(QueryUtil.LE):], v) 467 elif k.startswith(QueryUtil.GT): # 拼接大于 468 s += " AND `%s` > '%s' " % (k[len(QueryUtil.GT):], v) 469 elif k.startswith(QueryUtil.GE): # 拼接大于等于 470 s += " AND `%s` >= '%s' " % (k[len(QueryUtil.GE):], v) 471 else: # 拼接等于 472 if isinstance(v, str): 473 s += " AND `%s`='%s' " % (k, v) 474 elif isinstance(v, int): 475 s += " AND `%s`=%d " % (k, v) 476 return s 477 478 @staticmethod 479 def __filter_group(filters): 480 '''过滤分组''' 481 group = filters.pop(QueryUtil.GROUP) 482 s = " GROUP BY %s" % (group) 483 return s 484 485 @staticmethod 486 def __filter_order(filters): 487 '''过滤排序''' 488 order = filters.pop(QueryUtil.ORDER) 489 type = filters.pop(QueryUtil.ORDER_TYPE, "asc") 490 s = " ORDER BY `%s` %s" % (order, type) 491 return s 492 493 @staticmethod 494 def __filter_page(filters): 495 '''过滤 page 对象''' 496 page = filters.pop("page") 497 return " LIMIT %d,%d" % (page.startRow, page.endRow) 498 499 @staticmethod 500 def query_sql(sql=None, filters=dict()): 501 '''拼接 SQL 查询条件 502 - @param sql SQL 语句 503 - @param filters 过滤条件 504 - @return 返回拼接 SQL 505 ''' 506 if not filters: 507 return sql 508 else: 509 if not isinstance(filters, dict): 510 raise Exception("Parameter [filters] must be dict.") 511 group = None 512 order = None 513 page = None 514 if filters.get("groupby") != None: 515 group = QueryUtil.__filter_group(filters) 516 if filters.get("orderby") != None: 517 order = QueryUtil.__filter_order(filters) 518 if filters.get("page") != None: 519 page = QueryUtil.__filter_page(filters) 520 sql += QueryUtil.__filter_params(filters) 521 if group: 522 sql += group 523 if order: 524 sql += order 525 if page: 526 sql += page 527 return sql
代码下载地址(GitHub): https://github.com/ruancheng77/baseDao
代码中已经给出了几个具体示例,大家可以参考使用。
如果有感兴趣一起学习、讨论Python的可以加QQ群:626787819,有啥意见或者建议的可以发我邮箱:410093793@qq.com。