比如我们有一个用户大转盘抽奖的功能,需要规定用户在一个小时内只能抽奖3次,那此时对接口的访问频率限制就显得尤为重要
其实在restframework中已经为我们提供了频率限制的组件
先捋一下请求到APIview的过程:
as_view-->dispatch -->initialize_request-->initial-->perform_authentication-->check_permissions-->check_throttles(就是在这里实现了频率限制)
1 def initial(self, request, *args, **kwargs): 2 """ 3 Runs anything that needs to occur prior to calling the method handler. 4 """ 5 self.format_kwarg = self.get_format_suffix(**kwargs) 6 7 # Perform content negotiation and store the accepted info on the request 8 neg = self.perform_content_negotiation(request) 9 request.accepted_renderer, request.accepted_media_type = neg 10 11 # Determine the API version, if versioning is in use. 12 version, scheme = self.determine_version(request, *args, **kwargs) 13 request.version, request.versioning_scheme = version, scheme 14 15 # Ensure that the incoming request is permitted 16 17 # 身份验证 18 self.perform_authentication(request) 19 # 权限验证 20 self.check_permissions(request) 21 # 访问频率限制 22 self.check_throttles(request)
那check_throttles到底做了什么呢?
1 def check_throttles(self, request): 2 """ 3 Check if request should be throttled. 4 Raises an appropriate exception if the request is throttled. 5 """ 6 for throttle in self.get_throttles(): 7 if not throttle.allow_request(request, self): 8 self.throttled(request, throttle.wait())
其实和check_permissions很相似,分为以下几个步骤:
1. self.get_throttles() 通过列表推导式拿到了注册的throttle类,并将其实例化返回
1 def get_throttles(self): 2 """ 3 Instantiates and returns the list of throttles that this view uses. 4 """ 5 return [throttle() for throttle in self.throttle_classes]
2. throttle.allow_request说明throttle类中一定要实现allow_request方法,并且返回值为True表示正确允许访问,就执行下次循环,检查下一个频率控制对象
按照之前的套路,throttle组件中应该有个基础的throttle类,找一下:
1 class BaseThrottle(object): 2 """ 3 Rate throttling of requests. 4 """ 5 6 def allow_request(self, request, view): 7 """ 8 Return `True` if the request should be allowed, `False` otherwise. 9 """ 10 raise NotImplementedError('.allow_request() must be overridden') 11 12 def get_ident(self, request): 13 """ 14 Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR 15 if present and number of proxies is > 0. If not use all of 16 HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. 17 """ 18 xff = request.META.get('HTTP_X_FORWARDED_FOR') 19 remote_addr = request.META.get('REMOTE_ADDR') 20 num_proxies = api_settings.NUM_PROXIES 21 22 if num_proxies is not None: 23 if num_proxies == 0 or xff is None: 24 return remote_addr 25 addrs = xff.split(',') 26 client_addr = addrs[-min(num_proxies, len(addrs))] 27 return client_addr.strip() 28 29 return ''.join(xff.split()) if xff else remote_addr 30 31 def wait(self): 32 """ 33 Optionally, return a recommended number of seconds to wait before 34 the next request. 35 """ 36 return None
3. 如果返回值为False,就执行 self.throttled
# 抛出Throttled异常
1 def throttled(self, request, wait): 2 """ 3 If request is throttled, determine what kind of exception to raise. 4 """ 5 raise exceptions.Throttled(wait)
# Throttled异常类
1 class Throttled(APIException): 2 status_code = status.HTTP_429_TOO_MANY_REQUESTS 3 default_detail = _('Request was throttled.') 4 extra_detail_singular = 'Expected available in {wait} second.' 5 extra_detail_plural = 'Expected available in {wait} seconds.' 6 default_code = 'throttled' 7 8 def __init__(self, wait=None, detail=None, code=None): 9 if detail is None: 10 detail = force_text(self.default_detail) 11 if wait is not None: 12 wait = math.ceil(wait) 13 detail = ' '.join(( 14 detail, 15 force_text(ungettext(self.extra_detail_singular.format(wait=wait), 16 self.extra_detail_plural.format(wait=wait), 17 wait)))) 18 self.wait = wait 19 super(Throttled, self).__init__(detail, code)
实现:
一般来说,接口如果不做登录限制,那就会允许匿名用户和已登录用户都能访问。所以这个接口就要考虑能对匿名用户和登录用户都进行访问频率限制:
思路:
已经登录用户可以根据身份做判断,固定时间内,同一个用户的身份只能访问限定次数
未登录用户可通过IP地址判断,对同一个IP的请求进行限制
1 class MyThrottle(BaseThrottle): 2 ctime = time.time 3 4 def get_ident(self, request): 5 6 """ 7 reuqets.user有值,不是匿名用户 8 根据用户IP和代理IP,当做请求者的唯一IP 9 Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR 10 if present and number of proxies is > 0. If not use all of 11 HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. 12 """ 13 14 user = request.user 15 if user: 16 # 有用户身份,直接返回用户 17 return user 18 19 xff = request.META.get('HTTP_X_FORWARDED_FOR') 20 remote_addr = request.META.get('REMOTE_ADDR') 21 num_proxies = api_settings.NUM_PROXIES 22 23 if num_proxies is not None: 24 if num_proxies == 0 or xff is None: 25 return remote_addr 26 addrs = xff.split(',') 27 client_addr = addrs[-min(num_proxies, len(addrs))] 28 return client_addr.strip() 29 30 return ''.join(xff.split()) if xff else remote_addr 31 32 def allow_request(self, request, view): 33 """ 34 是否仍然在允许范围内 35 Return `True` if the request should be allowed, `False` otherwise. 36 :param request: 37 :param view: 38 :return: True,表示可以通过;False表示已超过限制,不允许访问 39 """ 40 # 获取用户唯一标识(如:IP) 41 42 # 允许一分钟访问10次 43 num_request = 10 44 time_request = 60 45 46 now = self.ctime() 47 ident = self.get_ident(request) 48 self.ident = ident 49 if ident not in RECORD: 50 RECORD[ident] = [now, ] 51 return True 52 history = RECORD[ident] 53 while history and history[-1] <= now - time_request: 54 history.pop() 55 if len(history) < num_request: 56 history.insert(0, now) 57 return True 58 59 def wait(self): 60 """ 61 多少秒后可以允许继续访问 62 Optionally, return a recommended number of seconds to wait before 63 the next request. 64 """ 65 last_time = RECORD[self.ident][0] 66 now = self.ctime() 67 return int(60 + last_time - now)
1 class MemberPrograms(APIView): 2 throttle_classes = [MyThrottle, ] 3 4 def get(self, request): 5 programs = MemberProgram.objects.all().values() 6 return JsonResponse(list(programs), safe=False)
测试:
其实restframework已经帮我们实现了一些简单的频率限制类 我们只需要稍加修改,比如SimpleRateThrottle
class SimpleRateThrottle(BaseThrottle):
""" A simple cache implementation, that only requires `.get_cache_key()` to be overridden. The rate (requests / seconds) is set by a `throttle` attribute on the View class. The attribute is a string of the form 'number_of_requests/period'. Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') Previous request information used for throttling is stored in the cache. """ cache = default_cache timer = time.time cache_format = 'throttle_%(scope)s_%(ident)s' scope = None # 频率的key名 # 必须设置 THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES # {scope:rate} 比如:{‘scopre’:10/m} def __init__(self): if not getattr(self, 'rate', None): self.rate = self.get_rate() self.num_requests, self.duration = self.parse_rate(self.rate) def get_cache_key(self, request, view): """ 必须重写,返回一个唯一的身份值作为缓存的key Should return a unique cache-key which can be used for throttling. Must be overridden. May return `None` if the request should not be throttled. """ raise NotImplementedError('.get_cache_key() must be overridden') def get_rate(self): """ Determine the string representation of the allowed request rate. """ if not getattr(self, 'scope', None): # 那不到scope就抛出异常 msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % self.__class__.__name__) raise ImproperlyConfigured(msg) try: # 从{scope:rate}中尝试取rate return self.THROTTLE_RATES[self.scope] except KeyError: # 取不到就抛出异常,所以rate也必须设置 msg = "No default throttle rate set for '%s' scope" % self.scope raise ImproperlyConfigured(msg) def parse_rate(self, rate): """ Given the request rate string, return a two tuple of: <allowed number of requests>, <period of time in seconds> """ if rate is None:
# 未设置频率就说明不做限制 return (None, None) num, period = rate.split('/') num_requests = int(num) duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] return (num_requests, duration) def allow_request(self, request, view): """ Implement the check to see if the request should be throttled. On success calls `throttle_success`. On failure calls `throttle_failure`. """ if self.rate is None:
# 没有频率限制 return True self.key = self.get_cache_key(request, view) if self.key is None:
# 没有key,说明没有访问记录,允许访问 return True
# 获取历史请求时间列表 self.history = self.cache.get(self.key, [])
获取当前时间 self.now = self.timer() # Drop any requests from the history which have now passed the # throttle duration
# 如果历史访问时间列表有记录,并且列表记录中最早的访问时间小于当前时间-限制时间,说明已经过了限制时间
# 例如,假设请求都在同一分钟内比较容易理解:list = [4,10,23] 表示在第4,10,25秒分表访问了一次
# 当前时间56秒 限制是20秒内3次:
# 56-20 = 25 只要列表最后一个元素小于25,那说明已经过了限制时间,距离最早的一次访问,就删除列表中的23: [4,10]
# 持续循环检查,直到[]为空,或者这次请求距离列表最早的请求小于频率限制时间
while self.history and self.history[-1] <= self.now - self.duration: # 删除掉这条记录,pop()删除列表最后一个元素
self.history.pop()
# 如果列表中的访问时间记录次数等于限制次数,说明没有被pop掉的记录,访问请求失败,否则请求成功 if len(self.history) >= self.num_requests: return self.throttle_failure() return self.throttle_success() def throttle_success(self): """ Inserts the current request's timestamp along with the key into the cache. """
# 成功,就将当前时间插入访问记录列表的最前面
self.history.insert(0, self.now) self.cache.set(self.key, self.history, self.duration) return True def throttle_failure(self): """ Called when a request to the API has failed due to throttling. """ return False def wait(self): """ Returns the recommended next request time in seconds. """ if self.history: remaining_duration = self.duration - (self.now - self.history[-1]) else: remaining_duration = self.duration available_requests = self.num_requests - len(self.history) + 1 if available_requests <= 0: return None return remaining_duration / float(available_requests)
通过继承SimpleRateThrottle类实现:
class MyThrottle(SimpleRateThrottle): rate = '10/m' # 每分钟只能访问10次 def get_cache_key(self, request, view): user = request.user if user: return user xff = request.META.get('HTTP_X_FORWARDED_FOR') remote_addr = request.META.get('REMOTE_ADDR') num_proxies = api_settings.NUM_PROXIES if num_proxies is not None: if num_proxies == 0 or xff is None: return remote_addr addrs = xff.split(',') client_addr = addrs[-min(num_proxies, len(addrs))] return client_addr.strip()
在setting.py中配置:
REST_FRAMEWORK = { "DEFAULT_THROTTLE_CLASSES":[ permissions.utils.MyThrottle, ], 'DEFAULT_THROTTLE_RATES':{ 'scope':'10/minute', } }