zoukankan      html  css  js  c++  java
  • 15、编写ORM

    概述

    在一个Web App中,所有数据,包括用户信息、发布的日志、评论等,都存储在数据库中。

    Web App里面有很多地方都要访问数据库。访问数据库需要创建数据库连接、游标对象,然后执行SQL语句,最后处理异常,清理资源。这些访问数据库的代码如果分散到各个函数中,势必无法维护,也不利于代码复用。

    所以,我们要首先把常用的SELECT、INSERT、UPDATE和DELETE操作用函数封装起来。

    由于Web框架使用了基于asyncio的aiohttp,这是基于协程的异步模型。在协程中,不能调用普通的同步IO操作,因为所有用户都是由一个线程服务的,协程的执行速度必须非常快,才能处理大量用户的请求。而耗时的IO操作不能在协程中以同步的方式调用,否则,等待一个IO操作时,系统无法响应任何其他用户。

    这就是异步编程的一个原则:一旦决定使用异步,则系统每一层都必须是异步,“开弓没有回头箭”。

    幸运的是aiomysql为MySQL数据库提供了异步IO的驱动。

     简单的orm实现技术原理可参考先前写的博文:12、元类(metaclass)实现精简ORM框架

    一、创建连接池

    我们需要创建一个全局的连接池,每个HTTP请求都可以从连接池中直接获取数据库连接。使用连接池的好处是不必频繁地打开和关闭数据库连接,而是能复用就尽量复用。

    连接池由全局变量__pool存储,缺省情况下将编码设置为utf8,自动提交事务:

    @asyncio.coroutine
    def create_pool(loop, **kwargs):
        logging.info('create database connection pool...')
        global __pool 
        __pool = yield from aiomysql.create_pool(
            host=kwargs.get('host', 'localhost'),
            port=kwargs.get('port', 3306),
            user=kwargs['user'],
            password=kwargs['password'],
            db=kwargs['db'],
            charset=kwargs.get('charset', 'utf8'),  
            autocommit=kwargs.get('autocommit', True), 
            maxsize=kwargs.get('maxsize', 10),
            minsize=kwargs.get('minsize', 1),
            loop=loop
        )

     关于aiomysql.create_pool的详细讲述,请参考博文:16、【翻译】aiomysql-Pool

     create_pool方法中的kwargs是关键字参数,保存着连接数据库所必须的host、port、user、password等信息,这些关键字参数在函数内部自动组装为一个dict。

    二、封装select语句

    # 该协程封装的是查询事务,第一个参数为sql语句,第二个为sql语句中占位符的参数列表,第三个参数是要查询数据的数量
    @asyncio.coroutine
    def select(sql, args, size=None):
        log(sql, args)  #显示sql语句和参数
        global __pool   #引用全局变量
        with (yield from __pool) as conn:   # 以上下文方式打开conn连接,无需再调用conn.close()  或写成 with await __pool as conn:
            cur = yield from conn.cursor(aiomysql.DictCursor)   # 创建一个DictCursor类指针,返回dict形式的结果集
            yield from cur.execute(sql.replace('?', '%s'), args or ())  # 替换占位符,SQL语句占位符为?,MySQL为%s。
            if size:
                rs = yield from cur.fetchmany(size) #接收size条返回结果行.
            else:
                rs = yield from cur.fetchall()  #接收全部的返回结果行.
            yield from cur.close()  #关闭游标
            logging.info('rows returned: %s' % len(rs)) #打印返回结果行数
            return rs   #返回结果

    SQL语句的占位符是?,而MySQL的占位符是%sselect()函数在内部自动替换。注意要始终坚持使用带参数的SQL,而不是自己拼接SQL字符串,这样可以防止SQL注入攻击。

    注意到yield from将调用一个子协程(也就是在一个协程中调用另一个协程)并直接获得子协程的返回结果。

    如果传入size参数,就通过fetchmany()获取最多指定数量的记录,否则,通过fetchall()获取所有记录。

    三、封装INSERT、UPDATE、DELETE语句

    #执行update,insert,delete语句,可以统一用一个execute函数执行,
    # 因为它们所需参数都一样,而且都只返回一个整数表示影响的行数。
    @asyncio.coroutine
    def execute(sql, args, autocommit=True):
        log(sql)
        with (yield from __pool) as conn:
            if not autocommit:
                yield from conn.begin()
            try:
                 cur = yield from conn.cursor()
                 yield from cur.execute(sql.replace('?', '%s'), args)
                 affected = cur.rowcount
                 yield from cur.close()
                 if not autocommit:
                     yield from conn.commit()
            except BaseException as e:  #如果事务处理出现错误,则回退
                if not autocommit:
                    yield from conn.rollback()
                raise
            return affected

    execute()函数和select()函数所不同的是,cursor对象不返回结果集,而是通过rowcount返回结果数。

    四、ORM

    设计ORM需要从上层调用者角度来设计。

    我们先考虑如何定义一个User对象,然后把数据库表users和它关联起来。

    from orm import Model, StringField, IntegerField
    
    class User(Model):
        __table__ = 'users'
    
        id = IntegerField(primary_key=True)
        name = StringField()

    注意到定义在User类中的__table__idname是类的属性,不是实例的属性,类的所有示例都可以访问!!!所以,在类级别上定义的属性用来描述User对象和表的映射关系,而实例属性用来描述数据库表中的一行数据,必须通过__init__()方法去初始化,所以两者互不干扰:

    # 创建实例:
    user = User(id=123, name='Michael')
    # 存入数据库:
    user.insert()
    # 查询所有User对象:
    users = User.findAll()

    五、Field以及各种Field子类

    用来描述数据库中表字段的属性(字段名、类型、是否主键等等)。

    首先定义基类Field:

    class Field(object):
    
        def __init__(self, name, column_type, primary_key, default):
            self.name = name
            self.column_type = column_type
            self.primary_key = primary_key
            self.default = default
    
        def __str__(self):
            return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)

    __str__()是Python中有特殊用途的函数,用来定制类。当我们print(Field或Field子类对象)时,会打印该对象(字段)的类名,字段类别以及字段名称。

     然后在Field的基础上,进一步定义各种类型的Field:

    # 字符串类型字段,继承自父类Field
    class StringField(Field):
        #如果一个函数的参数中含有默认参数,则这个默认参数后的所有参数都必须是默认参数 ,
        # 否则会抛出:SyntaxError: non-default argument follows default argument的异常。
        def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):
            super(StringField, self).__init__(name, ddl, primary_key, default)
    
    # 布尔值类型字段,继承自父类Field
    class BooleanField(Field):
    
        def __init__(self, name=None, default=False):
            super(BooleanField, self).__init__(name, 'boolean', False, default)
    
    # 整数类型字段,继承自父类Field
    class IntegerField(Field):
    
        def __init__(self, name=None, primary_key=False, default=0):
            super(IntegerField, self).__init__(name, 'bigint', primary_key, default)
    
    # 浮点数类型字段,继承自父类Field
    class FloatField(Field):
    
        def __init__(self, name=None, primary_key=False, default=0.0):
            super(FloatField, self).__init__(name, 'real', primary_key, default)
    
    # 文本类型字段,继承自父类Field
    class TextField(Field):
    
        def __init__(self, name=None, default=None):
            super(TextField, self).__init__(name, 'text', False, default)

    上述子类生成对象时,均会调用父类的Init方法初始化。

    可见,数据库表字段共4个属性:字段名、字段类型、是否主键、默认值。

    六、编写元类—ModelMetaclass

     1 class ModelMetaclass(type):
     2 
     3     def __new__(cls, name, bases, attrs):
     4         # 排除Model类本身:
     5         if name=='Model':
     6             return type.__new__(cls, name, bases, attrs)
     7         # 获取table名称:
     8         tableName = attrs.get('__table__', None) or name
     9         logging.info('found model: %s (table: %s)' % (name, tableName))
    10         # 获取所有的Field和主键名:
    11         mappings = dict()
    12         fields = []
    13         primaryKey = None
    14         for k, v in attrs.items():
    15             if isinstance(v, Field):
    16                 logging.info('  found mapping: %s ==> %s' % (k, v))
    17                 mappings[k] = v
    18                 if v.primary_key:
    19                     # 找到主键:
    20                     if primaryKey:
    21                         raise RuntimeError('Duplicate primary key for field: %s' % k)
    22                     primaryKey = k
    23                 else:
    24                     fields.append(k)
    25         if not primaryKey:
    26             raise RuntimeError('Primary')
    27         for k in mappings.keys():
    28             attrs.pop(k)
    29         escaped_fields = list(map(lambda f: '`%s`' % f, fields))
    30         attrs['__mappings__'] = mappings    # 保存属性和列的映射关系
    31         attrs['__table__'] = tableName
    32         attrs['__primary_key__'] = primaryKey   # 主键属性名
    33         attrs['__fields__'] = fields    # 除主键外的属性名
    34         # 构造默认的SELECT, INSERT, UPDATE和DELETE语句:
    35         ##以下四种方法保存了默认了增删改查操作,其中添加的反引号``,是为了避免与sql关键字冲突的,否则sql语句会执行出错
    36         attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)
    37         attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
    38         attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
    39         attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
    40         return type.__new__(cls, name, bases, attrs)

    1、首先进行判断,如果将要创建的类是Model,无需做个性化定制,直接通过type创建,排除对Model类的修改;

    2、获取table名称,即类名;

    3、mappings保存类属性和表字段的映射关系primaryKey保存映射表主键的类属性;fields保存映射其余表字段的类属性;

    4、注意匿名函数【 lambda f: '`%s`' % f, fields 】 的用法,实际上第29行代码就是这个意思: 

    fields = ['one', 'two', 'three']
    
    def fun(f):
        return '`%s`' % f
    
    escaped_fields = list(map(fun, fields))

     使用匿名函数lambda,针对fields中每个元素,如 name,加上反引号后:`name`后返回;

    为何要加上反引号?它是为了区分MYSQL的保留字与普通字符而引入的符号。

     5、下面是一系列为定制类动态添加的属性:

    (1) attrs['__mappings__'] = mappings    -》 保存类属性和表字段的映射关系;

    (2)attrs['__table__'] = tableName    -》 保存该类对应的表名;

    (3)attrs['__primary_key__'] = primaryKey    -》 保存映射表中主键字段的类属性;

    (4)attrs['__fields__'] = fields    -》  保存映射非主键字段的类属性;

    接着是SQL语句模板,届时调用时只需要将参数传递给Mysql占位符?即可:

    (5)attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)

      例:

    (6)attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))

      例:

    (7)attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)

      例:

    (8)attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)

       例:

    6、在模块加载时使用type动态地定制类。

    7、注意:

    (1)以上属性都是类的属性,属类所有,所有实例对象共享一个类属性。实例属性属于各个实例所有,互不干扰;在编写程序的时候,千万不要对实例属性和类属性使用相同的名字,因为相同名称的实例属性将屏蔽掉类属性,但是当你删除实例属性后,再使用相同的名称,访问到的将是类属性。

    (2)表的字段名使用类属性名,即名字相同!

    七、编写基类——Model

     1 class Model(dict, metaclass=ModelMetaclass):
     2 
     3     def __init__(self, **kwargs):
     4         super(Model, self).__init__(**kwargs)
     5 
     6     def __getattr__(self, key):
     7         try:
     8             return self[key]
     9         except KeyError:
    10             raise AttributeError(r"'Model' object has no attribute '%s'" % key)
    11 
    12     def __setattr__(self, key, value):
    13         self[key] = value
    14 
    15     def getValue(self, key):
    16         return getattr(self, key, None)
    17 
    18     def getValueOrDefault(self, key):
    19         value = getattr(self, key, None)
    20         if value is None:
    21             field = self.__mappings__[key]
    22             if field.default is not None:
    23                 value = field.default() if callable(field.default) else field.default
    24                 logging.debug('using default value for %s: %s' % (key, str(value)))
    25                 setattr(self, key, value)
    26         return value
    27 
    28     @classmethod
    29     @asyncio.coroutine
    30     def findAll(cls, where=None, args=None, **kwargs):
    31         'find objects by where clause'
    32         sql = [cls.__select__]  #sql是list类型,元素是定制类的类属性——select查询语句模板
    33         if where:
    34             sql.append('where')
    35             sql.append(where)
    36         if args is None:
    37             args = []
    38         orderBy = kwargs.get('orderBy', None)
    39         if orderBy:
    40             sql.append('order by')
    41             sql.append(orderBy)
    42         limit = kwargs.get('limit', None)
    43         if limit is not None:
    44             sql.append('limit')
    45             if isinstance(limit, int):
    46                 sql.append('?')
    47                 args.append(limit)
    48             elif isinstance(limit, tuple) and len(limit) == 2:
    49                 sql.append('?, ?')
    50                 args.extend(limit)
    51             else:
    52                 raise ValueError('Invalid limit value: %s' % str(limit))
    53         rs = yield from select(' '.join(sql), args)  #传入sql语句及参数,调用select语句获取查询结果
    54         return [cls(**r) for r in rs]
    55 
    56     @classmethod
    57     @asyncio.coroutine
    58     def findNumber(cls, selectField, where=None, args=None):
    59         'find number by select and where.'
    60         sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)]
    61         if where:
    62             sql.append('where')
    63             sql.append(where)
    64             rs = yield from select(' '.join(sql), args, 1)
    65             if len(rs) == 0:
    66                 return None
    67             return rs[0]['_num_']
    68 
    69     @classmethod
    70     @asyncio.coroutine
    71     def find(cls, pk):
    72         'find object by primary key.'
    73         rs = yield from select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)
    74         if len(rs) == 0:
    75             return None
    76         return cls(**rs[0])
    77 
    78     @asyncio.coroutine
    79     def save(self):
    80         args = list(map(self.getValueOrDefault, self.__fields__))
    81         args.append(self.getValueOrDefault(self.__primary_key__))
    82         rows = yield from execute(self.__insert__, args)
    83         if rows != 1:
    84             logging.warn('failed to insert record: affected rows: %s' % rows)
    85 
    86     @asyncio.coroutine
    87     def update(self):
    88         args = list(map(self.getValue, self.__fields__))
    89         args.append(self.getValue(self.__primary_key__))
    90         rows = yield from execute(self.__update__, args)
    91         if rows != 1:
    92             logging.warn('failed to update by primary key: affected rows: %s' % rows)
    93 
    94     @asyncio.coroutine
    95     def remove(self):
    96         args = [self.getValue(self.__primary_key__)]
    97         rows = yield from execute(self.__delete__, args)
    98         if rows != 1:
    99             logging.warn('failed to remove by primary key: affected rows: %s' % rows)

    1、__getattr__为内置方法,当使用点号获取实例属性,例如 stu.score 时,如果属性score不存在就自动调用__getattr__方法。注意:已有的属性,比如name,不会在__getattr__中查找;

    2、__setattr__当设置实例属性时自动调用,如 stu.score=5时,就会调用__setattr__方法  self.[score]=5;

    3、getValueOrDefault()  ->   获取属性值,如果为空,则取默认值;

    4、@classmethod装饰的方法是类方法,直接使用类名调用,所有子类都可以调用类方法。不需要实例化,不需要 self 参数,第一个参数是表示自身类的 cls 参数。

    5、分析findAll() 方法:

    (1)第53行语句: rs = yield from select(' '.join(sql), args),调试可见返回结果是list类型,元素是dict类型的每行表数据:

    (2)第54行语句: return [cls(**r) for r in rs],不太能理解,故编写语句 result = User.findAll() 来将返回值保存在result参数中,调试可得:

    由此可得出结论,[cls(**r) for r in rs] 是将查询数据库表得到的每行结果,生成cls类的对象。

    6、分析save() 方法

    我们编写下列语句调用save方法:

    @asyncio.coroutine
        def save(self):
            args = list(map(self.getValueOrDefault, self.__fields__))
            args.append(self.getValueOrDefault(self.__primary_key__))
            rows = yield from execute(self.__insert__, args)
            if rows != 1:
                logging.warn('failed to insert record: affected rows: %s' % rows)

    经分析可得,获取调用save函数的类属性__friends__及__primary_key__的值,有默认值就传入默认值,作为sql语句参数执行。

     八、定义映射数据库表的类

    def next_id():
        return '%015d%s000' % (int(time.time() * 1000), uuid.uuid4().hex)
    
    class User(Model):
        __table__ = 'users'
        id = StringField(primary_key=True, default=next_id, ddl='varchar(50)')
        email = StringField(ddl='varchar(50)')
        passwd = StringField(ddl='varchar(50)')
        admin = BooleanField()
        name = StringField(ddl='varchar(50)')
        image = StringField(ddl='varchar(500)')
        created_at = FloatField(default=time.time)

    九、编写测试代码

    import orm
    from models import User, Blog, Comment
    import asyncio
    
    loop = asyncio.get_event_loop()
    
    async def test():
        # 创建连接池,里面的host,port,user,password需要替换为自己数据库的信息
        await orm.create_pool(loop=loop, host='127.0.0.1', port=3306, user='root', password='root', db='awesome')
        # 没有设置默认值的一个都不能少
        u = User(name='Test', email='547280745@qq.com', passwd='1234567890', image='about:blank', id="123")
        await u.save()
        result = await User.findAll()
    
    loop.run_until_complete(test())

    在Mysql数据中查询结果可知导入数据成功:

  • 相关阅读:
    Leetcode 16.25 LRU缓存 哈希表与双向链表的组合
    Leetcode437 路径总和 III 双递归与前缀和
    leetcode 0404 二叉树检查平衡性 DFS
    Leetcode 1219 黄金矿工 暴力回溯
    Leetcode1218 最长定差子序列 哈希表优化DP
    Leetcode 91 解码方法
    Leetcode 129 求根到叶子节点数字之和 DFS优化
    Leetcode 125 验证回文串 双指针
    Docker安装Mysql记录
    vmware虚拟机---Liunx配置静态IP
  • 原文地址:https://www.cnblogs.com/zwb8848happy/p/8799044.html
Copyright © 2011-2022 走看看