import pymysql import time ''' 基于pymysql的数据库扩展类,操作mysql数据库和PHP有相同感,PHP中使用关联数组处理,python使用dict处理 ''' class PDO(object): __db = None # pymysql.connect 返回的数据操作对象 __cursor = None # 游标 默认使用dict处理数据 __error_message = "" # 错误信息 __error_code = "" # 错误码 __error = "" # 总错误信息 __sql = "" # SQL语句 port = 3306 # 默认端口 __debug = "" # 是否开启debug模式 __start_time = "" # 执行SQL开始时间 __end_time = "" # 执行SQL结束时间 __affected_rows = 0 # 影响行数 __timeout_error_value = 1000 # 最大效率查询时间限制 超过此时间(ms) debug显示用时为红色字体 __last_id = None # 上次插入的id MAX_LIMIT = 100 # 最大查询条数 ''' 初始化数据库连接 默认编码UTF-8 默认端口3306 debug默认关闭 commit模式默认自动提交 ''' def __init__(self, host, user, password, db, charset="utf8", port=3306, debug=False,use_unicode=True, autocommit=True): self.__debug = debug self.port = port try: config = { "host":host, "user":user, "password":password, "db":db, "charset":charset, "use_unicode":True, "port":port, "cursorclass":pymysql.cursors.DictCursor } self.__db = pymysql.connect(**config) self.__cursor = self.__db.cursor() self.__db.autocommit(autocommit) if self.__debug: self.__start_time = time.time() except Exception as e: self.__exception(e) ''' 错误抛出异常,显示错误信息 ''' def __exception(self, e): self.__error_code = str(e.args[0]) self.__error_message = e.args[1] error_arr = ["[Err] ",self.__error_code,"-",self.__error_message]; self.__error = str.join("",error_arr) raise Exception(self.__error) ''' 返回原生pymysql的连接对象 ''' def get_connect(self): return self.__db ''' 返回connect连接对象的游标对象 connect.cursor() ''' def get_cursor(self): return self.__cursor ''' 将数据插入到数据库 接收参数为 table(表名) rows(dict的数组/单个dict) ''' def insert(self, table, rows): insert_sql = self.__insert_sql(rows) sql_arr = ["INSERT INTO `",table,"` ",insert_sql] self.__sql = str.join("",sql_arr) try: self.__cursor.execute(self.__sql) self.__affected_rows = self.__db.affected_rows() self.__last_id = self.__db.insert_id() return self.__last_id except Exception as e: self.__exception(e) ''' 拼接INSERT SQL语句 ''' def __insert_sql(self, rows): temp_rows = [] if type(rows) is dict: temp_rows.append(rows) else: temp_rows = rows first_row = temp_rows[0] column_sql = "(" + self.__get_columns(first_row) + ")" data_sql_arr = [] for row in temp_rows: row_values = list(row.values()) for i, v in enumerate(row_values): temp_str = str(v) temp_str = temp_str.replace("'", "\'") row_values[i] = "'" + temp_str + "'" data_sql = "(" + str.join(",", row_values) + ")" data_sql_arr.append(data_sql) data_sql_all = str.join(",", data_sql_arr) sql = column_sql + " VALUES " + data_sql_all return sql ''' 获取要查询/插入的列名SQL ''' def __get_columns(self, row): keys_arr = row.keys() keys_arr = list(keys_arr) for i, v in enumerate(keys_arr): keys_arr[i] = "`" + v + "`" column_sql = str.join(",", keys_arr) return column_sql ''' 删除信息。传递参数 table(表名) where(where条件 为一个dict) ''' def delete(self, table, where, limit=1): sql_arr=["DELETE FROM `",table,self.__where_sql(where)," LIMIT ",str(limit)] self.__sql = str.join("",sql_arr) try: self.__cursor.execute(self.__sql) self.__affected_rows = self.__db.affected_rows() return self.__db.affected_rows() except Exception as e: self.__exception(e) ''' 将where条件的dict解析为SQL语句 ''' def __where_sql(self, where): if len(where) == 0: return "" else: sql = [] for k, v in where.items(): if k == "OR": continue k = k.strip() explode_arr = k.split(" ") if type(v) is list: temp_sql = "" if len(v) == 0: sql.append(temp_sql) else: in_str = " IN " if len(explode_arr) == 1: for v_index, v_value in enumerate(v): temp_v_value = str(v_value) temp_v_value = temp_v_value.replace("'", "\'") v[v_index] = "'" + temp_v_value + "'" temp_sql = str.join(",", v) temp_sql = "`" + str(k) + "`" + in_str + " (" + temp_sql + ") " sql.append(temp_sql) else: for v_index, v_value in enumerate(v): temp_v_value = str(v_value) temp_v_value = temp_v_value.replace("'", "\'") v[v_index] = "'" + temp_v_value + "'" temp_sql = str.join(",", v) column = explode_arr[0] del explode_arr[0] condition = str.join(" ", explode_arr) temp_sql = " `" + str(column) + "` " + str(condition) + " (" + temp_sql + ") " sql.append(temp_sql) else: if len(explode_arr) >= 2: temp_v_value = str(v) temp_v_value = temp_v_value.replace("'", "\'") column = explode_arr[0] del explode_arr[0] condition = str.join(" ", explode_arr) sql.append(" `" + str(column) + "` " + str(condition) + " '" + str(v) + "' ") else: temp_v_value = str(v) temp_v_value = temp_v_value.replace("'", "\'") sql.append(" `" + str(k) + "` =" + " " + "'" + temp_v_value + "'" + " ") if "OR" in where: return str.join(" OR ", sql) else: return str.join(" AND ", sql) ''' 更新数据. table(表名) update_dict(要更新的数据dict) where(where条件 dict) ''' def update(self, table, update_dict, where, limit=1): sql = "UPDATE `" + table + "` SET " + self.__update_sql(update_dict) sql += " WHERE " + self.__where_sql(where) + " LIMIT " + str(limit) self.__sql = sql try: self.__cursor.execute(self.__sql) self.__affected_rows = self.__db.affected_rows() return self.__db.affected_rows() except Exception as e: self.__exception(e) ''' 获取更新SQL ''' def __update_sql(self, update_dict): sql_arr = [] if type(update_dict) is dict: for k, v in update_dict.items(): sql_arr_item = " `" + str(k) + "` = '" + str(v) + "' " sql_arr.append(sql_arr_item) return str.join(",", sql_arr) else: return "" ''' SELECT查询语句。 cols是list,要取的列名集合。 table(表名) where(where条件 dict e.g {"id":1}) id=1 order (排序 dict,e.g: {"id":"DESC","created_at":"ASC"}) offset limit 限制查询条数 ''' def select(self, cols, table, where={}, order={},): # offset=0, limit=1000 need_column = "" if type(cols) is list and len(cols) != 0: for i, v in enumerate(cols): cols[i] = "`" + str(v) + "`" need_column = str.join(",", cols) else: need_column = "*" order_sql = "" if type(order) is dict and len(order) != 0: order_arr = [] for col, sort in order.items(): order_arr_item = " `" + str(col) + "` " + str(sort) + " " order_arr.append(order_arr_item) order_sql = str.join(",", order_arr) order_sql = " ORDER BY " + order_sql else: order_sql = "" where_sql = " " if len(where): where_sql = " WHERE " + self.__where_sql(where) # limit = min(limit,self.MAX_LIMIT) self.__sql = "SELECT " + need_column + " FROM `" + table + "`" + where_sql self.__sql += order_sql # + " LIMIT " + str(offset) + "," + str(limit) try: self.__cursor.execute(self.__sql) self.__affected_rows = self.__db.affected_rows() return self.__cursor.fetchall() except Exception as e: self.__exception(e) ''' 执行原生查询SQL语句 select ''' def query(self, sql): try: self.__sql = sql self.__cursor.execute(self.__sql) self.__affected_rows = self.__db.affected_rows() return self.__cursor.fetchall() except Exception as e: self.__exception(e) ''' 执行原生操作语句 insert update delete ''' def execute(self, sql, ret_last_id=False): try: self.__sql = sql self.__cursor.execute(self.__sql) self.__affected_rows = self.__db.affected_rows() if ret_last_id: self.__last_id = self.__db.insert_id() return self.__last_id return self.__db.affected_rows() except Exception as e: self.__exception(e) ''' 获取本次操作的SQL语句 ''' def count(self, table, where={}): where_sql = "" if where: where_sql = " WHERE "+self.__where_sql(where) self.__sql = "SELECT COUNT(*) as num FROM `"+table+"` "+where_sql try: self.__cursor.execute(self.__sql) row = self.__cursor.fetchone() return row["num"] except Exception as e: self.__exception(e) def get_sql(self): return self.__sql ''' 析构函数 若是debug模式开启 则打印出SQL语句 影响条数 (最后插入的id) 操作执行时间(ms) ''' def __del__(self): if self.__debug: self.__end_time = time.time() use_time = self.__end_time - self.__start_time use_time = use_time * 1000 use_time = int(use_time) # 打印log信息 颜色为青色 错误/时间超过默认1000ms 变为红色 print(" 33[32;0m[SQL] " + self.__sql + " 33[0m") # SQL语句 print(" 33[32;0m[affected_rows] " + str(self.__affected_rows) + " 33[0m") #影响行数 if self.__last_id: print(" 33[32;0m[last_insert_id] " + str(self.__last_id) + " 33[0m") # 最后插入的id if use_time < self.__timeout_error_value: print(" 33[32;0m[time] " + str(use_time) + " ms 33[0m") # 执行时间 else: print(" 33[31;0m[time] " + str(use_time) + " ms") if self.__error: print(" 33[1;31;0m" + self.__error + " 33[0m")