Flask之上下文管理
知识储备之问题情境:
request中的参数:
- 单进程单线程
- 单进程多线程-->reqeust 会因为多个请求,数据发生错乱.--->可以基于threading.local对象
- 单进程单线程(多协程)threading.local对象做不到(因为一个线程下多个协程同享一个线程的资源)
解决办法:
自定义类似threading.local对象(支持协程)---保证多协程下数据的安全
先来看一下下面这段代码(支持多线程):
# -*- coding: utf-8 -*-
"""
1288::{}
"""
from _thread import get_ident
import threading
class Local(object):
def __init__(self):
self.storage = {}
self.get_ident = get_ident
# 设置值
def set(self, k, v):
# 获取线程的唯一标识
ident = self.get_ident()
# 通过唯一标识去字典里面取值
origin = self.storage.get(ident)
if not origin:
origin = {k: v}
else:
origin[k] = v
# 将k,v 保存到 storage中 形式如下
# {
# 1023:{k,v}, # self.storage[ident] = origin 所添加的值
# 1045:{k1,v1} # 原先storage中有的值
# }
self.storage[ident] = origin
# 获取值
def get(self, k):
ident = self.get_ident()
origin = self.storage.get(ident)
if not origin:
return None
return origin.get(k, None)
# 获取一个线程对象
local_obj = Local()
# 获取每一个线程的唯一标识
def task(num):
local_obj.set('name',num)
import time
time.sleep(1)
print(local_obj.get('name'),threading.current_thread().name)
for i in range(20):
th = threading.Thread(target=task, args=(i,), name='线程%s' % i)
th.start()
"""
0 线程0
1 线程1
2 线程2
5 线程5
6 线程6
3 线程3
4 线程4
10 线程10
9 线程9
11 线程11
7 线程7
13 线程13
14 线程14
17 线程17
18 线程18
15 线程15
19 线程19
8 线程8
12 线程12
16 线程16
"""
再进一步,支持协程
# 首先需要安装依赖
pip3 intall gevent
# gevent 依赖安装 greenlet 可以获取协程的唯一标识
# -*- coding: utf-8 -*-
"""
1288::{
}
"""
try:
# 优先用协程的
# 如果是单线程多协程,导入获取协程唯一标识的
from greenlet import getcurrent as get_ident # 协程
except ImportError:
try:
# 如果是多线程导入获取线程唯一标识的
from thread import get_ident
except ImportError:
# 如果是多线程导入获取线程唯一标识的
from _thread import get_ident # 线程
class Local(object):
def __init__(self):
self.storage = {}
self.get_ident = get_ident
# 设置值
def set(self, k, v):
# 获取线程的唯一标识
ident = self.get_ident()
# 通过唯一标识去字典里面取值
origin = self.storage.get(ident)
if not origin:
origin = {k: v}
else:
origin[k] = v
# 将k,v 保存到 storage中 形式如下
# {
# 1023:{k,v}, # self.storage[ident] = origin 所添加的值
# 1045:{k1,v1} # 原先storage中有的值
# }
self.storage[ident] = origin
# 获取值
def get(self, k):
ident = self.get_ident()
origin = self.storage.get(ident)
if not origin:
return None
return origin.get(k, None)
# 获取一个线程对象
local_obj = Local()
# 获取每一个线程的唯一标识
def task(num):
local_obj.set('name', num)
import time
time.sleep(1)
print(local_obj.get('name'), threading.current_thread().name)
for i in range(20):
th = threading.Thread(target=task, args=(i,), name='线/协程%s' % i)
th.start()
flask中实现的方式
flask中运用了面向对象的一些方法重试简化了实现方式
先补充了解面向对象的姿势:
class Foo():
# 在执行 对象.属性 = 值的时候执行,这里可以写赋值操作
def __setattr__(self,key,value):
print(key,value)
# 在执行 对象.属性的时候,执行, 这里可以写获取对象的属性
def __getattr__(self, item):
print(item)
foo = Foo()
foo.x = 123
foo.x
但是还是有点问题 上面写法: 如果在 初始化操作的时候,会出现递归问题
class Foo():
def __init__(self):
self.storage ={}
def __setattr__(self,key,value):
self.storage = {'k':'v'}
print(key,value)
def __getattr__(self, item):
print(item)
foo = Foo()
foo.x = 123
foo.x
"""
上述办法 会在 __setattr__ 这里产生递归
self.storage = {'k':'v'}
[Previous line repeated 327 more times]
RecursionError: maximum recursion depth exceeded
"""
解决办法
class Foo(object):
def __init__(self):
object.__setattr__(self, "storage", {})
# self.storage = {}
def __setattr__(self, key, value):
storage = self.storage
storage['1024'] = {key: value}
print(storage)
def __getattr__(self, item):
print(item)
"""
{'1024': {'x': 123}}
x
"""
上述问题 接近源码的做法实现一个支持协程线程的自定义类似threading.local 对象
# -*- coding: utf-8 -*-
"""
模仿
flask中运用了一些面向对象的方法: __getattr__,__setattr__
"""
import threading
try:
# 优先用协程的
# 如果是单线程多协程,导入获取协程唯一标识的
from greenlet import getcurrent as get_ident # 协程
except ImportError:
try:
# 如果是多线程导入获取线程唯一标识的
from thread import get_ident
except ImportError:
# 如果是多线程导入获取线程唯一标识的
from _thread import get_ident # 线程
class Local(object):
def __init__(self):
object.__setattr__(self, "__storage__", {})
object.__setattr__(self, "__ident_func__", get_ident)
def __getattr__(self, name):
try:
return self.__storage__[self.__ident_func__()][name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name, value):
ident = self.__ident_func__()
storage = self.__storage__
try:
storage[ident][name] = value
except KeyError:
storage[ident] = {name: value}
# 获取一个线程对象
local_obj = Local()
# 获取每一个线程的唯一标识
def task(num):
local_obj.name = num
import time
time.sleep(1)
print(local_obj.name, threading.current_thread().name)
for i in range(20):
th = threading.Thread(target=task, args=(i,), name='线程%s' % i)
th.start()
"""
0 线程0
3 线程3
4 线程4
1 线程1
2 线程2
8 线程8
7 线程7
5 线程5
6 线程6
10 线程10
9 线程9
11 线程11
12 线程12
15 线程15
14 线程14
13 线程13
19 线程19
16 线程16
18 线程18
17 线程17
"""
flask 源码实现方式
try:
from greenlet import getcurrent as get_ident
except ImportError:
try:
from thread import get_ident
except ImportError:
from _thread import get_ident
class Local(object):
def __init__(self):
"""当类 实例化产生函数的时候初始化的时候被调用"""
object.__setattr__(self, "__storage__", {})
object.__setattr__(self, "__ident_func__", get_ident)
def __call__(self, proxy):
"""
当类实例化的对象 被 调用的时候执行该函数
"""
"""Create a proxy for a name."""
return LocalProxy(self, proxy)
def __release_local__(self):
self.__storage__.pop(self.__ident_func__(), None)
def __getattr__(self, name):
"""定义当用户试图获取一个不存在的属性时的行为"""
try:
return self.__storage__[self.__ident_func__()][name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name, value):
"""定义当一个属性被设置时的行为"""
ident = self.__ident_func__()
storage = self.__storage__
try:
storage[ident][name] = value
except KeyError:
storage[ident] = {name: value}
def __delattr__(self, name):
"""定义当一个属性被删除时的行为"""
try:
del self.__storage__[self.__ident_func__()][name]
except KeyError:
raise AttributeError(name)
PS: flask 中保存请求相关 session相关的对象的在并发的时候的不同(保证数据的安全),都是基于这个 threading.local 实现的