zoukankan      html  css  js  c++  java
  • 基于Torndb的简易ORM

    ============================================================================

    原创作品,同意转载。

    转载时请务必以超链接形式标明原始出处、以及本声明。

    请注明转自:http://yunjianfei.iteye.com/blog/

    ============================================================================

      


    近期在用tornado写一个基于Rest的WebService服务端,仅仅提供后端服务,其它webserver应用通过URL,Rest的方式来訪问。

     

         我们在开发web应用的时候。难免会想到ORM的一些框架。比方java ee中经常使用的hibernate, ibatis以及python中的SQLAlchemy之类。

    使用ORM会在一定程度上加快我们的开发效率。

     

          一个简易ORM框架主要实现例如以下几个功能就足够了:

         1.插入: 类对象映射为数据库记录

         2.查询:数据库记录映射为类对象

         3.改动、删除:能够通过自己写sql语句来搞定。

     

          python中有类,同一时候也有dict字典类型。假设将字典再包装为类。则显得过渡包装了,反倒非常不灵活,因此,提炼一下,python的ORM框架仅仅须要实现例如以下几点就足够:

         1.插入: python的dict映射为数据库记录

         2.查询:数据库记录映射为python的dict以及list等

         3.改动、删除:能够通过自己写sql语句来搞定。

        

          经过一些測试,技术选型,终于确定了使用tornadb。很轻量级,查询数据库返回的对象直接映射为python的数据类型dict或者list之类。能够用类似java中“对象.属性”的方式来訪问数据。

    这简直是太爽了~首先,看一个小样例。


    import types
    import time
    
    class Row(dict):
        """A dict that allows for object-like property access syntax."""
        def __getattr__(self, name):
            try:
                return self[name]
            except KeyError:
                raise AttributeError(name)
    
    dic = Row()
    dic.name = 'hello'
    dic.num = '12334'
    print type(dic)
    print "dic.name: " + dic.name
    print "dic.num: " + dic.num

    输出结果为:

     

    <class '__main__.Row'>
    dic.name: hello
    dic.num: 12334

     

    通过这个样例。我们能够看到,python里面的dict类型。是能够变成类似java中“对象.属性”的方式来訪问的。

    torndb就是通过这种方式,查询返回的数据能够通过“.列名”来直接訪问。

     

    查询的时候直接返回dict或者list类型,那插入呢?假设能够像java一样,传入一个对象,通过ORM框架直接反射为sql操作。这样多方便啊~

     

    还是dict,假设我们插入的时候。直接将插入的数据格式保存为dict,通过这个dict生成insert语句就能够了,经过查阅各种资料。我提炼出来了例如以下方法:(使用的时候直接将该方法放入torndb.py中就可以)

        def insert_by_dict(self, tablename, rowdict, replace=False):
            cursor = self._cursor()
            cursor.execute("describe %s" % tablename)
            allowed_keys = set(row[0] for row in cursor.fetchall())
            keys = allowed_keys.intersection(rowdict)
    
            if len(rowdict) > len(keys):
                unknown_keys = set(rowdict) - allowed_keys
                logging.error("skipping keys: %s", ", ".join(unknown_keys))
    
            columns = ", ".join(keys)
            values_template = ", ".join(["%s"] * len(keys))
    
            if replace:
                sql = "REPLACE INTO %s (%s) VALUES (%s)" % (
                    tablename, columns, values_template)
            else:
                sql = "INSERT INTO %s (%s) VALUES (%s)" % (
                    tablename, columns, values_template)
    
            values = tuple(rowdict[key] for key in keys)
            try:
                cursor.execute(sql, values)
                #self._execute(cursor, sql, values, None)
                return cursor.lastrowid
            finally:
                cursor.close()

    这样,插入的时候我们就再也不用写繁琐的sql语句了,仅仅须要将我们要插入的对象使用dict封装。比方:

    有个host表,里面有hostname,ip两个字段,则我们能够用例如以下几行代码,就能够插入到数据库:

        host = {}
        host['hostname'] = 'test1'
        host['ip'] = '10.22.10.90'
        ret = db.insert_by_dict("Host", host)

    是不是非常方便呢?:)以下是我改动过后,完整的torndb源代码。欢迎大家多多下载使用。


    #!/usr/bin/env python
    #
    # Copyright 2009 Facebook
    #
    # Licensed under the Apache License, Version 2.0 (the "License"); you may
    # not use this file except in compliance with the License. You may obtain
    # a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
    # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
    # License for the specific language governing permissions and limitations
    # under the License.
    
    """A lightweight wrapper around MySQLdb.
    
    Originally part of the Tornado framework.  The tornado.database module
    is slated for removal in Tornado 3.0, and it is now available separately
    as torndb.
    """
    
    from __future__ import absolute_import, division, with_statement
    
    import copy
    import itertools
    import logging
    import os
    import time
    
    try:
        import MySQLdb.constants
        import MySQLdb.converters
        import MySQLdb.cursors
    except ImportError:
        # If MySQLdb isn't available this module won't actually be useable,
        # but we want it to at least be importable on readthedocs.org,
        # which has limitations on third-party modules.
        if 'READTHEDOCS' in os.environ:
            MySQLdb = None
        else:
            raise
    
    version = "0.2"
    version_info = (0, 2, 0, 0)
    
    class Connection(object):
        """A lightweight wrapper around MySQLdb DB-API connections.
    
        The main value we provide is wrapping rows in a dict/object so that
        columns can be accessed by name. Typical usage::
    
            db = torndb.Connection("localhost", "mydatabase")
            for article in db.query("SELECT * FROM articles"):
                print article.title
    
        Cursors are hidden by the implementation, but other than that, the methods
        are very similar to the DB-API.
    
        We explicitly set the timezone to UTC and assume the character encoding to
        UTF-8 (can be changed) on all connections to avoid time zone and encoding errors.
    
        The sql_mode parameter is set by default to "traditional", which "gives an error instead of a warning"
        (http://dev.mysql.com/doc/refman/5.0/en/server-sql-mode.html). However, it can be set to
        any other mode including blank (None) thereby explicitly clearing the SQL mode.
        """
        def __init__(self, host, database, user=None, password=None,
                     max_idle_time=7 * 3600, connect_timeout=0, 
                     time_zone="+0:00", charset = "utf8", sql_mode="TRADITIONAL"):
            self.host = host
            self.database = database
            self.max_idle_time = float(max_idle_time)
    
            args = dict(conv=CONVERSIONS, use_unicode=True, charset=charset,
                        db=database, init_command=('SET time_zone = "%s"' % time_zone),
                        connect_timeout=connect_timeout, sql_mode=sql_mode)
            if user is not None:
                args["user"] = user
            if password is not None:
                args["passwd"] = password
    
            # We accept a path to a MySQL socket file or a host(:port) string
            if "/" in host:
                args["unix_socket"] = host
            else:
                self.socket = None
                pair = host.split(":")
                if len(pair) == 2:
                    args["host"] = pair[0]
                    args["port"] = int(pair[1])
                else:
                    args["host"] = host
                    args["port"] = 3306
    
            self._db = None
            self._db_args = args
            self._last_use_time = time.time()
            try:
                self.reconnect()
            except Exception:
                logging.error("Cannot connect to MySQL on %s", self.host,
                              exc_info=True)
    
        def __del__(self):
            self.close()
    
        def close(self):
            """Closes this database connection."""
            if getattr(self, "_db", None) is not None:
                self._db.close()
                self._db = None
    
        def reconnect(self):
            """Closes the existing database connection and re-opens it."""
            self.close()
            self._db = MySQLdb.connect(**self._db_args)
            self._db.autocommit(True)
    
        def initClientEncode(self):  
            """mysql client encoding=utf8"""
            curs = self._cursor()
            curs.execute("SET NAMES utf8")
            return curs
    
        def iter(self, query, *parameters, **kwparameters):
            """Returns an iterator for the given query and parameters."""
            self._ensure_connected()
            cursor = MySQLdb.cursors.SSCursor(self._db)
            try:
                self._execute(cursor, query, parameters, kwparameters)
                column_names = [d[0] for d in cursor.description]
                for row in cursor:
                    yield Row(zip(column_names, row))
            finally:
                cursor.close()
    
        def query(self, query, *parameters, **kwparameters):
            """Returns a row list for the given query and parameters."""
            cursor = self._cursor()
            try:
                self._execute(cursor, query, parameters, kwparameters)
                column_names = [d[0] for d in cursor.description]
                return [Row(itertools.izip(column_names, row)) for row in cursor]
            finally:
                cursor.close()
    
        def get(self, query, *parameters, **kwparameters):
            """Returns the (singular) row returned by the given query.
    
            If the query has no results, returns None.  If it has
            more than one result, raises an exception.
            """
            rows = self.query(query, *parameters, **kwparameters)
            if not rows:
                return None
            elif len(rows) > 1:
                raise Exception("Multiple rows returned for Database.get() query")
            else:
                return rows[0]
    
        # rowcount is a more reasonable default return value than lastrowid,
        # but for historical compatibility execute() must return lastrowid.
        def execute(self, query, *parameters, **kwparameters):
            """Executes the given query, returning the lastrowid from the query."""
            return self.execute_lastrowid(query, *parameters, **kwparameters)
    
        def execute_lastrowid(self, query, *parameters, **kwparameters):
            """Executes the given query, returning the lastrowid from the query."""
            cursor = self._cursor()
            try:
                self._execute(cursor, query, parameters, kwparameters)
                return cursor.lastrowid
            finally:
                cursor.close()
    
        def execute_rowcount(self, query, *parameters, **kwparameters):
            """Executes the given query, returning the rowcount from the query."""
            cursor = self._cursor()
            try:
                self._execute(cursor, query, parameters, kwparameters)
                return cursor.rowcount
            finally:
                cursor.close()
    
        def executemany(self, query, parameters):
            """Executes the given query against all the given param sequences.
    
            We return the lastrowid from the query.
            """
            return self.executemany_lastrowid(query, parameters)
    
        def executemany_lastrowid(self, query, parameters):
            """Executes the given query against all the given param sequences.
    
            We return the lastrowid from the query.
            """
            cursor = self._cursor()
            try:
                cursor.executemany(query, parameters)
                return cursor.lastrowid
            finally:
                cursor.close()
    
        def get_fields_str(self, tablename):
            cursor = self._cursor()
            cursor.execute("describe %s" % tablename)
            fields=[]
            for row in cursor.fetchall():
                fields.append(row[0])
            str = ", ".join(fields)
    
            cursor.close()
            return str
    
        def get_fields_prefix_str(self, tablename, prefix):
            cursor = self._cursor()
            cursor.execute("describe %s" % tablename)
            fields=[]
            for row in cursor.fetchall():
                fields.append(prefix+row[0])
            str = ", ".join(fields)
    
            cursor.close()
            return str
    
        def get_select_sql(self, tablename):
            str = self.get_fields_str(tablename)
            sql = "SELECT " + str + " FROM " + tablename + " "
            return sql
    
    
        def insert_by_dict(self, tablename, rowdict, replace=False):
            cursor = self._cursor()
            cursor.execute("describe %s" % tablename)
            allowed_keys = set(row[0] for row in cursor.fetchall())
            keys = allowed_keys.intersection(rowdict)
        
            if len(rowdict) > len(keys):
                unknown_keys = set(rowdict) - allowed_keys
                logging.error("skipping keys: %s", ", ".join(unknown_keys))
        
            columns = ", ".join(keys)
            values_template = ", ".join(["%s"] * len(keys))
        
            if replace:
                sql = "REPLACE INTO %s (%s) VALUES (%s)" % (
                    tablename, columns, values_template)
            else:
                sql = "INSERT INTO %s (%s) VALUES (%s)" % (
                    tablename, columns, values_template)
    
            values = tuple(rowdict[key] for key in keys)
            try:
                cursor.execute(sql, values)
                #self._execute(cursor, sql, values, None)
                return cursor.lastrowid
            finally:
                cursor.close()
    
        def transaction(self, query, *parameters, **kwparameters):
            self._db.begin()
            cursor = self._cursor()
            status = True
            try:
                for sql in query:
                    cursor.execute(sql, kwparameters or parameters)
                self._db.commit()
            except OperationalError, e:
                self._db.rollback()
                status = False
                raise Exception(e.args[1], e.args[0])
            finally:
                cursor.close()
            return status
    
    
        def executemany_rowcount(self, query, parameters):
            """Executes the given query against all the given param sequences.
    
            We return the rowcount from the query.
            """
            cursor = self._cursor()
            try:
                cursor.executemany(query, parameters)
                return cursor.rowcount
            finally:
                cursor.close()
    
        update = execute_rowcount
        updatemany = executemany_rowcount
    
        insert = execute_lastrowid
        insertmany = executemany_lastrowid
    
        def _ensure_connected(self):
            # Mysql by default closes client connections that are idle for
            # 8 hours, but the client library does not report this fact until
            # you try to perform a query and it fails.  Protect against this
            # case by preemptively closing and reopening the connection
            # if it has been idle for too long (7 hours by default).
            if (self._db is None or
                (time.time() - self._last_use_time > self.max_idle_time)):
                self.reconnect()
            self._last_use_time = time.time()
    
        def _cursor(self):
            self._ensure_connected()
            return self._db.cursor()
    
        def _execute(self, cursor, query, parameters, kwparameters):
            try:
                return cursor.execute(query, kwparameters or parameters)
            except OperationalError:
                logging.error("Error connecting to MySQL on %s", self.host)
                self.close()
                raise
    
    
    class Row(dict):
        """A dict that allows for object-like property access syntax."""
        def __getattr__(self, name):
            try:
                return self[name]
            except KeyError:
                raise AttributeError(name)
    
    if MySQLdb is not None:
        # Fix the access conversions to properly recognize unicode/binary
        FIELD_TYPE = MySQLdb.constants.FIELD_TYPE
        FLAG = MySQLdb.constants.FLAG
        CONVERSIONS = copy.copy(MySQLdb.converters.conversions)
    
        field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]
        if 'VARCHAR' in vars(FIELD_TYPE):
            field_types.append(FIELD_TYPE.VARCHAR)
    
        for field_type in field_types:
            CONVERSIONS[field_type] = [(FLAG.BINARY, str)] + CONVERSIONS[field_type]
    
        # Alias some common MySQL exceptions
        IntegrityError = MySQLdb.IntegrityError
        OperationalError = MySQLdb.OperationalError
    

    外带一个小样例,完整版请參照我在github上公布的一个webservice框架:https://github.com/yunfeiflying/tornado-rest-web-service-framwork/


    #!/usr/bin/env python2.7
    #
    # -*- coding:utf-8 -*-
    #
    #   Author  :   YunJianFei
    #   E-mail  :   yunjianfei@126.com
    #   Date    :   2014/02/25
    #   Desc    :   Test db
    #
    
    """ Data Access Object
        This file impelements DBI for the table 'Host'
    
    The Host table's create sql is : 
    
    CREATE TABLE IF NOT EXISTS `test`.`Host` (
      `host_id` INT NOT NULL AUTO_INCREMENT,
      `host_type` INT NULL,
      `hostname` VARCHAR(45) NULL,
      `ip` VARCHAR(45) NULL,
      `create_time` VARCHAR(45) NULL,
      `cpu_count` INT NULL,
      `cpu_pcount` INT NULL,
      `memory` INT NULL,
      `os` VARCHAR(200) NULL,
      `comment` VARCHAR(200) NULL,
      PRIMARY KEY (`host_id`))
    ENGINE = InnoDB;
    
    """
    
    from util.dbconst import TableName, TableFields, TableSelectSql
    import logging
    
    class HostDao:
        def __init__(self, db):
            mysql_host = "192.168.10.11:3306"
            db_name = "test"
            db_user = "root"
            db_pass = ""
    
            self.db = torndb.Connection(
                host=mysql_host, database=db_name,
                user=db_user, password=db_pass
            )
    
        def insert_by_dict(self, host, replace=False):
            try:
                id = self.db.insert_by_dict("Host", host, replace)
                return id
            except Exception, ex:
                logging.error("Insert host failed! Exception: %s   Host: %s", str(ex), str(host))
                return None
    
        def if_exist(self, hostname, ip):
            ret = self.get_by_hostname(hostname)
            if ret != None:
                return True
    
            ret = self.get_by_ip(ip)
            if ret != None:
                return True
    
            return False
    
        def get_by_ip(self, ip):
            sql = TableSelectSql.HOST + " where ip='" + str(ip)+"'"
            return self.db.get(sql)
    
        def get_all(self):
            sql = TableSelectSql.HOST
            return self.db.query(sql)
    
        def get_by_hostname(self, hostname):
            sql = TableSelectSql.HOST + " where hostname='" + str(hostname)+"'"
            return self.db.get(sql)
    
        def get_by_id(self, host_id):
            sql = TableSelectSql.HOST + " where host_id=%s" % str(host_id)
            return self.db.get(sql)
    
        def get_id_by_hostname(self, hostname):
            sql = TableSelectSql.HOST + " where hostname='" + str(hostname)+"'"
            ret = self.db.get(sql)
            if ret != None:
                return ret.host_id
            return None
    
        def update_worker_num_by_hostname(self, hostname, worker_num):
            try:
                sql = "UPDATE Host SET worker_num=%s WHERE hostname='%s'" % (worker_num, str(hostname))
                ret = self.db.execute(sql)
                return ret
            except Exception, ex:
                logging.error("Update Host failed! Exception: %s   hostname: %s , worker_num: %s", str(ex), str(hostname), worker_num)
                return None
    
        def update_worker_num_by_id(self, host_id, worker_num):
            try:
                sql = "UPDATE Host SET worker_num=%s WHERE host_id=%s" % (worker_num, host_id)
                ret = self.db.execute(sql)
                return ret
            except Exception, ex:
                logging.error("Update Host failed! Exception: %s   host_id: %s , worker_num: %s", str(ex), host_id, worker_num)
                return None
    
        def del_by_hostname(self, hostname):
            try:
                sql = "DELETE FROM Host WHERE hostname='" + str(hostname) + "'"
                ret = self.db.execute(sql)
                return ret
            except Exception, ex:
                logging.error("Delete host failed! Exception: %s   hostname: %s", str(ex), str(hostname))
                return None
    
        def del_by_id(self, host_id):
            try:
                sql = "DELETE FROM Host WHERE host_id=" + str(host_id)
                ret = self.db.execute(sql)
                return ret
            except Exception, ex:
                logging.error("Delete host failed! Exception: %s   host_id: %s", str(ex), host_id)
                return None
    




  • 相关阅读:
    eclipse——插件报错:Could not find node.js
    常用css设置
    前端跨域常见的处理方法
    刷新建设批次为空的sql
    PostgreSQL 添加字段语句
    查询在A表有记录数据,B表没有记录数据的SQL
    postgreSql 备份复制表结构和数据 SQL语句
    删除site_planning_id 和version 重复的,如果有多个版本,留下版本号最高的
    leetcode 精选top面试题
    leetcode 精选top面试题
  • 原文地址:https://www.cnblogs.com/liguangsunls/p/6898654.html
Copyright © 2011-2022 走看看