zoukankan      html  css  js  c++  java
  • Python实现ORM

    ORM即把数据库中的一个数据表给映射到代码里的一个类上,表的字段对应着类的属性。将增删改查等基本操作封装为类对应的方法,从而写出更干净和更富有层次性的代码。

    以查询数据为例,原始的写法要Python代码sql混合,示例代码如下:

     1 import MySQLdb
     2 import os,sys
     3 
     4 def main():
     5     conn=MySQLdb.connect(host="localhost",port=3306,passwd='toor',user='root')
     6     conn.select_db("xdyweb")
     7     cursor=conn.cursor()
     8     count=cursor.execute("select * from users")
     9     result=cursor.fetchmany()
    10     print(isinstance(result,tuple))
    11     print(type(result))
    12     print(len(result))
    13     for i in result:
    14         print(i)
    15         for j in i:
    16             print(j)
    17     print("row count is %s"%count)
    18     cursor.close()
    19     conn.close()
    20 
    21 if __name__=="__main__":
    22     cp=os.path.abspath('.')
    23     sys.path.append(cp)
    24     main()
    View Code

    而我们现在想要实现的是类似这样的效果:

    1 #查找:
    2 u=user.get(id=1)
    3 #添加
    4 u=user(name='y',password='y',email='1@q.com')
    5 u.insert()
    View Code

    实现思路是遍历Model的属性,得出要操作的字段,然后根据不同的操作要求(增,删,改,查)去动态生成不同的sql语句。

      1 #coding:utf-8
      2 
      3 #author:xudongyang
      4 
      5 #19:25 2015/4/15
      6 
      7 import  logging,time,sys,os,threading
      8 import test as db
      9 # logging.basicConfig(level=logging.INFO,format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',datefmt='%a, %d %b %Y %H:%M:%S')
     10 logging.basicConfig(level=logging.INFO)
     11 
     12 class Field(object):
     13     #映射数据表中一个字段的属性,包括字段名称,默认值,是否主键,可空,可更新,可插入,字段类型(varchar,text,Integer之类),字段顺序
     14     _count=0#当前定义的字段是类的第几个字段
     15     def __init__(self,**kw):
     16         self.name = kw.get('name', None)
     17         self._default = kw.get('default', None)
     18         self.primary_key = kw.get('primary_key', False)
     19         self.nullable = kw.get('nullable', False)
     20         self.updatable = kw.get('updatable', True)
     21         self.insertable = kw.get('insertable', True)
     22         self.ddl = kw.get('ddl', '')
     23         self._order = Field._count
     24         Field._count = Field._count + 1
     25     @property
     26     def default(self):
     27         d = self._default
     28         return d() if callable(d) else d
     29 
     30 class StringField(Field):
     31     #继承自Field,
     32     def __init__(self, **kw):
     33         if not 'default' in kw:
     34             kw['default'] = ''
     35         if not 'ddl' in kw:
     36             kw['ddl'] = 'varchar(255)'
     37         super(StringField, self).__init__(**kw)
     38 
     39 class IntegerField(Field):
     40 
     41     def __init__(self, **kw):
     42         if not 'default' in kw:
     43             kw['default'] = 0
     44         if not 'ddl' in kw:
     45             kw['ddl'] = 'bigint'
     46         super(IntegerField, self).__init__(**kw)
     47 class FloatField(Field):
     48 
     49     def __init__(self, **kw):
     50         if not 'default' in kw:
     51             kw['default'] = 0.0
     52         if not 'ddl' in kw:
     53             kw['ddl'] = 'real'
     54         super(FloatField, self).__init__(**kw)
     55 
     56 class BooleanField(Field):
     57 
     58     def __init__(self, **kw):
     59         if not 'default' in kw:
     60             kw['default'] = False
     61         if not 'ddl' in kw:
     62             kw['ddl'] = 'bool'
     63         super(BooleanField, self).__init__(**kw)
     64 
     65 class TextField(Field):
     66 
     67     def __init__(self, **kw):
     68         if not 'default' in kw:
     69             kw['default'] = ''
     70         if not 'ddl' in kw:
     71             kw['ddl'] = 'text'
     72         super(TextField, self).__init__(**kw)
     73 
     74 class BlobField(Field):
     75 
     76     def __init__(self, **kw):
     77         if not 'default' in kw:
     78             kw['default'] = ''
     79         if not 'ddl' in kw:
     80             kw['ddl'] = 'blob'
     81         super(BlobField, self).__init__(**kw)
     82 
     83 class VersionField(Field):
     84 
     85     def __init__(self, name=None):
     86         super(VersionField, self).__init__(name=name, default=0, ddl='bigint')
     87 
     88 def _gen_sql(table_name, mappings):
     89     print(__name__+'is called'+str(time.time()))
     90     pk = None
     91     sql = ['-- generating SQL for %s:' % table_name, 'create table `%s` (' % table_name]
     92     for f in sorted(mappings.values(), lambda x, y: cmp(x._order, y._order)):
     93         if not hasattr(f, 'ddl'):
     94             raise StandardError('no ddl in field "%s".' % n)
     95         ddl = f.ddl
     96         nullable = f.nullable
     97         if f.primary_key:
     98             pk = f.name
     99         sql.append(nullable and '  `%s` %s,' % (f.name, ddl) or '  `%s` %s not null,' % (f.name, ddl))
    100     sql.append('  primary key(`%s`)' % pk)
    101     sql.append(');')
    102     sql='
    '.join(sql)
    103     logging.info('sql is :'+sql)
    104     return sql
    105 
    106 class ModelMetaClass(type):
    107     #为什么__new__方法会被调用两次
    108     #为什么attrs.pop(k)要进行这个,而且进行了之后u.name就可以输出yy而不是一个Field对象
    109     def __new__(cls,name,base,attrs):
    110         logging.info("cls is:"+str(cls))
    111         logging.info("name is:"+str(name))
    112         logging.info("base is:"+str(base))
    113         logging.info("attrs is:"+str(attrs))
    114         print('new is called at '+str(cls)+str(time.time()))
    115 
    116         if name =="Model":
    117             return type.__new__(cls,name,base,attrs)
    118         mapping=dict()
    119         primary_key=None
    120         for k,v in attrs.iteritems():
    121             primary_key=None
    122             if isinstance(v,Field):
    123                 if not v.name:
    124                     v.name=k
    125                 mapping[k]=v
    126                 #检测是否是主键
    127                 if v.primary_key:
    128                     if primary_key:
    129                         raise TypeError("There only should be on primary_key")
    130                     if v.updatable:
    131                         logging.warning('primary_key should not be changed')
    132                         v.updatable=False
    133                     if v.nullable:
    134                         logging.warning('pri.. not be.null')
    135                         v.nullable=False
    136                     primary_key=v
    137 
    138         for k in mapping.iterkeys():
    139             attrs.pop(k)
    140 
    141         attrs['__mappings__']=mapping
    142         logging.info('mapping is :'+str(mapping))
    143         attrs['__primary_key__']=primary_key
    144         attrs['__sql__']=lambda self: _gen_sql(attrs['__table__'], mapping)
    145         return type.__new__(cls,name,base,attrs)
    146 class ModelMetaclass(type):
    147     '''
    148     Metaclass for model objects.
    149     '''
    150     def __new__(cls, name, bases, attrs):
    151         # skip base Model class:
    152         if name=='Model':
    153             return type.__new__(cls, name, bases, attrs)
    154 
    155         # store all subclasses info:
    156         if not hasattr(cls, 'subclasses'):
    157             cls.subclasses = {}
    158         if not name in cls.subclasses:
    159             cls.subclasses[name] = name
    160         else:
    161             logging.warning('Redefine class: %s' % name)
    162 
    163         logging.info('Scan ORMapping %s...' % name)
    164         mappings = dict()
    165         primary_key = None
    166         for k, v in attrs.iteritems():
    167             if isinstance(v, Field):
    168                 if not v.name:
    169                     v.name = k
    170                 logging.info('Found mapping: %s => %s' % (k, v))
    171                 # check duplicate primary key:
    172                 if v.primary_key:
    173                     if primary_key:
    174                         raise TypeError('Cannot define more than 1 primary key in class: %s' % name)
    175                     if v.updatable:
    176                         logging.warning('NOTE: change primary key to non-updatable.')
    177                         v.updatable = False
    178                     if v.nullable:
    179                         logging.warning('NOTE: change primary key to non-nullable.')
    180                         v.nullable = False
    181                     primary_key = v
    182                 mappings[k] = v
    183         # check exist of primary key:
    184         if not primary_key:
    185             raise TypeError('Primary key not defined in class: %s' % name)
    186         for k in mappings.iterkeys():
    187             attrs.pop(k)
    188         if not '__table__' in attrs:
    189             attrs['__table__'] = name.lower()
    190         attrs['__mappings__'] = mappings
    191         attrs['__primary_key__'] = primary_key
    192         attrs['__sql__'] = lambda self: _gen_sql(attrs['__table__'], mappings)
    193         # for trigger in _triggers:
    194         #     if not trigger in attrs:
    195         #         attrs[trigger] = None
    196         return type.__new__(cls, name, bases, attrs)
    197 class Model(dict):
    198     __metaclass__ = ModelMetaClass
    199     def __init__(self, **kw):
    200         super(Model, self).__init__(**kw)
    201 
    202     def __getattr__(self, key):
    203         try:
    204             return self[key]
    205         except KeyError:
    206             raise AttributeError(r"'Dict' object has no attribute '%s'" % key)
    207 
    208     def __setattr__(self, key, value):
    209         self[key] = value
    210 
    211     @classmethod
    212     def get(cls, pk):
    213         '''
    214         Get by primary key.
    215         '''
    216         d = db.select_one('select * from %s where %s=?' % (cls.__table__, cls.__primary_key__.name), pk)
    217         return cls(**d) if d else None
    218 
    219     @classmethod
    220     def find_first(cls, where, *args):
    221         '''
    222         Find by where clause and return one result. If multiple results found,
    223         only the first one returned. If no result found, return None.
    224         '''
    225         d = db.select_one('select * from %s %s' % (cls.__table__, where), *args)
    226         return cls(**d) if d else None
    227 
    228     @classmethod
    229     def find_all(cls, *args):
    230         '''
    231         Find all and return list.
    232         '''
    233         L = db.select('select * from `%s`' % cls.__table__)
    234         return [cls(**d) for d in L]
    235 
    236     @classmethod
    237     def find_by(cls, where, *args):
    238         '''
    239         Find by where clause and return list.
    240         '''
    241         L = db.select('select * from `%s` %s' % (cls.__table__, where), *args)
    242         return [cls(**d) for d in L]
    243 
    244     @classmethod
    245     def count_all(cls):
    246         '''
    247         Find by 'select count(pk) from table' and return integer.
    248         '''
    249         return db.select_int('select count(`%s`) from `%s`' % (cls.__primary_key__.name, cls.__table__))
    250 
    251     @classmethod
    252     def count_by(cls, where, *args):
    253         '''
    254         Find by 'select count(pk) from table where ... ' and return int.
    255         '''
    256         return db.select_int('select count(`%s`) from `%s` %s' % (cls.__primary_key__.name, cls.__table__, where), *args)
    257 
    258     def update(self):
    259         self.pre_update and self.pre_update()
    260         L = []
    261         args = []
    262         for k, v in self.__mappings__.iteritems():
    263             if v.updatable:
    264                 if hasattr(self, k):
    265                     arg = getattr(self, k)
    266                 else:
    267                     arg = v.default
    268                     setattr(self, k, arg)
    269                 L.append('`%s`=?' % k)
    270                 args.append(arg)
    271         pk = self.__primary_key__.name
    272         args.append(getattr(self, pk))
    273         db.update('update `%s` set %s where %s=?' % (self.__table__, ','.join(L), pk), *args)
    274         return self
    275 
    276     def delete(self):
    277         self.pre_delete and self.pre_delete()
    278         pk = self.__primary_key__.name
    279         args = (getattr(self, pk), )
    280         db.update('delete from `%s` where `%s`=?' % (self.__table__, pk), *args)
    281         return self
    282 
    283     def insert(self):
    284         self.pre_insert and self.pre_insert()
    285         params = {}
    286         for k, v in self.__mappings__.iteritems():
    287             if v.insertable:
    288                 if not hasattr(self, k):
    289                     setattr(self, k, v.default)
    290                 params[v.name] = getattr(self, k)
    291         db.insert('%s' % self.__table__, **params)
    292         return self
    293 class user(Model):
    294     name=StringField(name='name',primary_key=True)
    295     password=StringField(name='password')
    296 
    297 def main():
    298     u=user(name='yy',password='yyp')
    299 
    300     logging.info(u.__sql__)
    301     logging.info(dir(u.__mappings__.values()))
    302     u.password='xxx'
    303     print(u.password)
    304 
    305 if __name__ == '__main__':
    306     main()
    View Code

    要注意的是遍历Model属性这部分代码,利用了Python的__metaclass__实现,截断了Model的创建过程,进而对Model的属性进行遍历,具体代码见ModelMetaclass的__new__方法实现。

    这是模仿廖老师的代码,[http://liaoxuefeng.com],感谢。还有两个疑问注释在了代码中,希望有看明白的人解惑。
      

  • 相关阅读:
    一些tips
    微信小程序之后端处理
    微信小程序之前端代码篇
    微信小程序踩坑之前端问题处理篇
    Vue组件封装之一键复制文本到剪贴板
    读别人的代码之bug的发现
    解析webpack插件html-webpack-plugin
    数组去重方法整理
    如何理解EventLoop--浏览器篇
    axios和vue-axios的关系
  • 原文地址:https://www.cnblogs.com/cncyber/p/4433301.html
Copyright © 2011-2022 走看看