zoukankan      html  css  js  c++  java
  • 我的第一个python web开发框架(32)——定制ORM(八)

      写到这里,基本的ORM功能就完成了,不知大家有没有发现,这个ORM每个方法都是在with中执行的,也就是说每个方法都是一个完整的事务,当它执行完成以后也会将事务提交,那么如果我们想要进行一个复杂的事务时,它并不能做到,所以我们还需要对它进行改造,让它支持sql事务。

      那么应该怎么实现呢?我们都知道要支持事务,就必须让不同的sql语句在同一个事务中执行,也就是说,我们需要在一个with中执行所有的sql语句,失败则回滚,成功再提交事务。

      由于我们的逻辑层各个类都是继承ORM基类来实现的,而事务的开关放在各个类中就不合适,可能会存在问题,所以在执行事务时,直接调用db_helper模块,使用with初始化好数据库链接,然后在方法里编写并执行各个sql语句。

      当前逻辑层基类(ORM模块)的sql语句都是在方法中生成(拼接)的,然后在方法的with模块中执行,所以我们需要再次对整个类进行改造,将所有的sql生成方法提炼出来,成为单独的方法,然后在事务中,我们不直接执行获取结果,而是通过ORM生成对应的sql语句,在with中执行这样语句。(当然还有其他方法也能实现事务,不过在这里不做进一步的探讨,因为当前这种是最简单实现事务的方式之一,多层封装处理,有可能会导致系统变的更加复杂,代码更加难懂)

      代码改造起来很简单,比如说获取记录方法

     1     def get_model(self, wheres):
     2         """通过条件获取一条记录"""
     3         # 如果有条件,则自动添加where
     4         if wheres:
     5             wheres = ' where ' + wheres
     6 
     7         # 合成sql语句
     8         sql = "select %(column_name_list)s from %(table_name)s %(wheres)s" % 
     9               {'column_name_list': self.__column_name_list, 'table_name': self.__table_name, 'wheres': wheres}
    10         # 初化化数据库链接
    11         result = self.select(sql)
    12         if result:
    13             return result[0]
    14         return {}

      我们可以将它拆分成get_model_sql()和get_model()两个方法,一个处理sql组合,一个执行获取结果,前者可以给事务调用,后者直接给对应的程序调用

     1     def get_model_sql(self, wheres):
     2         """通过条件获取一条记录"""
     3         # 如果有条件,则自动添加where
     4         if wheres:
     5             wheres = ' where ' + wheres
     6 
     7         # 合成sql语句
     8         sql = "select %(column_name_list)s from %(table_name)s %(wheres)s" % 
     9               {'column_name_list': self.__column_name_list, 'table_name': self.__table_name, 'wheres': wheres}
    10         return sql
    11 
    12     def get_model(self, wheres):
    13         """通过条件获取一条记录"""
    14         # 生成sql
    15         sql = self.get_model_sql(wheres)
    16         # 初化化数据库链接
    17         result = self.select(sql)
    18         if result:
    19             return result[0]
    20         return {}

      其他代码不一一细述,大家自己看看重构后的结果

      1 #!/usr/bin/env python
      2 # coding=utf-8
      3 
      4 from common import db_helper, cache_helper, encrypt_helper
      5 
      6 
      7 class LogicBase():
      8     """逻辑层基础类"""
      9 
     10     def __init__(self, db, is_output_sql, table_name, column_name_list='*', pk_name='id'):
     11         """类初始化"""
     12         # 数据库参数
     13         self.__db = db
     14         # 是否输出执行的Sql语句到日志中
     15         self.__is_output_sql = is_output_sql
     16         # 表名称
     17         self.__table_name = str(table_name).lower()
     18         # 查询的列字段名称,*表示查询全部字段,多于1个字段时用逗号进行分隔,除了字段名外,也可以是表达式
     19         self.__column_name_list = str(column_name_list).lower()
     20         # 主健名称
     21         self.__pk_name = str(pk_name).lower()
     22         # 缓存列表
     23         self.__cache_list = self.__table_name + '_cache_list'
     24 
     25     #####################################################################
     26     ### 生成Sql ###
     27     def get_model_sql(self, wheres):
     28         """通过条件获取一条记录"""
     29         # 如果有条件,则自动添加where
     30         if wheres:
     31             wheres = ' where ' + wheres
     32 
     33         # 合成sql语句
     34         sql = "select %(column_name_list)s from %(table_name)s %(wheres)s" % 
     35               {'column_name_list': self.__column_name_list, 'table_name': self.__table_name, 'wheres': wheres}
     36         return sql
     37 
     38     def get_model_for_pk_sql(self, pk, wheres=''):
     39         """通过主键值获取数据库记录实体"""
     40         # 组装查询条件
     41         wheres = '%s = %s' % (self.__pk_name, str(pk))
     42         return self.get_model_sql(wheres)
     43 
     44     def get_value_sql(self, column_name, wheres=''):
     45         """
     46         获取指定条件的字段值————多于条记录时,只取第一条记录
     47         :param column_name: 单个字段名,如:id
     48         :param wheres: 查询条件
     49         :return: 7 (指定的字段值)
     50         """
     51         if wheres:
     52             wheres = ' where ' + wheres
     53 
     54         sql = 'select %(column_name)s from %(table_name)s %(wheres)s limit 1' % 
     55               {'column_name': column_name, 'table_name': self.__table_name, 'wheres': wheres}
     56         return sql
     57 
     58     def get_value_list_sql(self, column_name, wheres=''):
     59         """
     60         获取指定条件记录的字段值列表
     61         :param column_name: 单个字段名,如:id
     62         :param wheres: 查询条件
     63         :return: [1,3,4,6,7]
     64         """
     65         if not column_name:
     66             column_name = self.__pk_name
     67         elif wheres:
     68             wheres = ' where ' + wheres
     69 
     70         sql = 'select array_agg(%(column_name)s) as list from %(table_name)s %(wheres)s' % 
     71               {'column_name': column_name, 'table_name': self.__table_name, 'wheres': wheres}
     72         return sql
     73 
     74     def add_model_sql(self, fields, returning=''):
     75         """新增数据库记录"""
     76         ### 拼接sql语句 ###
     77         # 初始化变量
     78         key_list = []
     79         value_list = []
     80         # 将传入的字典参数进行处理,把字段名生成sql插入字段名数组和字典替换数组
     81         # PS:字符串使用字典替换参数时,格式是%(name)s,这里会生成对应的字串
     82         # 比如:
     83         #   传入的字典为: {'id': 1, 'name': '名称'}
     84         #   那么生成的key_list为:'id','name'
     85         #   而value_list为:'%(id)s,%(name)s'
     86         #   最终而value_list为字符串对应名称位置会被替换成相应的值
     87         for key in fields.keys():
     88             key_list.append(key)
     89             value_list.append('%(' + key + ')s')
     90         # 设置sql拼接字典,并将数组(lit)使用join方式进行拼接,生成用逗号分隔的字符串
     91         parameter = {
     92             'table_name': self.__table_name,
     93             'pk_name': self.__pk_name,
     94             'key_list': ','.join(key_list),
     95             'value_list': ','.join(value_list)
     96         }
     97         # 如果有指定返回参数,则添加
     98         if returning:
     99             parameter['returning'] = ', ' + returning
    100         else:
    101             parameter['returning'] = ''
    102 
    103         # 生成可以使用字典替换的字符串
    104         sql = "insert into %(table_name)s (%(key_list)s) values (%(value_list)s) returning %(pk_name)s %(returning)s" % parameter
    105         # 将生成好的字符串替字典参数值,生成最终可执行的sql语句
    106         return sql % fields
    107 
    108     def edit_sql(self, fields, wheres='', returning=''):
    109         """
    110         批量编辑数据库记录
    111         :param fields: 要更新的字段(字段名与值存储在字典中)
    112         :param wheres: 更新条件
    113         :param returning: 更新成功后,返回的字段名
    114         :param is_update_cache: 是否同步更新缓存
    115         :return:
    116         """
    117         ### 拼接sql语句 ###
    118         # 拼接字段与值
    119         field_list = [key + ' = %(' + key + ')s' for key in fields.keys()]
    120         # 设置sql拼接字典
    121         parameter = {
    122             'table_name': self.__table_name,
    123             'pk_name': self.__pk_name,
    124             'field_list': ','.join(field_list)
    125         }
    126         # 如果存在更新条件,则将条件添加到sql拼接更换字典中
    127         if wheres:
    128             parameter['wheres'] = ' where ' + wheres
    129         else:
    130             parameter['wheres'] = ''
    131 
    132         # 如果有指定返回参数,则添加
    133         if returning:
    134             parameter['returning'] = ', ' + returning
    135         else:
    136             parameter['returning'] = ''
    137 
    138         # 生成sql语句
    139         sql = "update %(table_name)s set %(field_list)s %(wheres)s returning %(pk_name)s %(returning)s" % parameter
    140         return sql % fields
    141 
    142     def edit_model_sql(self, pk, fields, wheres='', returning=''):
    143         """编辑单条数据库记录"""
    144         if wheres:
    145             wheres = self.__pk_name + ' = ' + str(pk) + ' and ' + wheres
    146         else:
    147             wheres = self.__pk_name + ' = ' + str(pk)
    148 
    149         return self.edit_sql(fields, wheres, returning)
    150 
    151     def delete_sql(self, wheres='', returning=''):
    152         """
    153         批量删除数据库记录
    154         :param wheres: 删除条件
    155         :param returning: 删除成功后,返回的字段名
    156         :param is_update_cache: 是否同步更新缓存
    157         :return:
    158         """
    159         # 如果存在条件
    160         if wheres:
    161             wheres = ' where ' + wheres
    162 
    163         # 如果有指定返回参数,则添加
    164         if returning:
    165             returning = ', ' + returning
    166 
    167         # 生成sql语句
    168         sql = "delete from %(table_name)s %(wheres)s returning %(pk_name)s %(returning)s" % 
    169               {'table_name': self.__table_name, 'wheres': wheres, 'pk_name': self.__pk_name, 'returning': returning}
    170         return sql
    171 
    172     def delete_model_sql(self, pk, wheres='', returning=''):
    173         """删除单条数据库记录"""
    174         if wheres:
    175             wheres = self.__pk_name + ' = ' + str(pk) + ' and ' + wheres
    176         else:
    177             wheres = self.__pk_name + ' = ' + str(pk)
    178 
    179         return self.delete_sql(wheres, returning)
    180 
    181     def get_list_sql(self, column_name_list='', wheres='', orderby=None, table_name=None):
    182         """
    183         获取指定条件的数据库记录集
    184         :param column_name_list:      查询字段
    185         :param wheres:      查询条件
    186         :param orderby:     排序规则
    187         :param table_name:     查询数据表,多表查询时需要设置
    188         :return:
    189         """
    190         # 初始化查询数据表名称
    191         if not table_name:
    192             table_name = self.__table_name
    193         # 初始化查询字段名
    194         if not column_name_list:
    195             column_name_list = self.__column_name_list
    196         # 初始化查询条件
    197         if wheres:
    198             # 如果是字符串,表示该查询条件已组装好了,直接可以使用
    199             if isinstance(wheres, str):
    200                 wheres = 'where ' + wheres
    201             # 如果是list,则表示查询条件有多个,可以使用join将它们用and方式组合起来使用
    202             elif isinstance(wheres, list):
    203                 wheres = 'where ' + ' and '.join(wheres)
    204         # 初始化排序
    205         if not orderby:
    206             orderby = self.__pk_name + ' desc'
    207         #############################################################
    208 
    209         ### 按条件查询数据库记录
    210         sql = "select %(column_name_list)s from %(table_name)s %(wheres)s order by %(orderby)s " % 
    211               {'column_name_list': column_name_list,
    212                'table_name': table_name,
    213                'wheres': wheres,
    214                'orderby': orderby}
    215         return sql
    216 
    217     def get_count_sql(self, wheres=''):
    218         """获取指定条件记录数量"""
    219         if wheres:
    220             wheres = ' where ' + wheres
    221         sql = 'select count(1) as total from %(table_name)s %(wheres)s ' % 
    222               {'table_name': self.__table_name, 'wheres': wheres}
    223         return sql
    224 
    225     def get_sum_sql(self, fields, wheres):
    226         """获取指定条件记录数量"""
    227         sql = 'select sum(%(fields)s) as total from %(table_name)s where %(wheres)s ' % 
    228               {'table_name': self.__table_name, 'wheres': wheres, 'fields': fields}
    229         return sql
    230 
    231     def get_min_sql(self, fields, wheres):
    232         """获取该列记录最小值"""
    233         sql = 'select min(%(fields)s) as min from %(table_name)s where %(wheres)s ' % 
    234               {'table_name': self.__table_name, 'wheres': wheres, 'fields': fields}
    235         return sql
    236 
    237     def get_max_sql(self, fields, wheres):
    238         """获取该列记录最大值"""
    239         sql = 'select max(%(fields)s) as max from %(table_name)s where %(wheres)s ' % 
    240               {'table_name': self.__table_name, 'wheres': wheres, 'fields': fields}
    241         return sql
    242 
    243     #####################################################################
    244 
    245 
    246     #####################################################################
    247     ### 执行Sql ###
    248 
    249     def select(self, sql):
    250         """执行sql查询语句(select)"""
    251         with db_helper.PgHelper(self.__db, self.__is_output_sql) as db:
    252             # 执行sql语句
    253             result = db.execute(sql)
    254             if not result:
    255                 result = []
    256         return result
    257 
    258     def execute(self, sql):
    259         """执行sql语句,并提交事务"""
    260         with db_helper.PgHelper(self.__db, self.__is_output_sql) as db:
    261             # 执行sql语句
    262             result = db.execute(sql)
    263             if result:
    264                 db.commit()
    265             else:
    266                 result = []
    267         return result
    268 
    269     def copy(self, values, columns):
    270         """批量更新数据"""
    271         with db_helper.PgHelper(self.__db, self.__is_output_sql) as db:
    272             # 执行sql语句
    273             result = db.copy(values, self.__table_name, columns)
    274         return result
    275 
    276     def get_model(self, wheres):
    277         """通过条件获取一条记录"""
    278         # 生成sql
    279         sql = self.get_model_sql(wheres)
    280         # 执行查询操作
    281         result = self.select(sql)
    282         if result:
    283             return result[0]
    284         return {}
    285 
    286     def get_model_for_pk(self, pk, wheres=''):
    287         """通过主键值获取数据库记录实体"""
    288         if not pk:
    289             return {}
    290         # 生成sql
    291         sql = self.get_model_for_pk_sql(pk, wheres)
    292         # 执行查询操作
    293         result = self.select(sql)
    294         if result:
    295             return result[0]
    296         return {}
    297 
    298     def get_value(self, column_name, wheres=''):
    299         """
    300         获取指定条件的字段值————多于条记录时,只取第一条记录
    301         :param column_name: 单个字段名,如:id
    302         :param wheres: 查询条件
    303         :return: 7 (指定的字段值)
    304         """
    305         if not column_name:
    306             return None
    307 
    308         # 生成sql
    309         sql = self.get_value_sql(column_name, wheres)
    310         result = self.select(sql)
    311         # 如果查询成功,则直接返回记录字典
    312         if result:
    313             return result[0].get(column_name)
    314 
    315     def get_value_list(self, column_name, wheres=''):
    316         """
    317         获取指定条件记录的字段值列表
    318         :param column_name: 单个字段名,如:id
    319         :param wheres: 查询条件
    320         :return: [1,3,4,6,7]
    321         """
    322         # 生成sql
    323         sql = self.get_value_list_sql(column_name, wheres)
    324         result = self.select(sql)
    325         # 如果查询失败或不存在指定条件记录,则直接返回初始值
    326         if result and isinstance(result, list):
    327             return result[0].get('list')
    328         else:
    329             return []
    330 
    331     def add_model(self, fields, returning=''):
    332         """新增数据库记录"""
    333         # 生成sql
    334         sql = self.add_model_sql(fields, returning)
    335         result = self.execute(sql)
    336         if result:
    337             return result[0]
    338         return {}
    339 
    340     def edit(self, fields, wheres='', returning='', is_update_cache=True):
    341         """
    342         批量编辑数据库记录
    343         :param fields: 要更新的字段(字段名与值存储在字典中)
    344         :param wheres: 更新条件
    345         :param returning: 更新成功后,返回的字段名
    346         :param is_update_cache: 是否同步更新缓存
    347         :return:
    348         """
    349         # 生成sql
    350         sql = self.edit_sql(fields, wheres, returning)
    351         result = self.execute(sql)
    352         if result:
    353             # 判断是否删除对应的缓存
    354             if is_update_cache:
    355                 # 循环删除更新成功的所有记录对应的缓存
    356                 for model in result:
    357                     self.del_model_for_cache(model.get(self.__pk_name, 0))
    358                 # 同步删除与本表关联的缓存
    359                 self.del_relevance_cache()
    360         return result
    361 
    362     def edit_model(self, pk, fields, wheres='', returning='', is_update_cache=True):
    363         """编辑单条数据库记录"""
    364         if not pk:
    365             return {}
    366         # 生成sql
    367         sql = self.edit_model_sql(pk, fields, wheres, returning)
    368         result = self.execute(sql)
    369         if result:
    370             # 判断是否删除对应的缓存
    371             if is_update_cache:
    372                 # 删除更新成功的所有记录对应的缓存
    373                 self.del_model_for_cache(result[0].get(self.__pk_name, 0))
    374                 # 同步删除与本表关联的缓存
    375                 self.del_relevance_cache()
    376         return result
    377 
    378     def delete(self, wheres='', returning='', is_update_cache=True):
    379         """
    380         批量删除数据库记录
    381         :param wheres: 删除条件
    382         :param returning: 删除成功后,返回的字段名
    383         :param is_update_cache: 是否同步更新缓存
    384         :return:
    385         """
    386         # 生成sql
    387         sql = self.delete_sql(wheres, returning)
    388         result = self.execute(sql)
    389         if result:
    390             # 同步删除对应的缓存
    391             if is_update_cache:
    392                 for model in result:
    393                     self.del_model_for_cache(model.get(self.__pk_name, 0))
    394                 # 同步删除与本表关联的缓存
    395                 self.del_relevance_cache()
    396         return result
    397 
    398     def delete_model(self, pk, wheres='', returning='', is_update_cache=True):
    399         """删除单条数据库记录"""
    400         if not pk:
    401             return {}
    402         # 生成sql
    403         sql = self.delete_model_sql(pk, wheres, returning)
    404         result = self.execute(sql)
    405         if result:
    406             # 同步删除对应的缓存
    407             if is_update_cache:
    408                 self.del_model_for_cache(result[0].get(self.__pk_name, 0))
    409                 # 同步删除与本表关联的缓存
    410                 self.del_relevance_cache()
    411         return result
    412 
    413     def get_list(self, column_name_list='', wheres='', page_number=None, page_size=None, orderby=None, table_name=None):
    414         """
    415         获取指定条件的数据库记录集
    416         :param column_name_list:      查询字段
    417         :param wheres:      查询条件
    418         :param page_number:   分页索引值
    419         :param page_size:    分页大小, 存在值时才会执行分页
    420         :param orderby:     排序规则
    421         :param table_name:     查询数据表,多表查询时需要设置
    422         :return: 返回记录集总数量与分页记录集
    423             {'records': 0, 'total': 0, 'page': 0, 'rows': []}
    424         """
    425         # 初始化输出参数:总记录数量与列表集
    426         data = {
    427             'records': 0,  # 总记录数
    428             'total': 0,  # 总页数
    429             'page': 1,  # 当前页面索引
    430             'rows': [],  # 查询结果(记录列表)
    431         }
    432         # 初始化查询数据表名称
    433         if not table_name:
    434             table_name = self.__table_name
    435         # 初始化查询字段名
    436         if not column_name_list:
    437             column_name_list = self.__column_name_list
    438         # 初始化查询条件
    439         if wheres:
    440             # 如果是字符串,表示该查询条件已组装好了,直接可以使用
    441             if isinstance(wheres, str):
    442                 wheres = 'where ' + wheres
    443             # 如果是list,则表示查询条件有多个,可以使用join将它们用and方式组合起来使用
    444             elif isinstance(wheres, list):
    445                 wheres = 'where ' + ' and '.join(wheres)
    446         # 初始化排序
    447         if not orderby:
    448             orderby = self.__pk_name + ' desc'
    449         # 初始化分页查询的记录区间
    450         paging = ''
    451 
    452         with db_helper.PgHelper(self.__db, self.__is_output_sql) as db:
    453             #############################################################
    454             # 判断是否需要进行分页
    455             if not page_size is None:
    456                 ### 执行sql,获取指定条件的记录总数量
    457                 sql = 'select count(1) as records from %(table_name)s %(wheres)s ' % 
    458                       {'table_name': table_name, 'wheres': wheres}
    459                 result = db.execute(sql)
    460                 # 如果查询失败或不存在指定条件记录,则直接返回初始值
    461                 if not result or result[0]['records'] == 0:
    462                     return data
    463 
    464                 # 设置记录总数量
    465                 data['records'] = result[0].get('records')
    466 
    467                 #########################################################
    468                 ### 设置分页索引与页面大小 ###
    469                 if page_size <= 0:
    470                     page_size = 10
    471                 # 计算总分页数量:通过总记录数除于每页显示数量来计算总分页数量
    472                 if data['records'] % page_size == 0:
    473                     page_total = data['records'] // page_size
    474                 else:
    475                     page_total = data['records'] // page_size + 1
    476                 # 判断页码是否超出限制,超出限制查询时会出现异常,所以将页面索引设置为最后一页
    477                 if page_number < 1 or page_number > page_total:
    478                     page_number = page_total
    479                 # 记录总页面数量
    480                 data['total'] = page_total
    481                 # 记录当前页面值
    482                 data['page'] = page_number
    483                 # 计算当前页面要显示的记录起始位置(limit指定的位置)
    484                 record_number = (page_number - 1) * page_size
    485                 # 设置查询分页条件
    486                 paging = ' limit ' + str(page_size) + ' offset ' + str(record_number)
    487             #############################################################
    488 
    489             ### 按条件查询数据库记录
    490             sql = "select %(column_name_list)s from %(table_name)s %(wheres)s order by %(orderby)s %(paging)s" % 
    491                   {'column_name_list': column_name_list,
    492                    'table_name': table_name,
    493                    'wheres': wheres,
    494                    'orderby': orderby,
    495                    'paging': paging}
    496             result = db.execute(sql)
    497             if result:
    498                 data['rows'] = result
    499                 # 不需要分页查询时,直接在这里设置总记录数
    500                 if page_size is None:
    501                     data['records'] = len(result)
    502 
    503         return data
    504 
    505     def get_count(self, wheres=''):
    506         """获取指定条件记录数量"""
    507         # 生成sql
    508         sql = self.get_count_sql(wheres)
    509         result = self.select(sql)
    510         # 如果查询存在记录,则返回true
    511         if result:
    512             return result[0].get('total')
    513         return 0
    514 
    515     def get_sum(self, fields, wheres):
    516         """获取指定条件记录数量"""
    517         # 生成sql
    518         sql = self.get_sum_sql(fields, wheres)
    519         result = self.select(sql)
    520         # 如果查询存在记录,则返回true
    521         if result and result[0].get('total'):
    522             return result[0].get('total')
    523         return 0
    524 
    525     def get_min(self, fields, wheres):
    526         """获取该列记录最小值"""
    527         # 生成sql
    528         sql = self.get_min_sql(fields, wheres)
    529         result = self.select(sql)
    530         # 如果查询存在记录,则返回true
    531         if result and result[0].get('min'):
    532             return result[0].get('min')
    533 
    534     def get_max(self, fields, wheres):
    535         """获取该列记录最大值"""
    536         # 生成sql
    537         sql = self.get_max_sql(fields, wheres)
    538         result = self.select(sql)
    539         # 如果查询存在记录,则返回true
    540         if result and result[0].get('max'):
    541             return result[0].get('max')
    542 
    543     #####################################################################
    544 
    545 
    546     #####################################################################
    547     ### 缓存操作方法 ###
    548 
    549     def get_cache_key(self, pk):
    550         """获取缓存key值"""
    551         return ''.join((self.__table_name, '_', str(pk)))
    552 
    553     def set_model_for_cache(self, pk, value, time=43200):
    554         """更新存储在缓存中的数据库记录,缓存过期时间为12小时"""
    555         # 生成缓存key
    556         key = self.get_cache_key(pk)
    557         # 存储到nosql缓存中
    558         cache_helper.set(key, value, time)
    559 
    560     def get_model_for_cache(self, pk):
    561         """从缓存中读取数据库记录"""
    562         # 生成缓存key
    563         key = self.get_cache_key(pk)
    564         # 从缓存中读取数据库记录
    565         result = cache_helper.get(key)
    566         # 缓存中不存在记录,则从数据库获取
    567         if not result:
    568             result = self.get_model_for_pk(pk)
    569             self.set_model_for_cache(pk, result)
    570         if result:
    571             return result
    572         else:
    573             return {}
    574 
    575     def get_model_for_cache_of_where(self, where):
    576         """
    577         通过条件获取记录实体(我们经常需要使用key、编码或指定条件来获取记录,这时可以通过当前方法来获取)
    578         :param where: 查询条件
    579         :return: 记录实体
    580         """
    581         # 生成实体缓存key
    582         model_cache_key = self.__table_name + encrypt_helper.md5(where)
    583         # 通过条件从缓存中获取记录id
    584         pk = cache_helper.get(model_cache_key)
    585         # 如果主键id存在,则直接从缓存中读取记录
    586         if pk:
    587             return self.get_model_for_cache(pk)
    588 
    589         # 否则从数据库中获取
    590         result = self.get_model(where)
    591         if result:
    592             # 存储条件对应的主键id值到缓存中
    593             cache_helper.set(model_cache_key, result.get(self.__pk_name))
    594             # 存储记录实体到缓存中
    595             self.set_model_for_cache(result.get(self.__pk_name), result)
    596             return result
    597 
    598     def get_value_for_cache(self, pk, column_name):
    599         """获取指定记录的字段值"""
    600         return self.get_model_for_cache(pk).get(column_name)
    601 
    602     def del_model_for_cache(self, pk):
    603         """删除缓存中指定数据"""
    604         # 生成缓存key
    605         key = self.get_cache_key(pk)
    606         # log_helper.info(key)
    607         # 存储到nosql缓存中
    608         cache_helper.delete(key)
    609 
    610     def add_relevance_cache_in_list(self, key):
    611         """将缓存名称存储到列表里————主要存储与记录变更关联的"""
    612         # 从nosql中读取全局缓存列表
    613         cache_list = cache_helper.get(self.__cache_list)
    614         # 判断缓存列表是否有值,有则进行添加操作
    615         if cache_list:
    616             # 判断是否已存储列表中,不存在则执行添加操作
    617             if not key in cache_list:
    618                 cache_list.append(key)
    619                 cache_helper.set(self.__cache_list, cache_list)
    620         # 无则直接创建全局缓存列表,并存储到nosql中
    621         else:
    622             cache_list = [key]
    623             cache_helper.set(self.__cache_list, cache_list)
    624 
    625     def del_relevance_cache(self):
    626         """删除关联缓存————将和数据表记录关联的,个性化缓存全部删除"""
    627         # 从nosql中读取全局缓存列表
    628         cache_list = cache_helper.get(self.__cache_list)
    629         # 清除已删除缓存列表
    630         cache_helper.delete(self.__cache_list)
    631         if cache_list:
    632             # 执行删除操作
    633             for cache in cache_list:
    634                 cache_helper.delete(cache)
    635 
    636     #####################################################################
    View Code

      从完整代码可以看到,重构后的类多了很多sql生成方法,它们其实是从原方法中分享出sql合成代码,将它们独立出来而已。

      接下来我们编写单元测试代码,执行一下事务看看效果

     1 #!/usr/bin/evn python
     2 # coding=utf-8
     3 
     4 import unittest
     5 from common import db_helper
     6 from common.string_helper import string
     7 from config import db_config
     8 from logic import product_logic, product_class_logic
     9 
    10 
    11 class DbHelperTest(unittest.TestCase):
    12     """数据库操作包测试类"""
    13 
    14     def setUp(self):
    15         """初始化测试环境"""
    16         print('------ini------')
    17 
    18     def tearDown(self):
    19         """清理测试环境"""
    20         print('------clear------')
    21 
    22     def test(self):
    23         ##############################################
    24         # 只需要看这里,其他代码是测试用例的模板代码 #
    25         ##############################################
    26         # 测试事务
    27         # 使用with方法,初始化数据库链接
    28         with db_helper.PgHelper(db_config.DB, db_config.IS_OUTPUT_SQL) as db:
    29             # 实例化product表操作类ProductLogic
    30             _product_logic = product_logic.ProductLogic()
    31             # 实例化product_class表操作类product_class_logic
    32             _product_class_logic = product_class_logic.ProductClassLogic()
    33             # 初始化产品分类主键id
    34             id = 1
    35 
    36             # 获取产品分类信息(为了查看效果,所以加了这段获取分类信息)
    37             sql = _product_class_logic.get_model_for_pk_sql(id)
    38             print(sql)
    39             # 执行sql语句
    40             result = db.execute(sql)
    41             if not result:
    42                 print('不存在指定的产品分类')
    43                 return
    44             print('----产品分类实体----')
    45             print(result)
    46             print('-------------------')
    47 
    48             # 禁用产品分类
    49             fields = {
    50                 'is_enable': 0
    51             }
    52             sql = _product_class_logic.edit_model_sql(id, fields, returning='is_enable')
    53             print(sql)
    54             # 执行sql语句
    55             result = db.execute(sql)
    56             if not result:
    57                 # 执行失败,执行回滚操作
    58                 db.rollback()
    59                 print('禁用产品分类失败')
    60                 return
    61             # 执行缓存清除操作
    62             _product_class_logic.del_model_for_cache(id)
    63             _product_class_logic.del_relevance_cache()
    64             print('----执行成功后的产品分类实体----')
    65             print(result)
    66             print('-------------------------------')
    67 
    68             # 同步禁用产品分类对应的所有产品
    69             sql = _product_logic.edit_sql(fields, 'product_class_id=' + str(id), returning='is_enable')
    70             print(sql)
    71             # 执行sql语句
    72             result = db.execute(sql)
    73             if not result:
    74                 # 执行失败,执行回滚操作
    75                 db.rollback()
    76                 print('同步禁用产品分类对应的所有产品失败')
    77                 return
    78             # 执行缓存清除操作
    79             for model in result:
    80                 _product_class_logic.del_model_for_cache(model.get('id'))
    81             _product_class_logic.del_relevance_cache()
    82             print('----执行成功后的产品实体----')
    83             print(result)
    84             print('---------------------------')
    85 
    86             db.commit()
    87             print('执行成功')
    88         ##############################################
    89 
    90 if __name__ == '__main__':
    91     unittest.main()

      细心的朋友可能会发现,在事务处理中,进行编辑操作以后,会执行缓存的清除操作,这是因为我们在ORM中所绑定的缓存自动清除操作,是在对应的执行方法中,而不是sql生成方法里,所以在进行事务时,如果你使用了缓存的方法,在这里就需要手动添加清除缓存操作,不然就会产生脏数据。

      执行结果:

     1 ------ini------
     2 select * from product_class  where id = 1
     3 ----产品分类实体----
     4 [{'add_time': datetime.datetime(2018, 8, 17, 16, 14, 54), 'id': 1, 'is_enable': 1, 'name': '饼干'}]
     5 -------------------
     6 update product_class set is_enable = 0  where id = 1 returning id , is_enable
     7 ----执行成功后的产品分类实体----
     8 [{'id': 1, 'is_enable': 0}]
     9 -------------------------------
    10 update product set is_enable = 0  where product_class_id=1 returning id , is_enable
    11 ----执行成功后的产品实体----
    12 [{'id': 2, 'is_enable': 0}, {'id': 7, 'is_enable': 0}, {'id': 14, 'is_enable': 0}, {'id': 15, 'is_enable': 0}]
    13 ---------------------------
    14 执行成功
    15 ------clear------

      本文对应的源码下载(一些接口进行了重构,有些还没有处理,所以源码可能直接运行不了,下一章节会讲到所有代码使用ORM模块重构内容)

     

    版权声明:本文原创发表于 博客园,作者为 AllEmpty 本文欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则视为侵权。

    python开发QQ群:669058475(本群已满)、733466321(可以加2群)    作者博客:http://www.cnblogs.com/EmptyFS/

     

  • 相关阅读:
    2018软工实践之团队选题报告
    2018软工实践作业五之结对作业2
    2018软工实践作业四之团队展示
    2018软工实践作业四之团队展示
    2018软工实践作业三
    职场老鸟项目经理多年感悟
    项目冲突管理
    项目变更如何控制
    项目管理基础
    成功项目管理与PMP认证2017
  • 原文地址:https://www.cnblogs.com/EmptyFS/p/9484686.html
Copyright © 2011-2022 走看看