import pymysql
import time
'''
基于pymysql的数据库扩展类,操作mysql数据库和PHP有相同感,PHP中使用关联数组处理,python使用dict处理
'''
class PDO(object):
__db = None # pymysql.connect 返回的数据操作对象
__cursor = None # 游标 默认使用dict处理数据
__error_message = "" # 错误信息
__error_code = "" # 错误码
__error = "" # 总错误信息
__sql = "" # SQL语句
port = 3306 # 默认端口
__debug = "" # 是否开启debug模式
__start_time = "" # 执行SQL开始时间
__end_time = "" # 执行SQL结束时间
__affected_rows = 0 # 影响行数
__timeout_error_value = 1000 # 最大效率查询时间限制 超过此时间(ms) debug显示用时为红色字体
__last_id = None # 上次插入的id
MAX_LIMIT = 100 # 最大查询条数
'''
初始化数据库连接 默认编码UTF-8 默认端口3306 debug默认关闭 commit模式默认自动提交
'''
def __init__(self, host, user, password, db, charset="utf8", port=3306, debug=False,use_unicode=True, autocommit=True):
self.__debug = debug
self.port = port
try:
config = {
"host":host,
"user":user,
"password":password,
"db":db,
"charset":charset,
"use_unicode":True,
"port":port,
"cursorclass":pymysql.cursors.DictCursor
}
self.__db = pymysql.connect(**config)
self.__cursor = self.__db.cursor()
self.__db.autocommit(autocommit)
if self.__debug:
self.__start_time = time.time()
except Exception as e:
self.__exception(e)
'''
错误抛出异常,显示错误信息
'''
def __exception(self, e):
self.__error_code = str(e.args[0])
self.__error_message = e.args[1]
error_arr = ["[Err] ",self.__error_code,"-",self.__error_message];
self.__error = str.join("",error_arr)
raise Exception(self.__error)
'''
返回原生pymysql的连接对象
'''
def get_connect(self):
return self.__db
'''
返回connect连接对象的游标对象 connect.cursor()
'''
def get_cursor(self):
return self.__cursor
'''
将数据插入到数据库 接收参数为 table(表名) rows(dict的数组/单个dict)
'''
def insert(self, table, rows):
insert_sql = self.__insert_sql(rows)
sql_arr = ["INSERT INTO `",table,"` ",insert_sql]
self.__sql = str.join("",sql_arr)
try:
self.__cursor.execute(self.__sql)
self.__affected_rows = self.__db.affected_rows()
self.__last_id = self.__db.insert_id()
return self.__last_id
except Exception as e:
self.__exception(e)
'''
拼接INSERT SQL语句
'''
def __insert_sql(self, rows):
temp_rows = []
if type(rows) is dict:
temp_rows.append(rows)
else:
temp_rows = rows
first_row = temp_rows[0]
column_sql = "(" + self.__get_columns(first_row) + ")"
data_sql_arr = []
for row in temp_rows:
row_values = list(row.values())
for i, v in enumerate(row_values):
temp_str = str(v)
temp_str = temp_str.replace("'", "\'")
row_values[i] = "'" + temp_str + "'"
data_sql = "(" + str.join(",", row_values) + ")"
data_sql_arr.append(data_sql)
data_sql_all = str.join(",", data_sql_arr)
sql = column_sql + " VALUES " + data_sql_all
return sql
'''
获取要查询/插入的列名SQL
'''
def __get_columns(self, row):
keys_arr = row.keys()
keys_arr = list(keys_arr)
for i, v in enumerate(keys_arr):
keys_arr[i] = "`" + v + "`"
column_sql = str.join(",", keys_arr)
return column_sql
'''
删除信息。传递参数 table(表名) where(where条件 为一个dict)
'''
def delete(self, table, where, limit=1):
sql_arr=["DELETE FROM `",table,self.__where_sql(where)," LIMIT ",str(limit)]
self.__sql = str.join("",sql_arr)
try:
self.__cursor.execute(self.__sql)
self.__affected_rows = self.__db.affected_rows()
return self.__db.affected_rows()
except Exception as e:
self.__exception(e)
'''
将where条件的dict解析为SQL语句
'''
def __where_sql(self, where):
if len(where) == 0:
return ""
else:
sql = []
for k, v in where.items():
if k == "OR":
continue
k = k.strip()
explode_arr = k.split(" ")
if type(v) is list:
temp_sql = ""
if len(v) == 0:
sql.append(temp_sql)
else:
in_str = " IN "
if len(explode_arr) == 1:
for v_index, v_value in enumerate(v):
temp_v_value = str(v_value)
temp_v_value = temp_v_value.replace("'", "\'")
v[v_index] = "'" + temp_v_value + "'"
temp_sql = str.join(",", v)
temp_sql = "`" + str(k) + "`" + in_str + " (" + temp_sql + ") "
sql.append(temp_sql)
else:
for v_index, v_value in enumerate(v):
temp_v_value = str(v_value)
temp_v_value = temp_v_value.replace("'", "\'")
v[v_index] = "'" + temp_v_value + "'"
temp_sql = str.join(",", v)
column = explode_arr[0]
del explode_arr[0]
condition = str.join(" ", explode_arr)
temp_sql = " `" + str(column) + "` " + str(condition) + " (" + temp_sql + ") "
sql.append(temp_sql)
else:
if len(explode_arr) >= 2:
temp_v_value = str(v)
temp_v_value = temp_v_value.replace("'", "\'")
column = explode_arr[0]
del explode_arr[0]
condition = str.join(" ", explode_arr)
sql.append(" `" + str(column) + "` " + str(condition) + " '" + str(v) + "' ")
else:
temp_v_value = str(v)
temp_v_value = temp_v_value.replace("'", "\'")
sql.append(" `" + str(k) + "` =" + " " + "'" + temp_v_value + "'" + " ")
if "OR" in where:
return str.join(" OR ", sql)
else:
return str.join(" AND ", sql)
'''
更新数据. table(表名) update_dict(要更新的数据dict) where(where条件 dict)
'''
def update(self, table, update_dict, where, limit=1):
sql = "UPDATE `" + table + "` SET " + self.__update_sql(update_dict)
sql += " WHERE " + self.__where_sql(where) + " LIMIT " + str(limit)
self.__sql = sql
try:
self.__cursor.execute(self.__sql)
self.__affected_rows = self.__db.affected_rows()
return self.__db.affected_rows()
except Exception as e:
self.__exception(e)
'''
获取更新SQL
'''
def __update_sql(self, update_dict):
sql_arr = []
if type(update_dict) is dict:
for k, v in update_dict.items():
sql_arr_item = " `" + str(k) + "` = '" + str(v) + "' "
sql_arr.append(sql_arr_item)
return str.join(",", sql_arr)
else:
return ""
'''
SELECT查询语句。 cols是list,要取的列名集合。 table(表名) where(where条件 dict e.g {"id":1}) id=1
order (排序 dict,e.g: {"id":"DESC","created_at":"ASC"})
offset limit 限制查询条数
'''
def select(self, cols, table, where={}, order={},): # offset=0, limit=1000
need_column = ""
if type(cols) is list and len(cols) != 0:
for i, v in enumerate(cols):
cols[i] = "`" + str(v) + "`"
need_column = str.join(",", cols)
else:
need_column = "*"
order_sql = ""
if type(order) is dict and len(order) != 0:
order_arr = []
for col, sort in order.items():
order_arr_item = " `" + str(col) + "` " + str(sort) + " "
order_arr.append(order_arr_item)
order_sql = str.join(",", order_arr)
order_sql = " ORDER BY " + order_sql
else:
order_sql = ""
where_sql = " "
if len(where):
where_sql = " WHERE " + self.__where_sql(where)
# limit = min(limit,self.MAX_LIMIT)
self.__sql = "SELECT " + need_column + " FROM `" + table + "`" + where_sql
self.__sql += order_sql # + " LIMIT " + str(offset) + "," + str(limit)
try:
self.__cursor.execute(self.__sql)
self.__affected_rows = self.__db.affected_rows()
return self.__cursor.fetchall()
except Exception as e:
self.__exception(e)
'''
执行原生查询SQL语句 select
'''
def query(self, sql):
try:
self.__sql = sql
self.__cursor.execute(self.__sql)
self.__affected_rows = self.__db.affected_rows()
return self.__cursor.fetchall()
except Exception as e:
self.__exception(e)
'''
执行原生操作语句 insert update delete
'''
def execute(self, sql, ret_last_id=False):
try:
self.__sql = sql
self.__cursor.execute(self.__sql)
self.__affected_rows = self.__db.affected_rows()
if ret_last_id:
self.__last_id = self.__db.insert_id()
return self.__last_id
return self.__db.affected_rows()
except Exception as e:
self.__exception(e)
'''
获取本次操作的SQL语句
'''
def count(self, table, where={}):
where_sql = ""
if where:
where_sql = " WHERE "+self.__where_sql(where)
self.__sql = "SELECT COUNT(*) as num FROM `"+table+"` "+where_sql
try:
self.__cursor.execute(self.__sql)
row = self.__cursor.fetchone()
return row["num"]
except Exception as e:
self.__exception(e)
def get_sql(self):
return self.__sql
'''
析构函数 若是debug模式开启 则打印出SQL语句 影响条数 (最后插入的id) 操作执行时间(ms)
'''
def __del__(self):
if self.__debug:
self.__end_time = time.time()
use_time = self.__end_time - self.__start_time
use_time = use_time * 1000
use_time = int(use_time)
# 打印log信息 颜色为青色 错误/时间超过默认1000ms 变为红色
print(" 33[32;0m[SQL] " + self.__sql + " 33[0m") # SQL语句
print(" 33[32;0m[affected_rows] " + str(self.__affected_rows) + " 33[0m") #影响行数
if self.__last_id:
print(" 33[32;0m[last_insert_id] " + str(self.__last_id) + " 33[0m") # 最后插入的id
if use_time < self.__timeout_error_value:
print(" 33[32;0m[time] " + str(use_time) + " ms 33[0m") # 执行时间
else:
print(" 33[31;0m[time] " + str(use_time) + " ms")
if self.__error:
print(" 33[1;31;0m" + self.__error + " 33[0m")