zoukankan      html  css  js  c++  java
  • 手写ORM

    一 前言

    1 我在实例化一个user对象的时候,可以user=User(name='lqz',password='123')

    2 也可以 user=User()

        user['name']='lqz'
        user['password']='123'
    3 也可以 user=User()

        user.name='lqz'
        user.password='password'
    前两种,可以通过继承字典dict来实现,第三种,用getattr和setattr

        __getattr__ 拦截点号运算。当对未定义的属性名称和实例进行点号运算时,就会用属性名作为字符串调用这个方法。如果继承树可以找到该属性,则不调用此方法
        __setattr__会拦截所有属性的的赋值语句。如果定义了这个方法,self.arrt = value 就会变成self,__setattr__("attr", value).这个需要注意。当在__setattr__方法内对属性进行赋值是,不可使用self.attr = value,因为他会再次调用self,__setattr__("attr", value),则会形成无穷递归循环,最后导致堆栈溢出异常。应该通过对属性字典做索引运算来赋值任何实例属性,也就是使用self.__dict__['name'] = value

    二 定义Model基类

    复制代码
    class Model(dict):
        def __init__(self, **kw):
            super(Model, self).__init__(**kw)
    
        def __getattr__(self, key):# .访问属性触发
            try:
                return self[key]
            except KeyError:
                raise AttributeError('没有属性:%s' % key)
    
        def __setattr__(self, key, value):
            self[key] = value
    复制代码

    三 定义Field

    数据库中每一列数据,都有:列名,列的数据类型,是否是主键,默认值

    复制代码
    class Field:
        def __init__(self, name, column_type, primary_key, default_value):
            self.name = name
            self.column_type = column_type
            self.primary_key = primary_key
            self.default_value = default_value
    
    
    class StringField(Field):
        def __init__(self, name, column_type='varchar(100)', primary_key=False, default_value=None):
            super().__init__(name, column_type, primary_key, default_value)
    
    
    class IntegerField(Field):
        def __init__(self, name, primary_key=False, default_value=0):
            super().__init__(name, 'int', primary_key, default_value)
    复制代码

    四 定义元类

    元类复习

    数据库中的每个表,都有表名,每一列的列名,以及主键是哪一列

    既然我要用数据库中的表,对应这一个程序中的类,那么我这个类也应该有这些类属性

    但是不同的类这些类属性又不尽相同,所以我应该怎么做?在元类里拦截类的创建过程,然后把这些东西取出来,放到类里面

    所以用到了元类

    复制代码
    class ModelMetaclass(type):
        def __new__(cls, name, bases,attrs):
            if name=='Model':
                return type.__new__(cls,name,bases,attrs)
            table_name=attrs.get('table_name',None)
            if not table:
                table_name=name
            primary_key=None
            mappings=dict()
            for k,v in attrs.items():
                if isinstance(v,Field):#v 是不是Field的对象
                    mappings[k]=v
                    if v.primary_key:
                        #找到主键
                        if primary_key:
                            raise TypeError('主键重复:%s'%k)
                        primary_key=k
    
            for k in mappings.keys():
                attrs.pop(k)
            if not primary_key:
                raise TypeError('没有主键')
            attrs['table_name']=table_name
            attrs['primary_key']=primary_key
            attrs['mappings']=mappingsreturn type.__new__(cls,name,bases,attrs)
    复制代码

    五 继续Model基类

    Model类是所有要对应数据库表类的基类,所以,Model的元类应该是咱上面写的那个,而每个数据库表对应类的对象,都应该有查询,插入,保存,方法

    所以:

    复制代码
    class Model(dict, metaclass=ModelMetaclass):
        def __init__(self, **kw):
            super(Model, self).__init__(**kw)
    
        def __getattr__(self, key):  # .访问属性触发
            try:
                return self[key]
            except KeyError:
                raise AttributeError('没有属性:%s' % key)
    
        def __setattr__(self, key, value):
            self[key] = value
    
        @classmethod
        def select_all(cls, **kwargs):
            ms = mysql_singleton.Mysql().singleton()
            if kwargs:  # 当有参数传入的时候
                key = list(kwargs.keys())[0]
                value = kwargs[key]
                sql = "select * from %s where %s=?" % (cls.table_name, key)
                sql = sql.replace('?', '%s')
                re = ms.select(sql, value)
            else:  # 当无参传入的时候查询所有
                sql = "select * from %s" % cls.table_name
                re = ms.select(sql)
            return [cls(**r) for r in re]
    
        @classmethod
        def select_one(cls, **kwargs):
            # 此处只支持单一条件查询
            key = list(kwargs.keys())[0]
            value = kwargs[key]
            ms = mysql_singleton.Mysql().singleton()
            sql = "select * from %s where %s=?" % (cls.table_name, key)
    
            sql = sql.replace('?', '%s')
            re = ms.select(sql, value)
            if re:
                return cls(**re[0])
            else:
                return None
    
        def save(self):
            ms = mysql_singleton.Mysql().singleton()
            fields = []
            params = []
            args = []
            for k, v in self.mapping.items():
                fields.append(v.name)
                params.append('?')
                args.append(getattr(self, k, v.default))
            sql = "insert into %s (%s) values (%s)" % (self.table_name, ','.join(fields), ','.join(params))
            sql = sql.replace('?', '%s')
            ms.execute(sql, args)
    
        def update(self):
            ms = mysql_singleton.Mysql().singleton()
            fields = []
            args = []
            pr = None
            for k, v in self.mapping.items():
                if v.primary_key:
                    pr = getattr(self, k, v.default)
                else:
                    fields.append(v.name + '=?')
                    args.append(getattr(self, k, v.default))
            sql = "update %s set %s where %s = %s" % (
                self.table_name, ', '.join(fields), self.primary_key, pr)
    
            sql = sql.replace('?', '%s')
            print(sql)
            ms.execute(sql, args)
    复制代码

    六 基于pymsql的数据库操作类(单例)

    复制代码
    from conf import setting
    import pymysql
    
    
    class Mysql:
        __instance = None
        def __init__(self):
            self.conn = pymysql.connect(host=setting.host,
                                        user=setting.user,
                                        password=setting.password,
                                        database=setting.database,
                                        charset=setting.charset,
                                        autocommit=setting.autocommit)
            self.cursor = self.conn.cursor(cursor=pymysql.cursors.DictCursor)
    
        def close_db(self):
            self.conn.close()
    
        def select(self, sql, args=None):
            self.cursor.execute(sql, args)
            rs = self.cursor.fetchall()
            return rs
    
        def execute(self, sql, args):
            try:
                self.cursor.execute(sql, args)
                affected = self.cursor.rowcount
                # self.conn.commit()
            except BaseException as e:
                print(e)
            return affected
    
        @classmethod
        def singleton(cls):
            if not cls.__instance:
                cls.__instance = cls()
            return cls.__instance
    复制代码

    七 数据库连接池版的数据库操作类

    在此之前,要先学习数据库链接池:链接

    db_pool.py

    复制代码
    import pymysql
    from conf import setting
    from DBUtils.PooledDB import PooledDB
    
    POOL = PooledDB(
        creator=pymysql,  # 使用链接数据库的模块
        maxconnections=6,  # 连接池允许的最大连接数,0和None表示不限制连接数
        mincached=6,  # 初始化时,链接池中至少创建的空闲的链接,0表示不创建
        maxcached=5,  # 链接池中最多闲置的链接,0和None不限制
        maxshared=3,
        # 链接池中最多共享的链接数量,0和None表示全部共享。
        blocking=True,  # 连接池中如果没有可用连接后,是否阻塞等待。True,等待;False,不等待然后报错
        maxusage=None,  # 一个链接最多被重复使用的次数,None表示无限制
        setsession=[],  # 开始会话前执行的命令列表。
        ping=0,
        # ping MySQL服务端,检查是否服务可用。
    
        host=setting.host,
        port=setting.port,
        user=setting.user,
        password=setting.password,
        database=setting.database,
        charset=setting.charset,
        autocommit=setting.autocommit
    )
    复制代码

     mysql_pool.py

    复制代码
    import pymysql
    from ormpool import db_pool
    from threading import current_thread
    
    
    class MysqlPool:
        def __init__(self):
            self.conn = db_pool.POOL.connection()
            # print(db_pool.POOL)
            # print(current_thread().getName(), '拿到连接', self.conn)
            # print(current_thread().getName(), '池子里目前有', db_pool.POOL._idle_cache, '
    ')
            self.cursor = self.conn.cursor(cursor=pymysql.cursors.DictCursor)
    
        def close_db(self):
            self.cursor.close()
            self.conn.close()
    
        def select(self, sql, args=None):
            self.cursor.execute(sql, args)
            rs = self.cursor.fetchall()
            return rs
    
        def execute(self, sql, args):
    
            try:
                self.cursor.execute(sql, args)
                affected = self.cursor.rowcount
                # self.conn.commit()
            except BaseException as e:
                print(e)
            finally:
                self.close_db()
            return affected
    复制代码

    setting.py

    复制代码
    host = '127.0.0.1'
    port = 3306
    user = 'root'
    password = '123456'
    database = 'youku2'
    charset = 'utf8'
    autocommit = True
    复制代码
  • 相关阅读:
    微信小程序HTTPS
    微信商城-1简介
    va_list
    Event log c++ sample.
    EVENT LOGGING
    Analyze Program Runtime Stack
    unknow table alarmtemp error when drop database (mysql)
    This application has request the Runtime to terminate it in an unusual way.
    How to check if Visual Studio 2005 SP1 is installed
    SetUnhandledExceptionFilter
  • 原文地址:https://www.cnblogs.com/bubu99/p/14774467.html
Copyright © 2011-2022 走看看