zoukankan      html  css  js  c++  java
  • 修改sqlarchemy源码使其支持jdbc连接mysql

    注意:本文不会将所有完整源码贴出,只是将具体的思路以及部分源码贴出,需要感兴趣的读者自己实验然后实现吆。 

    缘起

      公司最近的项目需要将之前的部分业务的数据库连接方式改为jdbc,但由于之前的项目都使用sqlarchemy作为orm框架,该框架似乎没有支持jdbc,为了能做最小的修改并满足需求,所以需要修改sqlarchemy的源码。

    基本配置介绍

      sqlalchemy 版本:1.1.15

      使用jaydebeapi模块调用jdbc连接mysql

    前提:

      1 学会使用jaydebeapi模块,使用方法具体可以参考:

        https://pypi.python.org/pypi/JayDeBeApi

        介绍的比较详细的可以参考:http://shuaizki.github.io/language_related/2013/06/22/introduction-to-jpype.html

         jaydebeapi是一个基于jpype的在Cpython中可以通过jdbc连接数据库的模块。该模块的python代码很少,基本上可以分为连接部分、游标部分、结果转换部分这三个。一般来说我们可能需要修改的就是结果转换部分,比如说sqlalchemy查询时如果某条记录中含TIME字段,那么该字段一般要表现为timedelta对象。而在jaydebeapi中则返回的是字符串对象,这样在sqlalchemy中会报错的。

    sqlarchemy为我们实现了ORM对象与语句的转换,连接池,session(包括对线程的支持scope_session)等较为上层的逻辑,但这些东西在这里我们不需要考虑(当然创建一个连接,生成curcor还是要考虑的),我们要考虑的仅仅是当sqlarchemy把sql语句以及参数传过来的时候我们该怎么做,以及当sql语句执行后如何对结果进行转换

    所需注意的问题

    1 sql语句以及参数传过来的时候我们该怎么做:

      1.1 对参数进行转义,防止sql注入

    2 执行完sql语句后对结果如何处理:

      2.1 我们知道python的基础sql模块会对结果进行处理,比如说把NUll转换为None,把数据库中的date字段转换为python的date对象等等

      2.2 一些不知道该怎么形容的数据:

        当我们查询时,获取的数据对应字段的元信息

        当我们update或者delete等操作时需要获取影响了多少行

        当我们插入数据后,如果主键是自增字段,我们一般(可以说在sqlarchemy中这是必须)需要获取该记录的主键值   

         实际上就是支持 python DB API 

    3 sqlalchemy增加代码,使其支持我们修改后的jaydebeapi

    如何解决

    1.1解决方案:

      人家pymysql咋搞,我就咋搞!

      在pymysql.corsors文件中Cursor类中有一个叫做mogrify的方法,这个方法不仅对参数转义,而且会将参数放置到sql语句中组成完整的可执行sql语句。所以偷一些代码然后稍加修改就是这样:

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    from functools import partial
    from pymysql.converters import escape_item, escape_string
    import sys
    
    
    PY2 = sys.version_info[0] == 2
    
    if PY2:
        import __builtin__
        range_type = xrange
        text_type = unicode
        long_type = long
        str_type = basestring
        unichr = __builtin__.unichr
    else:
        range_type = range
        text_type = str
        long_type = int
        str_type = str
        unichr = chr
    
    
    def _ensure_bytes(x, encoding="utf8"):
        if isinstance(x, text_type):
            x = x.encode(encoding)
        return x
    
    
    def _escape_args(args, encoding):
        ensure_bytes = partial(_ensure_bytes, encoding=encoding)
    
        if isinstance(args, (tuple, list)):
            if PY2:
                args = tuple(map(ensure_bytes, args))
            return tuple(escape(arg, encoding) for arg in args)
        elif isinstance(args, dict):
            if PY2:
                args = dict((ensure_bytes(key), ensure_bytes(val)) for
                            (key, val) in args.items())
            return dict((key, escape(val, encoding)) for (key, val) in args.items())
    
    
    def escape(obj, charset, mapping=None):
        if isinstance(obj, str_type):
            return "'" + escape_string(obj) + "'"
        return escape_item(obj, charset, mapping=mapping)
    
    
    def mogrify(query, encoding, args=None):
        if PY2:  # Use bytes on Python 2 always
            query = _ensure_bytes(query, encoding=encoding)
        if args is not None:
            # r = _escape_args(args, encoding)
            query = query % _escape_args(args, encoding)
        return query
    
    
    # 调用一下mogrigy函数
    # print(mogrify("select * from ll where a in %s and b = %s", "utf8", [[2, 1], 3]))
    View Code

    2.1解决方案:

      人家pymysql咋搞,我就咋搞!

      在pymysql.converters中有一个名为decoders的字典,这里面存放了mysql字段与python对象的转换关系!大概是这样

    def _convert_second_fraction(s):
        if not s:
            return 0
        # Pad zeros to ensure the fraction length in microseconds
        s = s.ljust(6, '0')
        return int(s[:6])
    
    DATETIME_RE = re.compile(r"(d{1,4})-(d{1,2})-(d{1,2})[T ](d{1,2}):(d{1,2}):(d{1,2})(?:.(d{1,6}))?")
    
    
    def convert_datetime(obj):
        """Returns a DATETIME or TIMESTAMP column value as a datetime object:
    
          >>> datetime_or_None('2007-02-25 23:06:20')
          datetime.datetime(2007, 2, 25, 23, 6, 20)
          >>> datetime_or_None('2007-02-25T23:06:20')
          datetime.datetime(2007, 2, 25, 23, 6, 20)
    
        Illegal values are returned as None:
    
          >>> datetime_or_None('2007-02-31T23:06:20') is None
          True
          >>> datetime_or_None('0000-00-00 00:00:00') is None
          True
    
        """
        if not PY2 and isinstance(obj, (bytes, bytearray)):
            obj = obj.decode('ascii')
    
        m = DATETIME_RE.match(obj)
        if not m:
            return convert_date(obj)
    
        try:
            groups = list(m.groups())
            groups[-1] = _convert_second_fraction(groups[-1])
            return datetime.datetime(*[ int(x) for x in groups ])
        except ValueError:
            return convert_date(obj)
    
    TIMEDELTA_RE = re.compile(r"(-)?(d{1,3}):(d{1,2}):(d{1,2})(?:.(d{1,6}))?")
    
    
    def convert_timedelta(obj):
        """Returns a TIME column as a timedelta object:
    
          >>> timedelta_or_None('25:06:17')
          datetime.timedelta(1, 3977)
          >>> timedelta_or_None('-25:06:17')
          datetime.timedelta(-2, 83177)
    
        Illegal values are returned as None:
    
          >>> timedelta_or_None('random crap') is None
          True
    
        Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
        can accept values as (+|-)DD HH:MM:SS. The latter format will not
        be parsed correctly by this function.
        """
        if not PY2 and isinstance(obj, (bytes, bytearray)):
            obj = obj.decode('ascii')
    
        m = TIMEDELTA_RE.match(obj)
        if not m:
            return None
    
        try:
            groups = list(m.groups())
            groups[-1] = _convert_second_fraction(groups[-1])
            negate = -1 if groups[0] else 1
            hours, minutes, seconds, microseconds = groups[1:]
    
            tdelta = datetime.timedelta(
                hours = int(hours),
                minutes = int(minutes),
                seconds = int(seconds),
                microseconds = int(microseconds)
                ) * negate
            return tdelta
        except ValueError:
            return None
    
    TIME_RE = re.compile(r"(d{1,2}):(d{1,2}):(d{1,2})(?:.(d{1,6}))?")
    
    
    def convert_time(obj):
        """Returns a TIME column as a time object:
    
          >>> time_or_None('15:06:17')
          datetime.time(15, 6, 17)
    
        Illegal values are returned as None:
    
          >>> time_or_None('-25:06:17') is None
          True
          >>> time_or_None('random crap') is None
          True
    
        Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
        can accept values as (+|-)DD HH:MM:SS. The latter format will not
        be parsed correctly by this function.
    
        Also note that MySQL's TIME column corresponds more closely to
        Python's timedelta and not time. However if you want TIME columns
        to be treated as time-of-day and not a time offset, then you can
        use set this function as the converter for FIELD_TYPE.TIME.
        """
        if not PY2 and isinstance(obj, (bytes, bytearray)):
            obj = obj.decode('ascii')
    
        m = TIME_RE.match(obj)
        if not m:
            return None
    
        try:
            groups = list(m.groups())
            groups[-1] = _convert_second_fraction(groups[-1])
            hours, minutes, seconds, microseconds = groups
            return datetime.time(hour=int(hours), minute=int(minutes),
                                 second=int(seconds), microsecond=int(microseconds))
        except ValueError:
            return None
    
    
    def convert_date(obj):
        """Returns a DATE column as a date object:
    
          >>> date_or_None('2007-02-26')
          datetime.date(2007, 2, 26)
    
        Illegal values are returned as None:
    
          >>> date_or_None('2007-02-31') is None
          True
          >>> date_or_None('0000-00-00') is None
          True
    
        """
        if not PY2 and isinstance(obj, (bytes, bytearray)):
            obj = obj.decode('ascii')
        try:
            return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
        except ValueError:
            return None
    
    
    def convert_mysql_timestamp(timestamp):
        """Convert a MySQL TIMESTAMP to a Timestamp object.
    
        MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME:
    
          >>> mysql_timestamp_converter('2007-02-25 22:32:17')
          datetime.datetime(2007, 2, 25, 22, 32, 17)
    
        MySQL < 4.1 uses a big string of numbers:
    
          >>> mysql_timestamp_converter('20070225223217')
          datetime.datetime(2007, 2, 25, 22, 32, 17)
    
        Illegal values are returned as None:
    
          >>> mysql_timestamp_converter('2007-02-31 22:32:17') is None
          True
          >>> mysql_timestamp_converter('00000000000000') is None
          True
    
        """
        if not PY2 and isinstance(timestamp, (bytes, bytearray)):
            timestamp = timestamp.decode('ascii')
        if timestamp[4] == '-':
            return convert_datetime(timestamp)
        timestamp += "0"*(14-len(timestamp)) # padding
        year, month, day, hour, minute, second = 
            int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), 
            int(timestamp[8:10]), int(timestamp[10:12]), int(timestamp[12:14])
        try:
            return datetime.datetime(year, month, day, hour, minute, second)
        except ValueError:
            return None
    
    def convert_set(s):
        if isinstance(s, (bytes, bytearray)):
            return set(s.split(b","))
        return set(s.split(","))
    
    
    def through(x):
        return x
    
    
    #def convert_bit(b):
    #    b = "x00" * (8 - len(b)) + b # pad w/ zeroes
    #    return struct.unpack(">Q", b)[0]
    #
    #     the snippet above is right, but MySQLdb doesn't process bits,
    #     so we shouldn't either
    convert_bit = through
    
    
    def convert_characters(connection, field, data):
        field_charset = charset_by_id(field.charsetnr).name
        encoding = charset_to_encoding(field_charset)
        if field.flags & FLAG.SET:
            return convert_set(data.decode(encoding))
        if field.flags & FLAG.BINARY:
            return data
    
        if connection.use_unicode:
            data = data.decode(encoding)
        elif connection.charset != field_charset:
            data = data.decode(encoding)
            data = data.encode(connection.encoding)
        return data
    
    encoders = {
        bool: escape_bool,
        int: escape_int,
        long_type: escape_int,
        float: escape_float,
        str: escape_str,
        text_type: escape_unicode,
        tuple: escape_sequence,
        list: escape_sequence,
        set: escape_sequence,
        frozenset: escape_sequence,
        dict: escape_dict,
        bytearray: escape_bytes,
        type(None): escape_None,
        datetime.date: escape_date,
        datetime.datetime: escape_datetime,
        datetime.timedelta: escape_timedelta,
        datetime.time: escape_time,
        time.struct_time: escape_struct_time,
        Decimal: escape_object,
    }
    
    if not PY2 or JYTHON or IRONPYTHON:
        encoders[bytes] = escape_bytes
    
    decoders = {
        FIELD_TYPE.BIT: convert_bit,
        FIELD_TYPE.TINY: int,
        FIELD_TYPE.SHORT: int,
        FIELD_TYPE.LONG: int,
        FIELD_TYPE.FLOAT: float,
        FIELD_TYPE.DOUBLE: float,
        FIELD_TYPE.LONGLONG: int,
        FIELD_TYPE.INT24: int,
        FIELD_TYPE.YEAR: int,
        FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp,
        FIELD_TYPE.DATETIME: convert_datetime,
        FIELD_TYPE.TIME: convert_timedelta,
        FIELD_TYPE.DATE: convert_date,
        FIELD_TYPE.SET: convert_set,
        FIELD_TYPE.BLOB: through,
        FIELD_TYPE.TINY_BLOB: through,
        FIELD_TYPE.MEDIUM_BLOB: through,
        FIELD_TYPE.LONG_BLOB: through,
        FIELD_TYPE.STRING: through,
        FIELD_TYPE.VAR_STRING: through,
        FIELD_TYPE.VARCHAR: through,
        FIELD_TYPE.DECIMAL: Decimal,
        FIELD_TYPE.NEWDECIMAL: Decimal,
    }
    原始代码

      而在jaydebeapi中也有一些相似的代码:

    def _to_datetime(rs, col):
        java_val = rs.getTimestamp(col)
        if not java_val:
            return
        d = datetime.datetime.strptime(str(java_val)[:19], "%Y-%m-%d %H:%M:%S")
        d = d.replace(microsecond=int(str(java_val.getNanos())[:6]))
        return str(d)
    
    def _to_time(rs, col):
        java_val = rs.getTime(col)
        if not java_val:
            return
        return str(java_val)
    
    def _to_date(rs, col):
        java_val = rs.getDate(col)
        if not java_val:
            return
        # The following code requires Python 3.3+ on dates before year 1900.
        # d = datetime.datetime.strptime(str(java_val)[:10], "%Y-%m-%d")
        # return d.strftime("%Y-%m-%d")
        # Workaround / simpler soltution (see
        # https://github.com/baztian/jaydebeapi/issues/18):
        return str(java_val)[:10]
    
    def _to_binary(rs, col):
        java_val = rs.getObject(col)
        if java_val is None:
            return
        return str(java_val)
    
    def _java_to_py(java_method):
        def to_py(rs, col):
            java_val = rs.getObject(col)
            if java_val is None:
                return
            if PY2 and isinstance(java_val, (string_type, int, long, float, bool)):
                return java_val
            elif isinstance(java_val, (string_type, int, float, bool)):
                return java_val
            return getattr(java_val, java_method)()
        return to_py
    
    _to_double = _java_to_py('doubleValue')
    
    _to_int = _java_to_py('intValue')
    
    _to_boolean = _java_to_py('booleanValue')
    
    
    _DEFAULT_CONVERTERS = {
        # see
        # http://download.oracle.com/javase/8/docs/api/java/sql/Types.html
        # for possible keys
        'TIMESTAMP': _to_datetime,
        'TIME': _to_time,
        'DATE': _to_date,
        'BINARY': _to_binary,
        'DECIMAL': _to_double,
        'NUMERIC': _to_double,
        'DOUBLE': _to_double,
        'FLOAT': _to_double,
        'TINYINT': _to_int,
        'INTEGER': _to_int,
        'SMALLINT': _to_int,
        'BOOLEAN': _to_boolean,
        'BIT': _to_boolean
    }
    原始代码

      然后我们稍微修改一下即可。 

    2.2解决方案

      在jaydebeapi中的Cursor类中,有一个属性叫做description这个属性,通过他我们就能获取查询时表的字段的元信息

      在jaydebeapi中的Cursor类中,是有rowcount这个属性的,他表示当我们进行插入更新删除操作时受影响的行数。

      而在pymysql的cursors文件中的Cursor类中的_do_get_result方法中不仅仅有受影响的行数rowcount,还有lastrowid这个属性,他表示当我们插入数据且对应主键是自增字段时,最后一条数据的主键值。但是在jaydebeapi中是没有的,而这个属性在sqlalchemy中恰恰是需要的,所以我们要为jaydebeapi的Cursor类加上这个属性。代码如下:

    class Cursor(object):
    
        lastrowid = None
        rowcount = -1
        _meta = None
        _prep = None
        _rs = None
        _description = None
    ...此处省略部分不相关代码...
    def execute(self, operation, parameters=None): if self._connection._closed: raise Error() if not parameters: parameters = () self._close_last() self._prep = self._connection.jconn.prepareStatement(operation) self._set_stmt_parms(self._prep, parameters) try: is_rs = self._prep.execute() # print is_rs except: _handle_sql_exception() # print(dir(self._prep)) # 如果是查询的话 is_rs就是1 if is_rs: self._rs = self._prep.getResultSet() self._meta = self._rs.getMetaData() self.rowcount = -1 self.lastrowid = None # 插入/修改/删除时 is_rs都为0 else: self.rowcount = self._prep.getUpdateCount() self.lastrowid = int(self._prep.lastInsertID)

    注意:上面的代码中红色的代码是我新增的

    3解决方案

        sqlarchemy中底层数据库连接模块都放在dialects这个包中,这个包里面有多个包分别是mysql oracle等数据库的基本数据库连接类,因为公司只使用mysql数据库,所以仅仅做了mysql的jdbc扩展,就放到了mysql包中。

    大体介绍一下我们将要修改的或者用到的类:

      MySQLDialect

        位置:sqlarchemy.dialects.mysql.base 

        描述:它是一个提供了对mysql数据库的连接、语句的执行等操作的基类,所以我们需要新写一个jdbcdialect类并继承它,然后重写某些方法。

        为什么会用到:这个就不用多说了

      ExecutionContext

        位置:sqlarchemy.engine.interface

        描述:通过这个东西我们可以获取当前游标的执行环境,比如说本次sql语句的执行影响了多少行,我们刚插入的一行的自增主键值是多少。他也负责把我们所写的python ORM语句转换为可以被底层数据库模块比如pymysql可以执行的东西。

    创建dialect类:

    我们知道使用sqlalchemy时首先需要创建一个engine,engine的第一个参数是一个URL,就像这样:mysql+pymysql://user:password@host:port/db?charset=utf8

      这段URL主要配置了三项:

        配置1 首先声明了我们要连接mysql数据库

        配置2 然后配置了底层连接数据库的dialect(这个单词翻译过来叫方言,就好比同是汉语(连接mysql),我们可以说山东话(pymysql)也可以说湖南话(mysqldb))模块是pymysql

        配置3 配置了用户名,密码,主机地址,端口,数据库名等信息

      通过查看代码我们可以看到:

        上面中的配置1实际上就是说接下来要在 sqlalchemy.dialects.mysql包中获取提供数据库操作等方法的class了。

        配置2实际上就是说 配置1想要找的的class我定义在了sqlalcehmy.dialects.mysql.pymysql中

        配置3会作为URL类包装解析,然后作为参数传入dialect实例的create_connect_args方法,以获取数据库连接参数。

    然后创建engine时还可以指定许多额外的参数,比如说连接池的配置等,这里面有几个我们需要注意的参数:

      假如我们没有指定module(数据库连接底层模块),默认会调用dialect类的类方法dbapi

      假如我们没有指定creator(与数据库建立连接的方法,一般是个函数)这个参数的话默认建立连接时会调用dialect实例的connect方法,并把create_connect_args返回的连接参数传入。

      当我们第一次与数据库建立连接时,会调用dialect实例的initialize方法,这个方法会做一系列操作,比如说获取当前数据库的版本信息:dialect实例的_get_server_version_info方法;获取当前isolation级别:dialect实例的get_isolation_level方法

    然后就很简单了:在sqlalchemy中找到sqlalchemy.dialects.mysql这个目录,然后新建一个名叫jaydebeapi的文件,并找到该目录下的pymysql文件,你会看到:

    from .mysqldb import MySQLDialect_mysqldb
    from ...util import langhelpers, py3k
    
    
    class MySQLDialect_pymysql(MySQLDialect_mysqldb):
        driver = 'pymysql'
    
        description_encoding = None
    
        # generally, these two values should be both True
        # or both False.   PyMySQL unicode tests pass all the way back
        # to 0.4 either way.  See [ticket:3337]
        supports_unicode_statements = True
        supports_unicode_binds = True
    
        def __init__(self, server_side_cursors=False, **kwargs):
            super(MySQLDialect_pymysql, self).__init__(**kwargs)
            self.server_side_cursors = server_side_cursors
    
        @langhelpers.memoized_property
        def supports_server_side_cursors(self):
            try:
                cursors = __import__('pymysql.cursors').cursors
                self._sscursor = cursors.SSCursor
                return True
            except (ImportError, AttributeError):
                return False
    
        @classmethod
        def dbapi(cls):
            return __import__('pymysql')
    
        if py3k:
            def _extract_error_code(self, exception):
                if isinstance(exception.args[0], Exception):
                    exception = exception.args[0]
                return exception.args[0]
    
    dialect = MySQLDialect_pymysql
    sqlalchemy.dialects.mysql.pymysql源码

    就这一个类,我们只需要继承这个类并重写某些方法就是了。就像这样:

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    import re
    from .pymysql import MySQLDialect_mysqldb
    
    
    class MySQLDialect_jaydebeapi(MySQLDialect_mysqldb):
        driver = 'jaydebeapi'
    
        @classmethod
        def dbapi(cls):
            return __import__('jaydebeapi')
    
        def connect(self, *cargs, **cparams):
            # get_jdbc_conn这个方法就自己写吧,实际上就是用jaydebeapi生成一个连接,但需要注意,连接的autocommit要设置为False
            return get_jdbc_conn(self.dbapi, **cparams)
    
        def _get_server_version_info(self, connection):
            dbapi_con = connection.connection
            cursor = dbapi_con.cursor()
            cursor.execute("select version()")
            version = str(cursor.fetchone()[0])
            cursor.close()
            version_list = []
            r = re.compile(r'[.-]')
            for n in r.split(version):
                try:
                    version_list.append(int(n))
                except ValueError:
                    version_list.append(n)
            return tuple(version_list)
    
        def _detect_charset(self, connection):
            """Sniff out the character set in use for connection results."""
    
            try:
                # note: the SQL here would be
                # "SHOW VARIABLES LIKE 'character_set%%'"
                # print dir(connection.connection)
                cset_name = connection.connection.character_set_name
            except AttributeError:
                return 'utf8'
            else:
                return cset_name()

    个人在修改源码中获取的知识点

    点1:

      com.mysql.jdbc.exceptions.MySQLNonTransientConnectionException: Can’t call rollback when autocommit=true

      1. 当开启autocommit=true时,回滚没有意义,无论成功/失败都已经已经将事务提交
      2. autocommit=false,我们需要运行conn.commit()执行事务, 如果失败则需要conn.rollback()对事务进行回滚;

    点2:

       尝试连接mysql时报错:Unknown system variable 'transaction_isolation'

      这是因为我的MySQLDialect_jaydebeapi类中的_get_server_version_info方法返回写死为5.7.21版本,而在mysql的Mysqldialect类的get_isolation_level中,会判断如果版本大于等于5.7.20的话执行SELECT @@transaction_isolation,反之会执行SELECT @@tx_isolation。

      于是看了看自己的mysql版本是5.7.11 ,遂改变版本号。

  • 相关阅读:
    数据产品—数据仓库
    数据产品-开篇
    os.walk()
    pytest入门
    XML 文件处理
    字符编码
    消息队列
    Pycharm
    AWS入门
    Python配置模块:configparser参数含义
  • 原文地址:https://www.cnblogs.com/MnCu8261/p/8601848.html
Copyright © 2011-2022 走看看