zoukankan      html  css  js  c++  java
  • 频率(限流)

    问题:控制访问频率,在访问的时候加上一定的次数限制

    基本实现

    views.py

    class VisitThrottle(object):
        def allow_request(self, request, view):
            return True       # 可以继续访问
            # return False    # 访问频率太高, 被限制
    
        def wait(self):
            pass
    

    可以进一步的升级,限制 10s 内只能访问3次

    import time
    VISIT_RECORD = {}
    
    class VisitThrottle(object):
        '''
        10s内只能访问3次
        '''
        def allow_request(self, request, view):
            # 1. 获取用户IP
            remote_addr = request.META.get('REMOTE_ADDR')
            ctime = time.time()
            if remote_addr not in VISIT_RECORD:
                VISIT_RECORD[remote_addr] = [ctime, ]
                return True
            history = VISIT_RECORD.get(remote_addr)
    
            while history and history[-1] < ctime - 10:
                history.pop()
    
            if len(history) < 3:
                history.insert(0, ctime)
                return True
                # return True       # 可以继续访问
                # return False      # 访问频率太高, 被限制
    
        def wait(self):
        '''
        还需要等待的时间
        '''
        ctime = time.time()
        return 60 - (ctime - self.history[-1])
    

    源码流程

    和前面一样,也是从 dispatch 开始,到 initial

    def initial(self, request, *args, **kwargs):
        
        # Ensure that the incoming request is permitted
        self.perform_authentication(request)
        self.check_permissions(request)
        # 控制访问频率
        self.check_throttles(request)
    
    def check_throttles(self, request):
        # get_throttles 里面是一个列表生成式
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                self.throttled(request, throttle.wait())
    
    def get_throttles(self):
        """
        Instantiates and returns the list of throttles that this view uses.
        """
        return [throttle() for throttle in self.throttle_classes]
    

    throttle_classes 默认使用配置文件

    class APIView(View):
        ...
        throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
        ...
    

    可以添加到全局使用,首先在 utils 下新建 throttle.py,将视图文件中的类移至 throttle.py,这里修改了 60s内能访问3次

    # throttle.py
    
    import time
    VISIT_RECORD = {}
    
    class VisitThrottle(object):
        '''
        60s内只能访问3次
        '''
        def __init__(self):
            self.history = None
    
        def allow_request(self, request, view):
            # 1. 获取用户IP
            remote_addr = request.META.get('REMOTE_ADDR')
            ctime = time.time()
            if remote_addr not in VISIT_RECORD:
                VISIT_RECORD[remote_addr] = [ctime, ]
                return True
            history = VISIT_RECORD.get(remote_addr)
            self.history = history
    
            while history and history[-1] < ctime - 60:
                history.pop()
    
            if len(history) < 3:
                history.insert(0, ctime)
                return True
                # return True       # 可以继续访问
                # return False      # 访问频率太高, 被限制
    
        def wait(self):
            '''
            还需要等待的时间
            '''
            ctime = time.time()
            return 60 - (ctime - self.history[-1])
    

    然后在配置文件 settings.py 中添加路径

    REST_FRAMEWORK = {
    	...
        'DEFAULT_THROTTLE_CLASSES': ['api.utils.throttle.VisitThrottle']
    }
    

    最后将视图中的局部配置删除即可。

    回到 check_throttles

    def check_throttles(self, request):
        
        for throttle in self.get_throttles():
            # throttle.allow_request 为 False,走下一步,throttled 抛出异常,表示访问频率过多
            if not throttle.allow_request(request, self):
                self.throttled(request, throttle.wait())
    
    def throttled(self, request, wait):
        """
        If request is throttled, determine what kind of exception to raise.
        """
        raise exceptions.Throttled(wait)
    

    频率的内置类

    在自定义频率的时候,为了更加规范,需要继承,并且父类有获取 IP 的方法(可以在 BaseThrottle 中查看),因此这里直接调用父类的方法即可

    from rest_framework.throttling import BaseThrottle
    
    import time
    VISIT_RECORD = {}
    
    class VisitThrottle(BaseThrottle):
        '''
        60s内只能访问3次
        '''
        def __init__(self):
            self.history = None
    
        def allow_request(self, request, view):
            # 1. 获取用户IP,调用父类的方法
            remote_addr = self.get_ident(request)
            
            ctime = time.time()
            if remote_addr not in VISIT_RECORD:
                VISIT_RECORD[remote_addr] = [ctime, ]
                return True
            history = VISIT_RECORD.get(remote_addr)
            self.history = history
    
            while history and history[-1] < ctime - 60:
                history.pop()
    
            if len(history) < 3:
                history.insert(0, ctime)
                return True
                # return True       # 可以继续访问
                # return False      # 访问频率太高, 被限制
    
        def wait(self):
            '''
            还需要等待的时间
            '''
            ctime = time.time()
            return 60 - (ctime - self.history[-1])
    

    进入 BaseThrottle ,发现在其下方有个 SimpleRateThrottle ,也是继承 BaseThrottle 。首先看 SimpleRateThrottle__init__ 方法

    class SimpleRateThrottle(BaseThrottle):
        ... # 省略的内容
        scope = None
        THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
    
        def __init__(self):
            if not getattr(self, 'rate', None):
                # 这里执行了 get_rate 方法
                self.rate = self.get_rate()
            self.num_requests, self.duration = self.parse_rate(self.rate)
    
    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        """
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)
    
        try:
            # scope实际上是一个字典的 key,这里在 THROTTLE_RATES 中取值
            # 在上面的代码中看到 THROTTLE_RATES 是一个配置项,获取用户自定义的配置
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)
    

    至此,就可以在配置文件中写一个 60s内能访问3次 的程序,让它自动完成,无需自定义写

    throttle.py

    class VisitThrottle(SimpleRateThrottle):
        scope = "xi"	# scope作为key使用
    

    settings.py

    REST_FRAMEWORK = {
        ... # 省略
        'DEFAULT_THROTTLE_CLASSES': ['api.utils.throttle.VisitThrottle'],
        'DEFAULT_THROTTLE_RATES' : {
            'xi': '3/m'		# m是分钟,每分钟访问3次
        }
    }
    

    这时,配置了访问次数,就会在 return self.THROTTLE_RATES[self.scope] 中获取到,返回给 get_rate 方法,然后 __init__ 中的 rate 拿到的就是 3/m

    class SimpleRateThrottle(BaseThrottle):
        ... # 省略的内容
        scope = None
        THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
    
        def __init__(self):
            if not getattr(self, 'rate', None):
                # '3/m'
                self.rate = self.get_rate()
            # 将字符串 '3/m' 当做参数传递给 parse_rate
            # 走完 parse_rate,num_requests代表3次,duration代表60s
            self.num_requests, self.duration = self.parse_rate(self.rate)
            
        .... # 省略
        
        def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        <allowed number of requests>, <period of time in seconds>
        """
        
        # rate就是 '3/m'
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)
    

    此时,构造函数走完,接着查看 allow_request

    def allow_request(self, request, view):
        
        if self.rate is None:
            return True
    	
        # 内置提供的访问记录放在了缓存中,通过 get_cache_key 实现
        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True
    	
        self.history = self.cache.get(self.key, [])
        self.now = self.timer()
    
    # 来到 get_cache_key,源码上并没有写什么,这表示是让我们自己写的
    
    def get_cache_key(self, request, view):
        
        raise NotImplementedError('.get_cache_key() must be overridden')
    
    # get_cache_key 实际上是表示能够唯一标识的方法,所以返回值可以是获取IP,用来表示谁的访问记录
    # throttle.py
    
    class VisitThrottle(SimpleRateThrottle):
        scope = "xi"
    
        def get_cache_key(self, request, view):
            return self.get_ident(request)	# 获取IP
    

    回到 allow_request

    def allow_request(self, request, view):
        
        if self.rate is None:
            return True
    	
        # 内置提供的访问记录放在了缓存中,通过 get_cache_key 实现
        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True
        
    	# 去缓存中取出所有记录
        # cache = default_cache,是django内置的缓存
        self.history = self.cache.get(self.key, [])
        self.now = self.timer()	 # timer() = time.time(),获取当前时间
        
        # Drop any requests from the history which have now passed the
        # throttle duration
        # 这里与上面自定义的相同
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()
    
    
    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        # 如果成功,加到历史记录中
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True
    
    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False
    
    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration
    
        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None
    
        return remaining_duration / float(available_requests)
    

    照样是前三次可以访问,后面再访问需要等一分钟,这是对匿名用户的控制

    也可以对登录的用户进行控制,但在全局的设置中,不能既有匿名的,还有登录的。这时,就可以将登录用户的访问控制设为全局,匿名用户使用局部的设置。

    settings.py

    REST_FRAMEWORK = {
        'DEFAULT_AUTHENTICATION_CLASSES': ['api.utils.auth.FirstAuthentication', 'api.utils.auth.Authentication'],
        # 'DEFAULT_AUTHENTICATION_CLASSES': ['api.utils.auth.FirstAuthentication', ],
        'UNAUTHENTICATED_USER': None,
        'UNAUTHENTICATED_TOKEN': None,
        'DEFAULT_PERMISSION_CLASSES': ['api.utils.permission.SVIPPermission'],
        'DEFAULT_THROTTLE_CLASSES': ['api.utils.throttle.UserThrottle'],	# 登录用户
        'DEFAULT_THROTTLE_RATES' : {
            'xi': '3/m',
            'xiUser': '10/m'
        }
    }
    

    throttle.py

    # 匿名用户
    class VisitThrottle(SimpleRateThrottle):
        scope = "xi"
    
        def get_cache_key(self, request, view):
            return self.get_ident(request)
    
    # 登录用户
    class UserThrottle(SimpleRateThrottle):
        scope = "xiUser"
    
        def get_cache_key(self, request, view):
            return request.user.username
    

    views.py

    from django.shortcuts import render, HttpResponse
    from django.http import JsonResponse
    from rest_framework.views import APIView
    from api import models
    from api.utils.permission import SVIPPermission, MyPermission
    from api.utils.throttle import VisitThrottle
    
    ORDER_DICT = {
        1: {
            'name': 'qiu',
            'age': 18,
            'gender': '男',
            'content': '...'
        },
    
        2: {
            'name': 'xi',
            'age': 19,
            'gender': '男',
            'content': '.....'
        }
    }
    
    def md5(user):
        import hashlib
        import time
    
        ctime = str(time.time())
    
        m = hashlib.md5(bytes(user, encoding='utf-8'))
        m.update(bytes(ctime, encoding='utf-8'))
        return m.hexdigest()
    
    
    class AuthView(APIView):
        authentication_classes = []
        permission_classes = []
        throttle_classes = [VisitThrottle]	# 为匿名用户设置频率控制
    
        def post(self, request, *args, **kwargs):
            ret = {'code': 1000, 'msg': None}
            try:
                user = request._request.POST.get('username')
                pwd = request._request.POST.get('password')
                obj = models.UerInfo.objects.filter(username=user, password=pwd).first()
                if not obj:
                    ret['code'] = 1001
                    ret['msg'] = '用户名或密码错误'
                # 为登录用户创建token
                else:
                    token = md5(user)
                    # 存在就更新, 不存在就创建
                    models.UserToken.objects.update_or_create(user=obj, defaults={'token': token})
                    ret['token'] = token
            except Exception as e:
                ret['code'] = 1002
                ret['msg'] = '请求异常'
    
            return JsonResponse(ret)
    
    class OrderView(APIView):
        '''
        订单相关业务(只有SVIP用户有权限)
        '''
        def get(self, request, *args, **kwargs):
            ret = {'code': 1000, 'msg': None, 'data': None}
            try:
                ret['data'] = ORDER_DICT
            except Exception as e:
                pass
            return JsonResponse(ret)
    
    class UserInfoView(APIView):
        '''
        用户中心(普通用户、VIP有权限)
        '''
        permission_classes = [MyPermission]
        def get(self, request, *args, **kwargs):
            return HttpResponse('用户信息')
    

    总结

    使用

    • 类,继承 BaseThrottle ,实现 allow_requestwait

    • 类,继承 SimpleRateThrottle ,实现 get_cache_keyscope = "xi"(配置文件中的key)

    • 局部:throttle_classes = [VisitThrottle]

    • 全局:配置 settings.py

  • 相关阅读:
    无需数学基础如进行机器学习
    机器学习路线图
    机器学习的最佳学习路线原来只有四步
    机器学习是否需要完整扎实的数学基础?
    可无注解的 SpringBoot API文档生成工具
    JApiDocs是一个无需额外注解、开箱即用的SpringBoot接口文档生成工具
    python 两个文件夹里的文件名对比
    Navicat for MySQL 激活方法
    mysql —— 利用Navicat 导出和导入数据库
    HTTP请求错误码大全(转)
  • 原文地址:https://www.cnblogs.com/qiuxirufeng/p/10458785.html
Copyright © 2011-2022 走看看