zoukankan      html  css  js  c++  java
  • Python3

      之前一篇Python 封装DBUtils 和pymysql 中写过一个basedao.py,最近几天又重新整理了下思绪,优化了下 basedao.py,目前支持的方法还不多,后续会进行改进、添加。

      主要功能:

        1.查询单个对象:

          所需参数:表名,过滤条件

        2.查询多个对象:
          所需参数:表名,过滤条件

        3.按主键查询:
          所需参数:表名,值

        4.分页查询:
          所需参数:表名,页码,每页记录数,过滤条件

      调用方法锁获取的对象都是以字典形式存储,例如:查询user表(字段有id,name,age)里的id=1的数据返回的对象为user = {"id":1,"name","zhangsan","age":18},我们可以通过user.get("id")来获取id值,非常方便,不用定义什么类对象来表示。如果查询的是多个,那么多个字典对象将会存放在一个列表里返回。

      具体代码如下:  

      1 import json, os, sys, time, pymysql, pprint
      2 
      3 from DBUtils import PooledDB
      4 
      5 def print(*args):
      6     pprint.pprint(args)
      7 
      8 def get_time():
      9     '获取时间'
     10     return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
     11 
     12 def stitch_sequence(seq=None, suf=None):
     13     '如果参数("suf")不为空,则根据特殊的suf拼接列表元素,返回一个字符串'
     14     if seq is None: raise Exception("Parameter seq is None");
     15     if suf is None: suf = ","
     16     r = str()
     17     for s in seq:
     18         r += s + suf
     19     return r[:-len(suf)]
     20 
     21 class BaseDao(object):
     22     """
     23     简便的数据库操作基类
     24     """
     25     def __init__(self, creator=pymysql, host="localhost",port=3306, user=None, password="",
     26                     database=None, charset="utf8"):
     27         if host is None: raise Exception("Parameter [host] is None.")
     28         if port is None: raise Exception("Parameter [port] is None.")
     29         if user is None: raise Exception("Parameter [user] is None.")
     30         if password is None: raise Exception("Parameter [password] is None.")
     31         if database is None: raise Exception("Parameter [database] is None.")
     32         # 数据库连接配置
     33         self.__config = dict({
     34             "creator" : creator, "charset":charset, "host":host, "port":port, 
     35             "user":user, "password":password, "database":database
     36         })
     37         self.__database = self.__config["database"]     # 用于存储查询数据库
     38         self.__tableName = None                         # 用于临时存储当前查询表名
     39         # 初始化
     40         self.__init_connect()                           # 初始化连接
     41         self.__init_params()                            # 初始化参数
     42         print(get_time(), self.__database, "数据库初始化成功。")
     43         
     44     def __del__(self):
     45         '重写类被清除时调用的方法'
     46         if self.__cursor: self.__cursor.close()
     47         if self.__conn: self.__conn.close()
     48         print(get_time(), self.__database, "连接关闭")
     49 
     50     def __init_connect(self):
     51         self.__conn = PooledDB.connect(**self.__config)
     52         self.__cursor = self.__conn.cursor()
     53 
     54     def __init_params(self):
     55         '初始化参数'
     56         self.__init_table_dict()
     57         self.__init__table_column_dict_list()
     58 
     59     def __init__information_schema_columns(self):
     60         "查询 information_schema.`COLUMNS` 中的列"
     61         sql =   """ SELECT COLUMN_NAME FROM information_schema.`COLUMNS`
     62                     WHERE TABLE_SCHEMA='information_schema' AND TABLE_NAME='COLUMNS'
     63                 """
     64         result_tuple = self.__exec_query(sql)
     65         column_list = [r[0] for r in result_tuple]
     66         return column_list
     67 
     68     def __init_table_dict(self):
     69         "查询配置数据库中改的所有表"
     70         schema_column_list = self.__init__information_schema_columns()
     71         stitch_str = stitch_sequence(schema_column_list)
     72         sql1 =  """ SELECT TABLE_NAME FROM information_schema.`TABLES`
     73                     WHERE TABLE_SCHEMA='%s'
     74                 """ %(self.__database)
     75         table_tuple = self.__exec_query(sql1)
     76         self.__table_dict = {t[0]:{} for t in table_tuple}
     77         for table in self.__table_dict.keys():
     78             sql =   """ SELECT %s FROM information_schema.`COLUMNS`
     79                         WHERE TABLE_SCHEMA='%s' AND TABLE_NAME='%s'
     80                     """ %(stitch_str, self.__database, table)
     81             column_tuple = self.__exec_query(sql)
     82             column_dict = {}
     83             for vs in column_tuple:
     84                 d = {k:v for k,v in zip(schema_column_list, vs)}
     85                 column_dict[d["COLUMN_NAME"]] = d
     86             self.__table_dict[table] = column_dict
     87 
     88     def __init__table_column_dict_list(self):
     89         self.__table_column_dict_list = {}
     90         for table, column_dict in self.__table_dict.items():
     91             column_list = [column for column in column_dict.keys()]
     92             self.__table_column_dict_list[table] = column_list
     93         
     94     def __exec_query(self, sql, single=False):
     95         '''
     96         执行查询方法
     97         - @sql    查询 sql
     98         - @single 是否查询单个结果集,默认False
     99         '''
    100         try:
    101             self.__cursor.execute(sql)
    102             print(get_time(), "SQL[%s]"%sql)
    103             if single:
    104                 result_tuple = self.__cursor.fetchone()
    105             else:
    106                 result_tuple = self.__cursor.fetchall()
    107             return result_tuple
    108         except Exception as e:
    109             print(e)
    110 
    111     def __exec_update(self, sql):
    112         try:
    113             # 获取数据库游标
    114             result = self.__cursor.execute(sql)
    115             print(get_time(), "SQL[%s]"%sql)
    116             self.__conn.commit()
    117             return result
    118         except Exception as e:
    119             print(e)
    120             self.__conn.rollback()
    121 
    122     def __parse_result(self, result):
    123         '用于解析单个查询结果,返回字典对象'
    124         if result is None: return None
    125         obj = {k:v for k,v in zip(self.__column_list, result)}
    126         return obj
    127 
    128     def __parse_results(self, results):
    129         '用于解析多个查询结果,返回字典列表对象'
    130         if results is None: return None
    131         objs = [self.__parse_result(result) for result in results]
    132         return objs
    133 
    134     def __getpk(self, tableName):
    135         if self.__table_dict.get(tableName) is None: raise Exception(tableName, "is not exist.")
    136         for column, column_dict in self.__table_dict[tableName].items():
    137             if column_dict["COLUMN_KEY"] == "PRI": return column
    138 
    139     def __get_table_column_list(self, tableName=None):
    140         '查询表的字段列表, 将查询出来的字段列表存入 __fields 中'
    141         return self.__table_column_dict_list[tableName]
    142 
    143     def __query_util(self, filters=None):
    144         """
    145         SQL 语句拼接方法
    146         @filters 过滤条件
    147         """
    148         sql = r'SELECT #{FIELDS} FROM #{TABLE_NAME} WHERE 1=1 #{FILTERS}'
    149         # 拼接查询表
    150         sql = sql.replace("#{TABLE_NAME}", self.__tableName)
    151         # 拼接查询字段
    152         FIELDS = stitch_sequence(self.__get_table_column_list(self.__tableName))
    153         sql = sql.replace("#{FIELDS}", FIELDS)
    154         # 拼接查询条件(待优化)
    155         if filters is None:
    156             sql = sql.replace("#{FILTERS}", "")
    157         else:
    158             FILTERS =  ""
    159             if not isinstance(filters, dict):
    160                 raise Exception("Parameter [filters] must be dict type. ")
    161             isPage = False
    162             if filters.get("_limit_"): isPage = True
    163             if isPage: beginindex, limit = filters.pop("_limit_")
    164             for k, v in filters.items():
    165                 if k.startswith("_in_"):                # 拼接 in
    166                     FILTERS += "AND %s IN (" %(k[4:])
    167                     values = v.split(",")
    168                     for value in values:
    169                         FILTERS += "%s,"%value
    170                     FILTERS = FILTERS[0:len(FILTERS)-1] + ") "
    171                 elif k.startswith("_nein_"):            # 拼接 not in
    172                     FILTERS += "AND %s NOT IN (" %(k[4:])
    173                     values = v.split(",")
    174                     for value in values:
    175                         FILTERS += "%s,"%value
    176                     FILTERS = FILTERS[0:len(FILTERS)-1] + ") "
    177                 elif k.startswith("_like_"):            # 拼接 like
    178                     FILTERS += "AND %s like '%%%s%%' " %(k[6:], v)
    179                 elif k.startswith("_ne_"):              # 拼接不等于
    180                     FILTERS += "AND %s != '%s' " %(k[4:], v)
    181                 elif k.startswith("_lt_"):              # 拼接小于
    182                     FILTERS += "AND %s < '%s' " %(k[4:], v)
    183                 elif k.startswith("_le_"):              # 拼接小于等于
    184                     FILTERS += "AND %s <= '%s' " %(k[4:], v)
    185                 elif k.startswith("_gt_"):              # 拼接大于
    186                     FILTERS += "AND %s > '%s' " %(k[4:], v)
    187                 elif k.startswith("_ge_"):              # 拼接大于等于
    188                     FILTERS += "AND %s >= '%s' " %(k[4:], v)
    189                 else:                # 拼接等于
    190                     FILTERS += "AND %s='%s' "%(k, v)
    191             sql = sql.replace("#{FILTERS}", FILTERS)
    192             if isPage: sql += "LIMIT %d,%d"%(beginindex, limit)
    193         return sql
    194 
    195     def __check_params(self, tableName):
    196         '''
    197         检查参数
    198         '''
    199         if tableName is None and self.__tableName is None:
    200             raise Exception("Parameter [tableName] is None.")
    201         elif self.__tableName is None or self.__tableName != tableName:
    202             self.__tableName = tableName
    203             self.__column_list = self.__table_column_dict_list[self.__tableName]
    204 
    205     def select_one(self, tableName=None, filters={}):
    206         '''
    207         查询单个对象
    208         @tableName 表名
    209         @filters 过滤条件
    210         @return 返回字典集合,集合中以表字段作为 key,字段值作为 value
    211         '''
    212         self.__check_params(tableName)
    213         sql = self.__query_util(filters)
    214         result = self.__exec_query(sql, single=True)
    215         return self.__parse_result(result) 
    216 
    217     def select_pk(self, tableName=None, primaryKey=None):
    218         '''
    219         按主键查询
    220         @tableName 表名
    221         @primaryKey 主键值
    222         '''
    223         self.__check_params(tableName)
    224         filters = {}
    225         filters.setdefault(self.__getpk(tableName), primaryKey)
    226         sql = self.__query_util(filters)
    227         result = self.__exec_query(sql, single=True)
    228         return self.__parse_result(result)
    229         
    230     def select_all(self, tableName=None, filters={}):
    231         '''
    232         查询所有
    233         @tableName 表名
    234         @filters 过滤条件
    235         @return 返回字典集合,集合中以表字段作为 key,字段值作为 value
    236         '''
    237         self.__check_params(tableName)
    238         sql = self.__query_util(filters)
    239         results = self.__exec_query(sql)
    240         return self.__parse_results(results)
    241 
    242     def count(self, tableName=None):
    243         '''
    244         统计记录数
    245         '''
    246         self.__check_params(tableName)
    247         sql = "SELECT count(*) FROM %s"%(self.__tableName)
    248         result = self.__exec_query(sql, single=True)
    249         return result[0]
    250 
    251     def select_page(self, tableName=None, pageNum=1, limit=10, filters={}):
    252         '''
    253         分页查询
    254         @tableName 表名
    255         @return 返回字典集合,集合中以表字段作为 key,字段值作为 value
    256         '''
    257         self.__check_params(tableName)
    258         totalCount = self.count(tableName)
    259         if totalCount / limit == 0 :
    260             totalPage = totalCount / limit
    261         else:
    262             totalPage = totalCount // limit + 1
    263         if pageNum > totalPage:
    264             print("最大页数为%d"%totalPage)
    265             pageNum = totalPage
    266         elif pageNum < 1:
    267             print("页数不能小于1")
    268             pageNum = 1
    269         beginindex = (pageNum-1) * limit
    270         filters.setdefault("_limit_", (beginindex, limit))
    271         sql = self.__query_util(filters)
    272         result_tuple = self.__exec_query(sql)
    273         return self.__parse_results(result_tuple)
    274 
    275 if __name__ == "__main__":
    276     config = {
    277         # "creator": pymysql,
    278         # "host" : "127.0.0.1", 
    279         "user" : "root", 
    280         "password" : "root",
    281         "database" : "test", 
    282         # "port" : 3306,
    283         # "charset" : 'utf8'
    284     }
    285     base = BaseDao(**config)
    286     ########################################################################
    287     user = base.select_one("user")
    288     print(user)
    289     ########################################################################
    290     # users = base.select_all("user")
    291     # print(users)
    292     ########################################################################
    293     filter1 = {
    294         "status":1,
    295         "_in_id":"1,2,3,4,5",
    296         "_like_name":"zhang",
    297         "_ne_name":"wangwu"
    298     }
    299     user_filters = base.select_all("user", filter1)
    300     print(user_filters)
    301     ########################################################################
    302     role = base.select_one("role")
    303     print(role)
    304     ########################################################################
    305     user_pk = base.select_pk("user", 2)
    306     print(user_pk)
    307     ########################################################################
    308     user_limit = base.select_page("user", 1, 10)
    309     print(user_limit)
    310     ########################################################################
    View Code

      更新:2017-08-25

      1 import json, os, sys, time, pymysql, pprint, logging
      2 
      3 logging.basicConfig(
      4     level=logging.DEBUG, 
      5     format='%(asctime)s [%(levelname)s] %(message)s',
      6     datefmt='%a, %d %b %Y %H:%M:%S')
      7 
      8 from DBUtils import PooledDB
      9 
     10 def print(*args):
     11     pprint.pprint(args)
     12 
     13 def get_time():
     14     '获取时间'
     15     return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
     16 
     17 def stitch_sequence(seq=None, suf=None):
     18     '如果参数("suf")不为空,则根据特殊的suf拼接列表元素,返回一个字符串。默认使用 ","。'
     19     if seq is None: raise Exception("Parameter seq is None");
     20     if suf is None: suf = ","
     21     r = str()
     22     for s in seq:
     23         r += s + suf
     24     return r[:-len(suf)]
     25 
     26 class BaseDao(object):
     27     """
     28     简便的数据库操作基类,该类所操作的表必须有主键
     29     初始化参数如下:
     30     - creator: 创建连接对象(默认: pymysql)
     31     - host: 连接数据库主机地址(默认: localhost)
     32     - port: 连接数据库端口(默认: 3306)
     33     - user: 连接数据库用户名(默认: None), 如果为空,则会抛异常
     34     - password: 连接数据库密码(默认: None), 如果为空,则会抛异常
     35     - database: 连接数据库(默认: None), 如果为空,则会抛异常
     36     - chatset: 编码(默认: utf8)
     37     - tableName: 初始化 BaseDao 对象的数据库表名(默认: None), 如果为空,
     38     则会初始化该数据库下所有表的信息, 如果不为空,则只初始化传入的 tableName 的表
     39     """
     40     def __init__(self, creator=pymysql, host="localhost",port=3306, user=None, password=None,
     41                     database=None, charset="utf8", tableName=None):
     42         if host is None: raise Exception("Parameter [host] is None.")
     43         if port is None: raise Exception("Parameter [port] is None.")
     44         if user is None: raise Exception("Parameter [user] is None.")
     45         if password is None: raise Exception("Parameter [password] is None.")
     46         if database is None: raise Exception("Parameter [database] is None.")
     47         if tableName is None: print("WARNING >>> Parameter [tableName] is None. All tables will be initialized.")
     48         logging.debug("[%s] 数据库初始化>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>开始"%(database))
     49         start = time.time()
     50         # 数据库连接配置
     51         self.__config = dict({
     52             "creator" : creator, "charset":charset, "host":host, "port":port, 
     53             "user":user, "password":password, "database":database
     54         })
     55         self.__database = database                      # 用于存储查询数据库
     56         self.__tableName = tableName                    # 用于临时存储当前查询表名
     57         # 初始化
     58         self.__init_connect()                           # 初始化连接
     59         self.__init_params()                            # 初始化参数
     60         end = time.time()
     61         logging.debug("[%s] 数据库初始化>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>结束"%(database))
     62         logging.debug("[%s] 数据库初始化成功。耗时:%d ms。"%(database, (end-start)))
     63         
     64     def __del__(self):
     65         '重写类被清除时调用的方法'
     66         if self.__cursor: self.__cursor.close()
     67         if self.__conn: self.__conn.close()
     68         logging.debug("[%s] 连接关闭。"%(self.__database))
     69 
     70     def __init_connect(self):
     71         '初始化连接'
     72         self.__conn = PooledDB.connect(**self.__config)
     73         self.__cursor = self.__conn.cursor()
     74 
     75     def __init_params(self):
     76         '初始化参数'
     77         self.__table_dict = {}
     78         self.__information_schema_columns = []
     79         self.__table_column_dict_list = {}
     80         if self.__tableName is None:
     81             self.__init_table_dict_list()
     82             self.__init__table_column_dict_list()
     83         else:
     84             self.__init_table_dict(self.__tableName)
     85             self.__init__table_column_dict_list()
     86             self.__column_list = self.__table_column_dict_list[self.__tableName]
     87 
     88     def __init__information_schema_columns(self):
     89         "查询 information_schema.`COLUMNS` 中的列"
     90         sql =   """ SELECT COLUMN_NAME 
     91                     FROM information_schema.`COLUMNS`
     92                     WHERE TABLE_SCHEMA='information_schema' AND TABLE_NAME='COLUMNS'
     93                 """
     94         result_tuple = self.__exec_query(sql)
     95         column_list = [r[0] for r in result_tuple]
     96         self.__information_schema_columns = column_list
     97 
     98     def __init_table_dict(self, tableName):
     99         '初始化表'
    100         if not self.__information_schema_columns:
    101             self.__init__information_schema_columns()
    102         stitch_str = stitch_sequence(self.__information_schema_columns)
    103         sql =   """ SELECT %s FROM information_schema.`COLUMNS`
    104                     WHERE TABLE_SCHEMA='%s' AND TABLE_NAME='%s'
    105                 """ %(stitch_str, self.__database, tableName)
    106         column_tuple = self.__exec_query(sql)
    107         column_dict = {}
    108         for vs in column_tuple:
    109             d = {k:v for k,v in zip(self.__information_schema_columns, vs)}
    110             column_dict[d["COLUMN_NAME"]] = d
    111         self.__table_dict[tableName] = column_dict
    112 
    113     def __init_table_dict_list(self):
    114         "初始化表字典对象"
    115         if not self.__information_schema_columns:
    116             self.__init__information_schema_columns()
    117         stitch_str = stitch_sequence(self.__information_schema_columns)
    118         sql1 =  """
    119                 SELECT TABLE_NAME FROM information_schema.`TABLES` WHERE TABLE_SCHEMA='%s'
    120                 """ %(self.__database)
    121         table_tuple = self.__exec_query(sql1)
    122         self.__table_dict = {t[0]:{} for t in table_tuple}
    123         for table in table_tuple:
    124             self.__init_table_dict(table[0])
    125 
    126     def __init__table_column_dict_list(self):
    127         '''初始化表字段字典列表'''
    128         for table, column_dict in self.__table_dict.items():
    129             column_list = [column for column in column_dict.keys()]
    130             self.__table_column_dict_list[table] = column_list
    131         
    132     def __exec_query(self, sql, single=False):
    133         '''执行查询 SQL 语句
    134         - @sql    查询 sql
    135         - @single 是否查询单个结果集,默认False
    136         '''
    137         try:
    138             logging.debug("[%s] SQL >>> [%s]"%(self.__database, sql))
    139             self.__cursor.execute(sql)
    140             if single:
    141                 result_tuple = self.__cursor.fetchone()
    142             else:
    143                 result_tuple = self.__cursor.fetchall()
    144             return result_tuple
    145         except Exception as e:
    146             print(e)
    147 
    148     def __exec_update(self, sql):
    149         '''执行更新 SQL 语句'''
    150         try:
    151             # 获取数据库游标
    152             logging.debug("[%s] SQL >>> [%s]"%(self.__database, sql))
    153             result = self.__cursor.execute(sql)
    154             self.__conn.commit()
    155             return result
    156         except Exception as e:
    157             print(e)
    158             self.__conn.rollback()
    159 
    160     def __parse_result(self, result):
    161         '用于解析单个查询结果,返回字典对象'
    162         if result is None: return None
    163         obj = {k:v for k,v in zip(self.__column_list, result)}
    164         return obj
    165 
    166     def __parse_results(self, results):
    167         '用于解析多个查询结果,返回字典列表对象'
    168         if results is None: return None
    169         objs = [self.__parse_result(result) for result in results]
    170         return objs
    171 
    172     def __getpk(self, tableName):
    173         '获取表对应的主键字段'
    174         if self.__table_dict.get(tableName) is None: raise Exception(tableName, "is not exist.")
    175         for column, column_dict in self.__table_dict[tableName].items():
    176             if column_dict["COLUMN_KEY"] == "PRI": return column
    177 
    178     def __get_table_column_list(self, tableName=None):
    179         '查询表的字段列表, 将查询出来的字段列表存入 __fields 中'
    180         return self.__table_column_dict_list[tableName]
    181 
    182     def __check_tableName(self, tableName):
    183         '''验证 tableName 参数'''
    184         if tableName is None:
    185             if self.__tableName is None:
    186                 raise Exception("Parameter [tableName] is None.")
    187         else:
    188             if self.__tableName != tableName:
    189                 self.__tableName = tableName
    190                 self.__column_list = self.__table_column_dict_list[self.__tableName]
    191 
    192     def select_one(self, tableName=None, filters={}):
    193         '''查询单个对象
    194         - @tableName 表名
    195         - @filters 过滤条件
    196         - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value
    197         '''
    198         self.__check_tableName(tableName)
    199         FIELDS = stitch_sequence(self.__get_table_column_list(self.__tableName))
    200         sql = "SELECT %s FROM %s"%(FIELDS ,self.__tableName)
    201         sql = QueryUtil.query_sql(sql, filters)
    202         result = self.__exec_query(sql, single=True)
    203         return self.__parse_result(result) 
    204 
    205     def select_pk(self, tableName=None, primaryKey=None):
    206         '''按主键查询
    207         - @tableName 表名
    208         - @primaryKey 主键值
    209         '''
    210         self.__check_tableName(tableName)
    211         FIELDS = stitch_sequence(self.__get_table_column_list(self.__tableName))
    212         sql = "SELECT %s FROM %s"%(FIELDS, self.__tableName)
    213         sql = QueryUtil.query_sql(sql, {self.__getpk(tableName):primaryKey})
    214         result = self.__exec_query(sql, single=True)
    215         return self.__parse_result(result)
    216         
    217     def select_all(self, tableName=None, filters={}):
    218         '''查询所有
    219         - @tableName 表名
    220         - @filters 过滤条件
    221         - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value
    222         '''
    223         self.__check_tableName(tableName)
    224         FIELDS = stitch_sequence(self.__get_table_column_list(self.__tableName))
    225         sql = "SELECT %s FROM %s"%(FIELDS ,self.__tableName)
    226         sql = QueryUtil.query_sql(sql, filters)
    227         results = self.__exec_query(sql)
    228         return self.__parse_results(results)
    229 
    230     def count(self, tableName=None):
    231         '''统计记录数'''
    232         self.__check_tableName(tableName)
    233         sql = "SELECT count(*) FROM %s"%(self.__tableName)
    234         result = self.__exec_query(sql, single=True)
    235         return result[0]
    236 
    237     def select_page(self, tableName=None, page=None, filters={}):
    238         '''分页查询
    239         - @tableName 表名
    240         - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value
    241         '''
    242         self.__check_tableName(tableName)
    243         if page is None:
    244             page = Page()
    245         filters["page"] = page
    246         FIELDS = stitch_sequence(self.__get_table_column_list(self.__tableName))
    247         sql = "SELECT %s FROM %s"%(FIELDS ,self.__tableName)
    248         sql = QueryUtil.query_sql(sql, filters)
    249         result_tuple = self.__exec_query(sql)
    250         return self.__parse_results(result_tuple)
    251 
    252     def save(self, tableName=None, obj=dict()):
    253         '''保存方法
    254         - @param tableName 表名
    255         - @param obj 对象
    256         - @return 影响行数
    257         '''
    258         self.__check_tableName(tableName)
    259         FIELDS = stitch_sequence(seq=obj.keys())
    260         VALUES = []
    261         for k, v in obj.items():
    262             if self.__table_dict[self.__tableName][k]["COLUMN_KEY"] != "PKI":
    263                 if v is None:
    264                     v = "null"
    265                 else:
    266                     v = '"%s"'%v
    267             VALUES.append(v)
    268         VALUES = stitch_sequence(seq=VALUES)
    269         sql = ' INSERT INTO `%s` (%s) VALUES(%s)'%(self.__tableName, FIELDS, VALUES)
    270         return self.__exec_update(sql)
    271     
    272     def update(self, tableName=None, obj={}):
    273         '''更新方法(根据主键更新,包含空值)
    274         - @param tableName 表名
    275         - @param obj 对象
    276         - @return 影响行数
    277         '''
    278         self.__check_tableName(tableName)
    279         l = []
    280         where = "WHERE "
    281         for k, v in obj.items():
    282             if self.__table_dict[self.__tableName][k]["COLUMN_KEY"] != "PRI":
    283                 if v is None:
    284                     if self.__table_dict[self.__tableName][k]["IS_NULLABLE"] == "YES":
    285                         l.append("%s=null"%(k))
    286                     else:
    287                         l.append("%s=''"%(k))
    288                 else:
    289                     l.append("%s='%s'"%(k, v))
    290             else:
    291                 where += "%s='%s'"%(k, v)
    292         sql = "UPDATE `%s` SET %s %s"%(self.__tableName, stitch_sequence(l), where)
    293         return self.__exec_update(sql)
    294 
    295     def update_selective(self, tableName=None, obj={}):
    296         '''更新方法(根据主键更新,不包含空值)
    297         - @param tableName 表名
    298         - @param obj 对象
    299         - @return 影响行数
    300         '''
    301         self.__check_tableName(tableName)
    302         where = "WHERE "
    303         l = []
    304         for k, v in obj.items():
    305             if self.__table_dict[self.__tableName][k]["COLUMN_KEY"] != "PRI":
    306                 if v is None:
    307                     continue
    308                 l.append("%s='%s'"%(k, v))
    309             else:
    310                 where += "%s='%s'"%(k, v)
    311         sql = "UPDATE `%s` SET %s %s"%(self.__tableName, stitch_sequence(l), where)
    312         return self.__exec_update(sql)
    313     
    314     def remove(self, tableName=None, obj={}):
    315         '''删除方法(根据主键删除)
    316         - @param tableName 表名
    317         - @param obj 对象
    318         - @return 影响行数
    319         '''
    320         self.__check_tableName(tableName)
    321         pk = self.__getpk(self.__tableName)
    322         sql = "DELETE FROM `%s` WHERE %s=%s"%(self.__tableName, pk, obj[pk])
    323         print(sql)
    324         # return self.__exec_update(sql)
    325 
    326 class Page(object):
    327     '分页对象'
    328     def __init__(self, pageNum=1, pageSize=10, count=False):
    329         '''
    330         Page 初始化方法
    331         - @param pageNum 页码,默认为1
    332         - @param pageSize 页面大小, 默认为10
    333         - @param count 是否包含 count 查询
    334         '''
    335         self.pageNum = pageNum if pageNum > 0 else 1            # 当前页数
    336         self.pageSize = pageSize if pageSize > 0 else 10        # 分页大小
    337         self.total = 0                                          # 总记录数
    338         self.pages = 1                                          # 总页数
    339         self.startRow = (self.pageNum - 1 ) * self.pageSize     # 起始行(用于 mysql 分页查询)
    340         self.endRow = self.startRow + self.pageSize             # 结束行(用于 mysql 分页查询)
    341 
    342 class QueryUtil(object):
    343     '''
    344     SQL 语句拼接工具类:
    345     - 主方法:querySql(sql, filters)
    346     - 参数说明:   
    347     - @param sql:需要拼接的 SQL 语句
    348     - @param filters:拼接 SQL 的过滤条件 
    
    349     filters 过滤条件说明:
    350     - 支持拼接条件如下:
    351     - 1、等于(如:{"id": 2}, 拼接后为:id=2)
    352     - 2、不等于(如:{"_ne_id": 2}, 拼接后为:id != 2)
    353     - 3、小于(如:{"_lt_id": 2},拼接后为:id < 2)
    354     - 4、小于等于(如:{"_le_id": 2},拼接后为:id <= 2)
    355     - 5、大于(如:{"_gt_id": },拼接后为:id > 2)
    356     - 6、大于等于(如:{"_ge_id": },拼接后为:id >=2)
    357     - 7、in(如:{"_in_id": "1,2,3"},拼接后为:id IN(1,2,3))
    358     - 8、not in(如:{"_nein_id": "4,5,6"},拼接后为:id NOT IN(4,5,6))
    359     - 9、like(如:{"_like_name": },拼接后为:name LIKE '%zhang%')
    360     - 10、like(如:{"_llike_name": },拼接后为:name LIKE '%zhang')
    361     - 11、like(如:{"_rlike_name": },拼接后为:name LIKE 'zhang%')
    362     - 12、分组(如:{"groupby": "status"},拼接后为:GROUP BY status)
    363     - 13、排序(如:{"orderby": "createDate"},拼接后为:ORDER BY createDate)
    364     '''
    365     
    366     NE = "_ne_"                 # 拼接不等于
    367     LT = "_lt_"                 # 拼接小于 
    368     LE = "_le_"                 # 拼接小于等于
    369     GT = "_gt_"                 # 拼接大于
    370     GE = "_ge_"                 # 拼接大于等于 
    371     IN = "_in_"                 # 拼接 in
    372     NE_IN = "_nein_"            # 拼接 not in
    373     LIKE = "_like_"             # 拼接 like
    374     LEFT_LIKE = "_llike_"       # 拼接左 like
    375     RIGHT_LIKE = "_rlike_"      # 拼接右 like
    376     GROUP = "groupby"           # 拼接分组
    377     ORDER = "orderby"           # 拼接排序
    378     ORDER_TYPE = "ordertype"    # 排序类型:asc(升序)、desc(降序)
    379 
    380     @staticmethod
    381     def __filter_params(filters):
    382         '''过滤参数条件'''
    383         s = " WHERE 1=1"
    384         for k, v in filters.items():
    385             if k.startswith(QueryUtil.IN):                  # 拼接 in
    386                 s += " AND %s IN (" %(k[4:])
    387                 values = v.split(",")
    388                 for value in values:
    389                     s += " %s,"%value
    390                 s = s[0:len(s)-1] + ") "
    391             elif k.startswith(QueryUtil.NE_IN):             # 拼接 not in
    392                 s += " AND %s NOT IN (" %(k[4:])
    393                 values = v.split(",")
    394                 for value in values:
    395                     s += " %s,"%value
    396                 s = s[0:len(s)-1] + ") "
    397             elif k.startswith(QueryUtil.LIKE):              # 拼接 like
    398                 s += " AND %s LIKE '%%%s%%' " %(k[len(QueryUtil.LIKE):], v)
    399             elif k.startswith(QueryUtil.LEFT_LIKE):         # 拼接左 like
    400                 s += " AND %s LIKE '%%%s' " %(k[len(QueryUtil.LEFT_LIKE):], v)
    401             elif k.startswith(QueryUtil.RIGHT_LIKE):        # 拼接右 like
    402                 s += " AND %s LIKE '%s%%' " %(k[len(QueryUtil.RIGHT_LIKE):], v)
    403             elif k.startswith(QueryUtil.NE):                # 拼接不等于
    404                 s += " AND %s != '%s' " %(k[4:], v)
    405             elif k.startswith(QueryUtil.LT):                # 拼接小于
    406                 s += " AND %s < '%s' " %(k[4:], v)
    407             elif k.startswith(QueryUtil.LE):                # 拼接小于等于
    408                 s += " AND %s <= '%s' " %(k[4:], v)
    409             elif k.startswith(QueryUtil.GT):                # 拼接大于
    410                 s += " AND %s > '%s' " %(k[4:], v)
    411             elif k.startswith(QueryUtil.GE):                # 拼接大于等于
    412                 s += " AND %s >= '%s' " %(k[4:], v)
    413             else:                                           # 拼接等于
    414                 if isinstance(v, str):
    415                     s += " AND %s='%s' "%(k, v)
    416                 elif isinstance(v, int):
    417                     s += " AND %s=%d "%(k, v)
    418         return s
    419 
    420     @staticmethod
    421     def __filter_group(filters):
    422         '''过滤分组'''
    423         group = filters.pop(QueryUtil.GROUP)
    424         s = " GROUP BY %s"%(group)
    425         return s
    426 
    427     @staticmethod
    428     def __filter_order(filters):
    429         '''过滤排序'''
    430         order = filters.pop(QueryUtil.ORDER)
    431         type = filters.pop(QueryUtil.ORDER_TYPE)
    432         s = " ORDER BY `%s` %s"%(order, type)
    433         return s
    434 
    435     @staticmethod
    436     def __filter_page(filters):
    437         '''过滤 page 对象'''
    438         page = filters.pop("page")
    439         return " LIMIT %d,%d"%(page.startRow, page.endRow)
    440         
    441     @staticmethod
    442     def query_sql(sql=None, filters=dict()):
    443         '''拼接 SQL 查询条件
    444         - @param sql SQL 语句
    445         - @param filters 过滤条件
    446         - @return 返回拼接 SQL
    447         '''
    448         if not filters:
    449             return sql
    450         else:
    451             if not isinstance(filters, dict):
    452                 raise Exception("Parameter [filters] must be dict.")
    453             group = None
    454             order = None
    455             page = None
    456             if filters.get("groupby") != None:
    457                 group = QueryUtil.__filter_group(filters)
    458             if filters.get("orderby") != None:
    459                 order = QueryUtil.__filter_order(filters)
    460             if filters.get("page") != None:
    461                 page = QueryUtil.__filter_page(filters)
    462             sql += QueryUtil.__filter_params(filters)
    463             if group:
    464                 sql += group
    465             if order:
    466                 sql += order
    467             if page:
    468                 sql += page
    469         return sql
    470 
    471     @staticmethod
    472     def query_set(fields, values):
    473         s = " SET "
    474         for f, v in zip(fields, values):
    475             s += '%s="%s", '
    476         pass
    477 
    478 def test():
    479     config = {
    480         # "creator": pymysql,
    481         # "host" : "127.0.0.1", 
    482         "user" : "root", 
    483         "password" : "root",
    484         "database" : "py", 
    485         # "port" : 3306,
    486         # "charset" : 'utf8'
    487         # "tableName" : "fake",
    488     }
    489     base = BaseDao(**config)
    490     ########################################################################
    491     # fake = base.select_one("fake")
    492     # print(fake)
    493     ########################################################################
    494     # users = base.select_all("fake")
    495     # print(users)
    496     ########################################################################
    497     # filter1 = {
    498     #     "status":1,
    499     #     "_in_id":"1,2,3,4,5",
    500     #     "_like_name":"zhang",
    501     #     "_ne_name":"wangwu"
    502     # }
    503     # user_filters = base.select_all("user", filter1)
    504     # print(user_filters)
    505     ########################################################################
    506     # role = base.select_one("role")
    507     # print(role)
    508     ########################################################################
    509     # fake = base.select_pk("fake", 2)
    510     # print(fake)
    511     # base.update("fake", fake)
    512     # base.update_selective("fake", fake)
    513     # base.remove("fake", fake)
    514     ########################################################################
    515     # user_limit = base.select_page("user")
    516     # print(user_limit)
    517     ########################################################################
    518     # fake = {
    519     #     "id": "null",
    520     #     "name": "test",
    521     #     "value": "test"
    522     # }
    523     # flag = base.save("fake", fake)
    524     # print(flag)
    525 
    526 if __name__ == "__main__":
    527     test()
    View Code

      以上更新部分比较多,整体上进行了优化,新增了(save,update,delete 方法)。

      更新:2017-10-26

      1 #!/usr/bin/env python3
      2 # -*- coding=utf-8 -*-
      3 
      4 import json
      5 import logging
      6 import os
      7 import sys
      8 import time
      9 
     10 import pymysql
     11 from DBUtils import PooledDB
     12 
     13 __author__ = "阮程"
     14 
     15 logging.basicConfig(
     16     level=logging.INFO,
     17     datefmt='%Y-%m-%d %H:%M:%S',
     18     format='%(asctime)s [%(levelname)s] %(message)s'
     19 )
     20 
     21 
     22 def get_time(format=None):
     23     '获取时间'
     24     format = format or "%Y-%m-%d %H:%M:%S"
     25     return time.strftime(format, time.localtime())
     26 
     27 
     28 def stitch_sequence(seq=None, suf=None, isField=True):
     29     '如果参数("suf")不为空,则根据特殊的suf拼接列表元素,返回一个字符串。默认使用 ","。'
     30     if seq is None:
     31         raise Exception("Parameter seq is None")
     32     suf = suf or ","
     33     r = str()
     34     for s in seq:
     35         r += '`%s`%s' % (s, suf) if isField else '%s%s' % (s, suf)
     36         # if isField:
     37         #     r += '`%s`%s' % (s, suf)
     38         # else:
     39         #     r += '%s%s' % (s, suf)
     40     return r[:-len(suf)]
     41 
     42 
     43 class BaseDao(object):
     44     """
     45     简便的数据库操作基类,该类所操作的表必须有主键
     46     初始化参数如下:
     47     - :creator: 创建连接对象(默认: pymysql)
     48     - :host: 连接数据库主机地址(默认: localhost)
     49     - :port: 连接数据库端口(默认: 3306)
     50     - :user: 连接数据库用户名(默认: None), 如果为空,则会抛异常
     51     - :password: 连接数据库密码(默认: None), 如果为空,则会抛异常
     52     - :database: 连接数据库(默认: None), 如果为空,则会抛异常
     53     - :chatset: 编码(默认: utf8)
     54     - :tableName: 初始化 BaseDao 对象的数据库表名(默认: None), 如果为空,
     55     则会初始化该数据库下所有表的信息, 如果不为空,则只初始化传入的 tableName 的表
     56     """
     57 
     58     def __init__(self, creator=pymysql, host="localhost", port=3306, user=None, password=None,
     59                  database=None, charset="utf8", tableName=None, *args, **kwargs):
     60         if host is None:
     61             raise ValueError("Parameter [host] is None.")
     62         if port is None:
     63             raise ValueError("Parameter [port] is None.")
     64         if user is None:
     65             raise ValueError("Parameter [user] is None.")
     66         if password is None:
     67             raise ValueError("Parameter [password] is None.")
     68         if database is None:
     69             raise ValueError("Parameter [database] is None.")
     70         if tableName is None:
     71             print(
     72                 "WARNING >>> Parameter [tableName] is None. All tables will be initialized.")
     73         logging.debug(
     74             "[%s] 数据库初始化>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>开始" % (database))
     75         start = time.time()
     76         # 数据库连接配置
     77         self.__config = dict({
     78             "creator": creator, "charset": charset, "host": host, "port": port,
     79             "user": user, "password": password, "database": database
     80         })
     81         self.__database = database                      # 用于存储查询数据库
     82         self.__tableName = tableName                    # 用于临时存储当前查询表名
     83         # 初始化
     84         self.__init_connect()                           # 初始化连接
     85         self.__init_params()                            # 初始化参数
     86         end = time.time()
     87         logging.debug(
     88             "[%s] 数据库初始化>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>结束" % (database))
     89         logging.info("[%s] 数据库初始化成功。耗时:%d ms。" % (database, (end - start)))
     90 
     91     def __del__(self):
     92         '重写类被清除时调用的方法'
     93         if self.__cursor:
     94             self.__cursor.close()
     95         if self.__conn:
     96             self.__conn.close()
     97         logging.debug("[%s] 连接关闭。" % (self.__database))
     98 
     99     def __init_connect(self):
    100         '初始化连接'
    101         self.__conn = PooledDB.connect(**self.__config)
    102         self.__cursor = self.__conn.cursor()
    103 
    104     def __init_params(self):
    105         '初始化参数'
    106         self._table_dict = {}
    107         self.__information_schema_columns = []
    108         self.__table_column_dict_list = {}
    109         if self.__tableName is None:
    110             self.__init_table_dict_list()
    111             self.__init__table_column_dict_list()
    112         else:
    113             self.__init_table_dict(self.__tableName)
    114             self.__init__table_column_dict_list()
    115             self.__column_list = self.__table_column_dict_list[self.__tableName]
    116 
    117     def __init__information_schema_columns(self):
    118         "查询 information_schema.`COLUMNS` 中的列"
    119         sql = """   SELECT COLUMN_NAME 
    120                     FROM information_schema.`COLUMNS`
    121                     WHERE TABLE_SCHEMA='information_schema' AND TABLE_NAME='COLUMNS'
    122                 """
    123         result_tuple = self.execute_query(sql)
    124         column_list = [r[0] for r in result_tuple]
    125         self.__information_schema_columns = column_list
    126 
    127     def __init_table_dict(self, tableName):
    128         '初始化表'
    129         if not self.__information_schema_columns:
    130             self.__init__information_schema_columns()
    131         stitch_str = stitch_sequence(self.__information_schema_columns)
    132         sql = """   SELECT %s FROM information_schema.`COLUMNS`
    133                     WHERE TABLE_SCHEMA='%s' AND TABLE_NAME='%s'
    134                 """ % (stitch_str, self.__database, tableName)
    135         column_tuple = self.execute_query(sql)
    136         column_dict = {}
    137         for vs in column_tuple:
    138             d = {k: v for k, v in zip(self.__information_schema_columns, vs)}
    139             column_dict[d["COLUMN_NAME"]] = d
    140         self._table_dict[tableName] = column_dict
    141 
    142     def __init_table_dict_list(self):
    143         "初始化表字典对象"
    144         if not self.__information_schema_columns:
    145             self.__init__information_schema_columns()
    146         stitch_str = stitch_sequence(self.__information_schema_columns)
    147         sql = """  SELECT TABLE_NAME FROM information_schema.`TABLES` 
    148                     WHERE TABLE_SCHEMA='%s'
    149                 """ % (self.__database)
    150         table_tuple = self.execute_query(sql)
    151         self._table_dict = {t[0]: {} for t in table_tuple}
    152         for table in table_tuple:
    153             self.__init_table_dict(table[0])
    154 
    155     def __init__table_column_dict_list(self):
    156         '''初始化表字段字典列表'''
    157         for table, column_dict in self._table_dict.items():
    158             column_list = [column for column in column_dict.keys()]
    159             self.__table_column_dict_list[table] = column_list
    160 
    161     def __parse_result(self, result):
    162         '用于解析单个查询结果,返回字典对象'
    163         if result is None:
    164             return None
    165         obj = {k: v for k, v in zip(self.__column_list, result)}
    166         return obj
    167 
    168     def __parse_results(self, results):
    169         '用于解析多个查询结果,返回字典列表对象'
    170         if results is None:
    171             return None
    172         objs = [self.__parse_result(result) for result in results]
    173         return objs
    174 
    175     def __getpk(self, tableName):
    176         '获取表对应的主键字段'
    177         if self._table_dict.get(tableName) is None:
    178             raise Exception(tableName, "is not exist.")
    179         for column, column_dict in self._table_dict[tableName].items():
    180             if column_dict["COLUMN_KEY"] == "PRI":
    181                 return column
    182 
    183     def __get_table_column_list(self, tableName=None):
    184         '查询表的字段列表, 将查询出来的字段列表存入 __fields 中'
    185         return self.__table_column_dict_list[tableName]
    186 
    187     def __check_tableName(self, tableName):
    188         '''验证 tableName 参数'''
    189         if tableName is None:
    190             if self.__tableName is None:
    191                 raise Exception("Parameter [tableName] is None.")
    192         else:
    193             if self.__tableName != tableName:
    194                 self.__tableName = tableName
    195                 self.__column_list = self.__table_column_dict_list[self.__tableName]
    196 
    197     def execute_query(self, sql, single=False):
    198         '''执行查询 SQL 语句
    199         - @sql    查询 sql
    200         - @single 是否查询单个结果集,默认False
    201         '''
    202         try:
    203             logging.info("[%s] SQL >>> [%s]" % (self.__database, sql))
    204             self.__cursor.execute(sql)
    205             if single:
    206                 result_tuple = self.__cursor.fetchone()
    207             else:
    208                 result_tuple = self.__cursor.fetchall()
    209             return result_tuple
    210         except Exception as e:
    211             logging.error(e)
    212     
    213     def execute_update(self, sql):
    214         '''执行更新 SQL 语句'''
    215         try:
    216             # 获取数据库游标
    217             logging.info("[%s] SQL >>> [%s]" % (self.__database, sql))
    218             result = self.__cursor.execute(sql)
    219             self.__conn.commit()
    220             return result
    221         except Exception as e:
    222             logging.error(e)
    223             self.__conn.rollback()
    224 
    225     def select_one(self, tableName=None, filters={}):
    226         '''查询单个对象
    227         - @tableName 表名
    228         - @filters 过滤条件
    229         - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value
    230         '''
    231         self.__check_tableName(tableName)
    232         FIELDS = stitch_sequence(
    233             self.__get_table_column_list(self.__tableName))
    234         sql = "SELECT %s FROM %s" % (FIELDS, self.__tableName)
    235         sql = QueryUtil.query_sql(sql, filters)
    236         result = self.execute_query(sql, single=True)
    237         return self.__parse_result(result)
    238 
    239     def select_pk(self, tableName=None, primaryKey=None):
    240         '''按主键查询
    241         - @tableName 表名
    242         - @primaryKey 主键值
    243         '''
    244         self.__check_tableName(tableName)
    245         FIELDS = stitch_sequence(
    246             self.__get_table_column_list(self.__tableName))
    247         sql = "SELECT %s FROM %s" % (FIELDS, self.__tableName)
    248         sql = QueryUtil.query_sql(sql, {self.__getpk(tableName): primaryKey})
    249         result = self.execute_query(sql, single=True)
    250         return self.__parse_result(result)
    251 
    252     def select_all(self, tableName=None, filters={}):
    253         '''查询所有
    254         - @tableName 表名
    255         - @filters 过滤条件
    256         - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value
    257         '''
    258         self.__check_tableName(tableName)
    259         FIELDS = stitch_sequence(
    260             self.__get_table_column_list(self.__tableName))
    261         sql = "SELECT %s FROM %s" % (FIELDS, self.__tableName)
    262         sql = QueryUtil.query_sql(sql, filters)
    263         results = self.execute_query(sql)
    264         return self.__parse_results(results)
    265 
    266     def count(self, tableName=None):
    267         '''统计记录数'''
    268         self.__check_tableName(tableName)
    269         sql = "SELECT count(*) FROM %s" % (self.__tableName)
    270         result = self.execute_query(sql, single=True)
    271         return result[0]
    272 
    273     def select_page(self, tableName=None, page=None, filters={}):
    274         '''分页查询
    275         - @tableName 表名
    276         - @return 返回字典集合,集合中以表字段作为 key,字段值作为 value
    277         '''
    278         self.__check_tableName(tableName)
    279         if page is None:
    280             page = Page()
    281         filters["page"] = page
    282         FIELDS = stitch_sequence(
    283             self.__get_table_column_list(self.__tableName))
    284         sql = "SELECT %s FROM %s" % (FIELDS, self.__tableName)
    285         sql = QueryUtil.query_sql(sql, filters)
    286         result_tuple = self.execute_query(sql)
    287         return self.__parse_results(result_tuple)
    288 
    289     def save(self, tableName=None, obj=dict()):
    290         '''保存方法
    291         - @param tableName 表名
    292         - @param obj 对象
    293         - @return 影响行数
    294         '''
    295         self.__check_tableName(tableName)
    296         pk = self.__getpk(self.__tableName)
    297         if pk not in obj.keys():
    298             obj[pk] = None
    299         FIELDS = stitch_sequence(obj.keys())
    300         VALUES = []
    301         for k, v in obj.items():
    302             if self._table_dict[self.__tableName][k]["COLUMN_KEY"] != "PKI":
    303                 v = "null" if v is None else '"%s"' % v
    304                 # if v is None:
    305                 #     v = "null"
    306                 # else:
    307                 #     v = '"%s"' % v
    308             VALUES.append(v)
    309         VALUES = stitch_sequence(VALUES, isField=False)
    310         sql = 'INSERT INTO `%s` (%s) VALUES(%s)' % (
    311             self.__tableName, FIELDS, VALUES)
    312         return self.execute_update(sql)
    313 
    314     def update_by_primarykey(self, tableName=None, obj={}):
    315         '''更新方法(根据主键更新,包含空值)
    316         - @param tableName 表名
    317         - @param obj 对象
    318         - @return 影响行数
    319         '''
    320         self.__check_tableName(tableName)
    321         pk = self.__getpk(self.__tableName)
    322         if pk not in obj.keys() or obj.get(pk) is None:
    323             raise ValueError("Parameter [obj.%s] is None." % pk)
    324         l = []
    325         where = "WHERE "
    326         for k, v in obj.items():
    327             if self._table_dict[tableName][k]["COLUMN_KEY"] != "PRI":
    328                 if v is None:
    329                     if self._table_dict[tableName][k]["IS_NULLABLE"] == "YES":
    330                         l.append("%s=null" % (k))
    331                     else:
    332                         l.append("%s=''" % (k))
    333                 else:
    334                     l.append("%s='%s'" % (k, v))
    335             else:
    336                 where += "%s='%s'" % (k, v)
    337         sql = "UPDATE `%s` SET %s %s" % (
    338             self.__tableName, stitch_sequence(l, isField=False), where)
    339         return self.execute_update(sql)
    340 
    341     def update_by_primarikey_selective(self, tableName=None, obj={}):
    342         '''更新方法(根据主键更新,不包含空值)
    343         - @param tableName 表名
    344         - @param obj 对象
    345         - @return 影响行数
    346         '''
    347         self.__check_tableName(tableName)
    348         pk = self.__getpk(self.__tableName)
    349         if pk not in obj.keys() or obj.get(pk) is None:
    350             raise ValueError("Parameter [obj.%s] is None." % pk)
    351         where = "WHERE "
    352         l = []
    353         for k, v in obj.items():
    354             if self._table_dict[self.__tableName][k]["COLUMN_KEY"] != "PRI":
    355                 if v is None:
    356                     continue
    357                 l.append("%s='%s'" % (k, v))
    358             else:
    359                 where += "%s='%s'" % (k, v)
    360         sql = "UPDATE `%s` SET %s %s" % (
    361             self.__tableName, stitch_sequence(l, isField=False), where)
    362         return self.execute_update(sql)
    363 
    364     def remove_by_primarykey(self, tableName=None, value=None):
    365         '''删除方法(根据主键删除)
    366         - @param tableName 表名
    367         - @param valuej 主键值
    368         - @return 影响行数
    369         '''
    370         self.__check_tableName(tableName)
    371         if value is None:
    372             raise ValueError("Parameter [value] can not be None.")
    373         pk = self.__getpk(self.__tableName)
    374         sql = "DELETE FROM `%s` WHERE `%s`='%s'" % (
    375             self.__tableName, pk, value)
    376         return self.execute_update(sql)
    377 
    378 
    379 class Page(object):
    380     '分页对象'
    381 
    382     def __init__(self, pageNum=1, pageSize=10, count=False):
    383         '''
    384         Page 初始化方法
    385         - @param pageNum 页码,默认为1
    386         - @param pageSize 页面大小, 默认为10
    387         - @param count 是否包含 count 查询
    388         '''
    389         self.pageNum = pageNum if pageNum > 0 else 1            # 当前页数
    390         self.pageSize = pageSize if pageSize > 0 else 10        # 分页大小
    391         self.total = 0                                          # 总记录数
    392         self.pages = 1                                          # 总页数
    393         self.startRow = (self.pageNum - 1) * 
    394             self.pageSize     # 起始行(用于 mysql 分页查询)
    395         self.endRow = self.startRow + self.pageSize             # 结束行(用于 mysql 分页查询)
    396 
    397 
    398 class QueryUtil(object):
    399     '''
    400     SQL 语句拼接工具类:
    401     - 主方法:querySql(sql, filters)
    402     - 参数说明:   
    403     - @param sql:需要拼接的 SQL 语句
    404     - @param filters:拼接 SQL 的过滤条件 
    
    405     filters 过滤条件说明:
    406     - 支持拼接条件如下:
    407     - 1、等于(如:{"id": 2}, 拼接后为:id=2)
    408     - 2、不等于(如:{"_ne_id": 2}, 拼接后为:id != 2)
    409     - 3、小于(如:{"_lt_id": 2},拼接后为:id < 2)
    410     - 4、小于等于(如:{"_le_id": 2},拼接后为:id <= 2)
    411     - 5、大于(如:{"_gt_id": },拼接后为:id > 2)
    412     - 6、大于等于(如:{"_ge_id": },拼接后为:id >=2)
    413     - 7、in(如:{"_in_id": "1,2,3"},拼接后为:id IN(1,2,3))
    414     - 8、not in(如:{"_nein_id": "4,5,6"},拼接后为:id NOT IN(4,5,6))
    415     - 9、like(如:{"_like_name": },拼接后为:name LIKE '%zhang%')
    416     - 10、like(如:{"_llike_name": },拼接后为:name LIKE '%zhang')
    417     - 11、like(如:{"_rlike_name": },拼接后为:name LIKE 'zhang%')
    418     - 12、分组(如:{"groupby": "status"},拼接后为:GROUP BY status)
    419     - 13、排序(如:{"orderby": "createDate"},拼接后为:ORDER BY createDate)
    420     '''
    421 
    422     NE = "_ne_"                 # 拼接不等于
    423     LT = "_lt_"                 # 拼接小于
    424     LE = "_le_"                 # 拼接小于等于
    425     GT = "_gt_"                 # 拼接大于
    426     GE = "_ge_"                 # 拼接大于等于
    427     IN = "_in_"                 # 拼接 in
    428     NE_IN = "_nein_"            # 拼接 not in
    429     LIKE = "_like_"             # 拼接 like
    430     LEFT_LIKE = "_llike_"       # 拼接左 like
    431     RIGHT_LIKE = "_rlike_"      # 拼接右 like
    432     GROUP = "groupby"           # 拼接分组
    433     ORDER = "orderby"           # 拼接排序
    434     ORDER_TYPE = "ordertype"    # 排序类型:asc(升序)、desc(降序)
    435 
    436     @staticmethod
    437     def __filter_params(filters):
    438         '''过滤参数条件'''
    439         s = " WHERE 1=1"
    440         for k, v in filters.items():
    441             if k.startswith(QueryUtil.IN):                  # 拼接 in
    442                 s += " AND `%s` IN (" % (k[len(QueryUtil.IN):])
    443                 values = v.split(",")
    444                 for value in values:
    445                     s += " %s," % value
    446                 s = s[0:len(s) - 1] + ") "
    447             elif k.startswith(QueryUtil.NE_IN):             # 拼接 not in
    448                 s += " AND `%s` NOT IN (" % (k[len(QueryUtil.NE_IN):])
    449                 values = v.split(",")
    450                 for value in values:
    451                     s += " %s," % value
    452                 s = s[0:len(s) - 1] + ") "
    453             elif k.startswith(QueryUtil.LIKE):              # 拼接 like
    454                 s += " AND `%s` LIKE '%%%s%%' " % (k[len(QueryUtil.LIKE):], v)
    455             elif k.startswith(QueryUtil.LEFT_LIKE):         # 拼接左 like
    456                 s += " AND `%s` LIKE '%%%s' " % (
    457                     k[len(QueryUtil.LEFT_LIKE):], v)
    458             elif k.startswith(QueryUtil.RIGHT_LIKE):        # 拼接右 like
    459                 s += " AND `%s` LIKE '%s%%' " % (
    460                     k[len(QueryUtil.RIGHT_LIKE):], v)
    461             elif k.startswith(QueryUtil.NE):                # 拼接不等于
    462                 s += " AND `%s` != '%s' " % (k[len(QueryUtil.NE):], v)
    463             elif k.startswith(QueryUtil.LT):                # 拼接小于
    464                 s += " AND `%s` < '%s' " % (k[len(QueryUtil.LT):], v)
    465             elif k.startswith(QueryUtil.LE):                # 拼接小于等于
    466                 s += " AND `%s` <= '%s' " % (k[len(QueryUtil.LE):], v)
    467             elif k.startswith(QueryUtil.GT):                # 拼接大于
    468                 s += " AND `%s` > '%s' " % (k[len(QueryUtil.GT):], v)
    469             elif k.startswith(QueryUtil.GE):                # 拼接大于等于
    470                 s += " AND `%s` >= '%s' " % (k[len(QueryUtil.GE):], v)
    471             else:                                           # 拼接等于
    472                 if isinstance(v, str):
    473                     s += " AND `%s`='%s' " % (k, v)
    474                 elif isinstance(v, int):
    475                     s += " AND `%s`=%d " % (k, v)
    476         return s
    477 
    478     @staticmethod
    479     def __filter_group(filters):
    480         '''过滤分组'''
    481         group = filters.pop(QueryUtil.GROUP)
    482         s = " GROUP BY %s" % (group)
    483         return s
    484 
    485     @staticmethod
    486     def __filter_order(filters):
    487         '''过滤排序'''
    488         order = filters.pop(QueryUtil.ORDER)
    489         type = filters.pop(QueryUtil.ORDER_TYPE, "asc")
    490         s = " ORDER BY `%s` %s" % (order, type)
    491         return s
    492 
    493     @staticmethod
    494     def __filter_page(filters):
    495         '''过滤 page 对象'''
    496         page = filters.pop("page")
    497         return " LIMIT %d,%d" % (page.startRow, page.endRow)
    498 
    499     @staticmethod
    500     def query_sql(sql=None, filters=dict()):
    501         '''拼接 SQL 查询条件
    502         - @param sql SQL 语句
    503         - @param filters 过滤条件
    504         - @return 返回拼接 SQL
    505         '''
    506         if not filters:
    507             return sql
    508         else:
    509             if not isinstance(filters, dict):
    510                 raise Exception("Parameter [filters] must be dict.")
    511             group = None
    512             order = None
    513             page = None
    514             if filters.get("groupby") != None:
    515                 group = QueryUtil.__filter_group(filters)
    516             if filters.get("orderby") != None:
    517                 order = QueryUtil.__filter_order(filters)
    518             if filters.get("page") != None:
    519                 page = QueryUtil.__filter_page(filters)
    520             sql += QueryUtil.__filter_params(filters)
    521             if group:
    522                 sql += group
    523             if order:
    524                 sql += order
    525             if page:
    526                 sql += page
    527         return sql
    View Code

      代码下载地址(GitHub): https://github.com/ruancheng77/baseDao

      代码中已经给出了几个具体示例,大家可以参考使用。

      如果有感兴趣一起学习、讨论Python的可以加QQ群:626787819,有啥意见或者建议的可以发我邮箱:410093793@qq.com。

      

  • 相关阅读:
    Legacy和UEFI,MBR和GPT的区别
    如何升级laravel5.4到laravel5.5并使用新特性?
    value toDF is not a member of org.apache.spark.rdd.RDD
    spark能传递外部命名参数给main函数吗?
    spark在idea中本地如何运行?(处理问题NoSuchFieldException: SHUTDOWN_HOOK_PRIORITY)
    工作随笔-20171012
    maven使用实战
    介绍maven构建的生命周期
    python中的pip
    python中的None
  • 原文地址:https://www.cnblogs.com/rcddup/p/7133378.html
Copyright © 2011-2022 走看看