zoukankan      html  css  js  c++  java
  • python-django rest framework框架之dispatch方法源码分析

    1.Django的 CBV 中在请求到来之后,都要执行dispatch方法,dispatch方法根据请求方式不同触发 get/post/put等方法

    class APIView(View):
        def dispatch(self, request, *args, **kwargs):#1.1 把wsgi的request进行封装
            request = self.initialize_request(request, *args, **kwargs)
            self.request = request     #此时的self.request 是rest_framework的Request对象,它里面比wsgi的request多了一些东西
    
            try:
                #1.2 进行 初始化     :版本,认证,权限,访问频率
                self.initial(request, *args, **kwargs)
    #1.3 反射执行get等方法 if request.method.lower() in self.http_method_names: handler = getattr(self, request.method.lower(), self.http_method_not_allowed) else: handler = self.http_method_not_allowed response = handler(request, *args, **kwargs) except Exception as exc: response = self.handle_exception(exc) #1.4 把返回的response 再进行封装 self.response = self.finalize_response(request, response, *args, **kwargs) #1.5 返回 , dispatch方法一定要有返回值,因为get等方法返回的结果要返回给前端 return self.response

    第1.1步:

    from rest_framework.request import Request
    
    class APIView(View):
        def initialize_request(self, request, *args, **kwargs):
           #返回了一个rest_framework的Request对象
            return Request(
                request,
           #1.1.1 parsers
    =self.get_parsers(),
           #1.1.2 authenticators
    =self.get_authenticators(),
           #1.1.3 negotiator
    =self.get_content_negotiator(), )

    第1.1.1步:

      pass

    第1.1.2步:

    class APIView(View):
        def get_authenticators(self):
         #self.authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES 是从配置文件中取数据,从变量名不难看出有可能是很多类的列表
    return [auth() for auth in self.authentication_classes]
         #self.authenticators = 一个多个对象的列表

    第1.2步:

    class APIView(View):
        def initial(self, request, *args, **kwargs):
         #1.2.0版本相关
         version, scheme = self.determine_version(request, *args, **kwargs)
         request.version, request.versioning_scheme = version, scheme
          # request.version, request.versioning_scheme = 版本,和相应的版本类对象
          # 1.2.0.0
          # 既然 request.version 是版本, 那request.versioning_scheme 代表什么呢?

    #1.2.1认证相关 self.perform_authentication(request) #1.2.2权限相关 self.check_permissions(request) #1.2.3访问频率相关 self.check_throttles(request)

    第1.2.0步:

    获取版本数据有五种接收方式:url上传参,url,子域名,namespace,请求头

    from rest_framework.versioning import QueryParameterVersioning,URLPathVersioning,HostNameVersioning  
       方式一:
        # 基于url传参  http://127.0.0.1:8001/api/users/?version=v1
        # versioning_class = QueryParameterVersioning

        方式二:
        # 基于URL http://127.0.0.1:8001/api/v2/users/
        # versioning_class = URLPathVersioning
         # url 配置: url(r'^api/(?P<version>[v1|v2]+)/', include('api.urls')),
        方式三:
        # 基于子域名 http://v1.luffy.com/users/
        # versioning_class = HostNameVersioning
      
     配置文件:
      REST_FRAMEWORK = {
      'VERSION_PARAM':'version',
      'DEFAULT_VERSION':'v1',
      'ALLOWED_VERSIONS':['v1','v2'],
      # 'DEFAULT_VERSIONING_CLASS':"rest_framework.versioning.HostNameVersioning",
      'DEFAULT_VERSIONING_CLASS':"rest_framework.versioning.URLPathVersioning",
      # 'DEFAULT_VERSIONING_CLASS':"rest_framework.versioning.QueryParameterVersioning",
      }
     
    class APIView(View):
        def determine_version(self, request, *args, **kwargs):
            if self.versioning_class is None:     # self.versioning_class = api_settings.DEFAULT_VERSIONING_CLASS
            #如果没有配置 版本相关的类
                return (None, None)
            scheme = self.versioning_class()      # 实例化版本类对象
            #调用对象的determine_version 方法
            return (scheme.determine_version(request, *args, **kwargs), scheme)

    方式一的 determine_version 方法:

    class QueryParameterVersioning(BaseVersioning):
       invalid_version_message = _('Invalid version in query parameter.')
    def determine_version(self, request, *args, **kwargs): # 相当于 request.GET.get() version = request.query_params.get(self.version_param, self.default_version) #self.version_param = api_settings.VERSION_PARAM #self.default_version = api_settings.DEFAULT_VERSION
         # 见下面
    if not self.is_allowed_version(version): raise exceptions.NotFound(self.invalid_version_message) return version

    class BaseVersioning(object):
        def is_allowed_version(self, version):
         #如果没有配置允许的版本列表,代表没限制
            if not self.allowed_versions: #self.allowed_versions = api_settings.ALLOWED_VERSIONS
                return True
            return ((version is not None and version == self.default_version) or
                    (version in self.allowed_versions))

    方式二的 determine_version 方法:

    class URLPathVersioning(BaseVersioning):
        invalid_version_message = _('Invalid version in URL path.')
        
        def determine_version(self, request, *args, **kwargs):
            #从传过来的字典中获取版本
            version = kwargs.get(self.version_param, self.default_version)
            
            if not self.is_allowed_version(version):
                raise exceptions.NotFound(self.invalid_version_message)
            return version
            
            
    class BaseVersioning(object):
        def is_allowed_version(self, version):
            if not self.allowed_versions:
                return True
            return ((version is not None and version == self.default_version) or
                    (version in self.allowed_versions))

    方式三的 determine_version 方法:

    class HostNameVersioning(BaseVersioning):
        hostname_regex = re.compile(r'^([a-zA-Z0-9]+).[a-zA-Z0-9]+.[a-zA-Z0-9]+$')
        invalid_version_message = _('Invalid version in hostname.')
        
        def determine_version(self, request, *args, **kwargs):
            hostname, separator, port = request.get_host().partition(':')
            match = self.hostname_regex.match(hostname)
            if not match:
                return self.default_version
            version = match.group(1)
            if not self.is_allowed_version(version):
                raise exceptions.NotFound(self.invalid_version_message)
            return version
            
    class BaseVersioning(object):
        def is_allowed_version(self, version):
            if not self.allowed_versions:
                return True
            return ((version is not None and version == self.default_version) or
                    (version in self.allowed_versions))

    第1.2.0.0步:     request.versioning_scheme  是用来反向生成url 的

            urlpatterns = [
                url(r'^api/(?P<version>[v1|v2]+)/', include('api.urls')),
                url(r'^api/', include('api.urls')),
            ]
        
            urlpatterns = [
                url(r'^users/', views.UsersView.as_view(),name='u'),
            ]
            
            
            # 当前版本一样的URL
            # url = request.versioning_scheme.reverse(viewname='u',request=request)    #不用传版本的参数
            # print(url)
    
            # 当前版本不一样的URL
            # from django.urls import reverse
            # url = reverse(viewname='u',kwargs={'version':'v2'})
            # print(url)

     方式一的 reverse方法

    class QueryParameterVersioning(BaseVersioning):
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            #调用基类的 reverse方法
            url = super(QueryParameterVersioning, self).reverse(
                viewname, args, kwargs, request, format, **extra
            )
            # 之所以在反向生成的时候不用传版本参数, 是因为这步帮你处理了
            if request.version is not None:
                return replace_query_param(url, self.version_param, request.version)
            return url
            
    class BaseVersioning(object):    
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            return _reverse(viewname, args, kwargs, request, format, **extra)
            
    def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
    
        if format is not None:
            kwargs = kwargs or {}
            kwargs['format'] = format
        # 最后还是调用了 django的 reverse方法 生成url
        url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
        if request:
            return request.build_absolute_uri(url)
        return url

     方式二的 reverse方法

    class URLPathVersioning(BaseVersioning):
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            if request.version is not None:
                kwargs = {} if (kwargs is None) else kwargs
                #给 version赋值
                kwargs[self.version_param] = request.version
    
            #调用基类的reverse方法
            return super(URLPathVersioning, self).reverse(
                viewname, args, kwargs, request, format, **extra
            )
            
    class BaseVersioning(object):
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            return _reverse(viewname, args, kwargs, request, format, **extra)
            
            
    def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
    
        if format is not None:
            kwargs = kwargs or {}
            kwargs['format'] = format
        #最后还是调用了django的reverse方法生成url
        url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
        if request:
            return request.build_absolute_uri(url)
        return url

    方式三的 reverse方法

    方式三 HostNameVersioning 类中没有 reverse方法,所以直接去基类中找
    class BaseVersioning(object):
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            return _reverse(viewname, args, kwargs, request, format, **extra)
             
    def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
        if format is not None:
            kwargs = kwargs or {}
            kwargs['format'] = format
        #最后还是调用了django的reverse方法生成url
        url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
        if request:
            return request.build_absolute_uri(url)
        return url

    第1.2.1步: 认证相关

    class APIView(View):
        def perform_authentication(self, request):
         #1.2.1.1
            request.user

    第1.2.1.1步:

    class Request(object):    
        @property
        def user(self):       #此时的self 是 rest_framework的request对象
            if not hasattr(self, '_user'):
                with wrap_attributeerrors():
              #1.2.1.1.1 self._authenticate()
    return self._user

    1.2.1.1.1步:

    class Request(object):    
        def _authenticate(self):
            #此时的self 是 rest_framework的request对象
    for authenticator in self.authenticators: #self.authenticators = 一个多个对象的列表 try:
              #执行每个对象的authenticate方法 user_auth_tuple
    = authenticator.authenticate(self)   #从变量的名不难看出 返回了一个元组 except exceptions.APIException: self._not_authenticated() raise if user_auth_tuple is not None: self._authenticator = authenticator
              #赋值, request.user和request.auth 并返回 self.user, self.auth
    = user_auth_tuple return self._not_authenticated()

    第1.3步反射执行get等方法

    我们可以自定义一个简单的用户认证

    class MyAuth(object):
        def authenticate(self,request):
            return "1111","222"
        
    class Host(APIView):
        authentication_classes=[MyAuth]
        def get(self,request):
            print(request.user)    #1111
            print(request.auth)   #222
            return HttpResponse("666")

    认证

    - 认证
                    - 局部  : 只是一个类内的一些接口用
                        class MyAuthentication(BaseAuthentication):
    
                            def authenticate(self, request):
                     '''
                     有三种返回值: None 表示我不管,交给下一个进行认证, 元组表示认证成功, 抛出异常
                     '''
    # return None ,我不管,交给下一个进行认证 token = request.query_params.get('token') obj = models.UserInfo.objects.filter(token=token).first() if obj: return (obj.username,obj) raise APIException('用户认证失败') #  认证失败时 需要注册restframework 进行友好的展示错误信息 class AuthView(APIView): authentication_classes=[MyAuthentication,]
                   ....

    - 全局 : 多个类都需要用到认证的时候 就需要在配置文件中配置了 REST_FRAMEWORK = { 'UNAUTHENTICATED_USER': None, 'UNAUTHENTICATED_TOKEN': None, "DEFAULT_AUTHENTICATION_CLASSES": [ "app02.utils.MyAuthentication", ], } class HostView(APIView):
                   #authentication_classes=[] #如果在类中写了authentication_classes 等于一个空列表,那就表示 这个类内的接口不需要认证
    def get(self,request,*args,**kwargs): return HttpResponse('主机列表')

             - 类的继承:
              ********utils.py

              from rest_framework.authentication import BaseAuthentication
              from rest_framework import exceptions

              class LuffyTokenAuthentication(BaseAuthentication):
                  keyword = 'Token'

                  def authenticate(self, request):
                      """
                      Authenticate the request and return a two-tuple of (user, token).
                      """

                      token = request.query_params.get('token')
                      if not token:
                          raise exceptions.AuthenticationFailed('验证失败')
                      return self.authenticate_credentials(token)

                  def authenticate_credentials(self, token):
                      from luffy.models import UserAuthToken
                      try:
                          token_obj = UserAuthToken.objects.select_related('user').get(token=token)
                      except Exception as e:
                          raise exceptions.AuthenticationFailed(_('Invalid token.'))

                      return (token_obj.user, token_obj)

             class AuthAPIView(object):
                 authentication_classes = [LuffyTokenAuthentication,]
        
             *******views.py
        
              from utils import AuthAPIView

              class ShoppingCarView(AuthAPIView,APIView):   #注意继承的顺序

                  def get(self,request,*args,**kwargs):
                      pass

     第1.2.2步: 权限相关

    class APIView(View):
        def check_permissions(self, request):
         #1.2.2.1 permission是每一个权限类的对象
    for permission in self.get_permissions():
           #1.2.2.2
    if not permission.has_permission(request, self): #我猜has_permission方法返回的值是 True/False,True代表有权限
              # 1.2.2.3 如果没有权限执行  self.permission_denied( request, message
    =getattr(permission, 'message', None) #根据这句话可以发现,可以在自定义类中写 message='无权访问' )

    第1.2.2.1步

    class APIView(View):
        def get_permissions(self):
            return [permission() for permission in self.permission_classes]

    第1.2.2.2步: 我们可以在自定义类中写这个方法,通过一些逻辑判断后让它返回True或False

    第1.2.2.3步

    class APIView(View):
        def permission_denied(self, request, message=None):
            # request.authenticators = 一个对象列表
            # 1.2.2.3.1  如果认证成功,不执行此步
            if request.authenticators and not request.successful_authenticator:
                raise exceptions.NotAuthenticated()   #抛出 未进行认证的异常,这里可以传错误信息 detail='xxx'
         # 1.2.2.3.2 抛出异常
    raise exceptions.PermissionDenied(detail=message)

    第1.2.2.3.1步

    class Request(object):        
        @property
        def successful_authenticator(self):
            #self._authenticator  是 最后的那个认证类的对象
            return self._authenticator

    第1.2.2.3.2步

    没有init方法,执行父类的        
    class PermissionDenied(APIException):
        default_detail = _('You do not have permission to perform this action.')
        
    class APIException(Exception):
        default_detail = _('A server error occurred.')

        def __init__(self, detail=None, code=None):
            if detail is None:
                detail = self.default_detail
            
            self.detail = _get_error_details(detail, code)

        def __str__(self):
            return six.text_type(self.detail)

    认证和权限联合使用:

    class MyAuthentication(BaseAuthentication):
    
        def authenticate(self, request):
            token = request.query_params.get('token')
            obj = models.UserInfo.objects.filter(token=token).first()
            if obj:
                return (obj.username,obj)
            return None
    
        def authenticate_header(self, request):
            """
            Return a string to be used as the value of the `WWW-Authenticate`
            header in a `401 Unauthenticated` response, or `None` if the
            authentication scheme should return `403 Permission Denied` responses.
            """
            # return 'Basic realm="api"'
            pass
    
    class MyPermission(object):
        message = "无权访问"
        def has_permission(self,request,view):
            if request.user:
                return True
            return False
    
    class AdminPermission(object):
        message = "无权访问"
        def has_permission(self,request,view):
            if request.user == 'alex':
                return True
            return False
    
    class HostView(APIView):
        """
        匿名用户和用户都能访问
        """
        authentication_classes = [MyAuthentication,]
        permission_classes = []
        def get(self,request,*args,**kwargs):
    
            return Response('主机列表')
    
    class UserView(APIView):
        """
        用户能访问
        """
        authentication_classes = [MyAuthentication, ]
        permission_classes = [MyPermission,]
        def get(self,request,*args,**kwargs):
            return Response('用户列表')
    
    class SalaryView(APIView):
        """
        管理员能访问
        """
        authentication_classes = [MyAuthentication, ]
        permission_classes = [MyPermission,AdminPermission,]
        def get(self,request,*args,**kwargs):
    
            return Response('薪资列表')
    
        #自定义未认证的错误信息
        def permission_denied(self, request, message=None):
            if request.authenticators and not request.successful_authenticator:
                raise exceptions.NotAuthenticated(detail='xxxxxxxx')
            raise exceptions.PermissionDenied(detail=message)

     在全局内使用权限需配置:

                        REST_FRAMEWORK = {
                                "DEFAULT_PERMISSION_CLASSES": [
                                     "app02.utils.MyPermission",
                                ],
                        }

    第1.2.3步: 访问频率相关  

    class APIView(View):
        def check_throttles(self, request):
         #1.2.3.1
    for throttle in self.get_throttles(): if not throttle.allow_request(request, self): #从这句代码可以看出,自定义的限流类 可以 写allow_request 方法,返回值应该是 True 表示通行或 False 表示限制
              # 1.2.3.2 限制的情况下执行 self.throttled(request, throttle.wait()) #从这句代码可以看出,自定义的限流类 中 要写 wait 方法,而且返回值必须是数字类型或者 None

    第1.2.3.1步:

    class APIView(View):
        def get_throttles(self):
            return [throttle() for throttle in self.throttle_classes]   #返回一个 限流 类的对象列表

    第1.2.3.2步:

    class APIView(View):
        def throttled(self, request, wait):
            raise exceptions.Throttled(wait)  #类的实例化  ,抛出异常是一个对象,那在打印的时候一定调用了 __str__方法
    class Throttled(APIException):
        default_detail = _('Request was throttled.')
        extra_detail_plural = 'Expected available in {wait} seconds.'
        def __init__(self, wait=None, detail=None, code=None):
            if detail is None:
                detail = force_text(self.default_detail)
            if wait is not None:
                wait = math.ceil(wait)
                detail = ' '.join((              # 把 wait方法的返回值和 detail 放到了一起,作为新的参数 传给了父类进行初始化
                    detail,
                    force_text(ungettext(self.extra_detail_singular.format(wait=wait),
                                         self.extra_detail_plural.format(wait=wait),
                                         wait))))
            self.wait = wait
            super(Throttled, self).__init__(detail, code)    #传给了父类进行初始化
    class APIException(Exception):
        default_detail = _('A server error occurred.')
        default_code = 'error'
    
        def __init__(self, detail=None, code=None):
            if detail is None:
                detail = self.default_detail
            if code is None:
                code = self.default_code
         #把 传过来的 错误信息 detail 赋值给了 self.detail
            self.detail = _get_error_details(detail, code)   
    
        def __str__(self):
            return six.text_type(self.detail)   #打印错误信息

    自定义的访问频率限制

    class MyThrottle(BaseThrottle):
        def allow_request(self,request,view):
            return False
            
        def wait(self):
            return 22     #表示还需22秒才能访问
        
    class User(APIView):
        throttle_classes=[MyThrottle,]
        
        def get(self,request,*args,**kwargs):
            return Response('333333')

    自定义一个对匿名用户的限流

    RECORD={}
    
    class MyThrottle(BaseThrottle):
        
        def allow_request(self,request,view):
            """
            返回False,限制
            返回True,通行
            """
             a. 对匿名用户进行限制:每个用户1分钟允许访问10次
                - 获取用户IP request 1.1.1
            """
            import time
            ctime = time.time()
            ip = "1.1.1"
            if ip not in RECORD:
                RECORD[ip] = [ctime,]
            else:
                # [4507862389234,3507862389234,2507862389234,1507862389234,]
                time_list = RECORD[ip]
                while True:
                    val = time_list[-1]
                    if (ctime-60) > val:
                        time_list.pop()
                    else:
                        break
                if len(time_list) > 10:
                    return False
                time_list.insert(0,ctime)
            return True
        def wait(self):
            import time
            ctime = time.time()
            first_in_time = RECORD["1.1.1"][-1]
            wt = 60 - (ctime - first_in_time)
            return wt
    class User(APIView):
    
        throttle_classes=[MyThrottle]
        
        def get(self,request,*args,**kwargs):
    
            return Response('333333')
    View Code

    但是这样写觉得很麻烦,故有更加简单的写法:如下 继承SimpleRateThrottle

    class MySimpleRateThrottle(SimpleRateThrottle):
        scope = "wdp"
    
        def get_cache_key(self, request, view):
            return self.get_ident(request)
         # 可以返回 None 表示 不限流
    class LimitView(APIView): authentication_classes = [] permission_classes = [] throttle_classes=[MySimpleRateThrottle,] def get(self,request,*args,**kwargs): return Response('控制访问频率示例') def throttled(self, request, wait): #自定义错误信息 class MyThrottled(exceptions.Throttled): default_detail = '请求被限制.' extra_detail_plural = '还需要再等待{wait}' raise MyThrottled(wait) 需要在配置文件中设置: REST_FRAMEWORK = { 'DEFAULT_THROTTLE_RATES':{ 'wdp':'5/minute', } }

    这种简单写法的源码流程:  首先也是要执行  allow_request 方法 ,它自己类中没有就去基类中找

    class SimpleRateThrottle(BaseThrottle):
        def allow_request(self, request, view):
            #简1   (不执行)
            if self.rate is None:
                return True
            #简2   如果你不重写.get_cache_key() 方法 就抛出异常
            self.key = self.get_cache_key(request, view)
            if self.key is None:
                return True
    
            self.history = self.cache.get(self.key, [])   #通过唯一标识 到 cache中取相当于刚才匿名用户的 一个ip的访问记录列表
                    # cache 可以放在本地,也可以放在缓存中 等
    self.now = self.timer() #如果 记录列表有值并且列表最后面的值 小于当前时间减去限流的周期 就说明这条记录过期了 while self.history and self.history[-1] <= self.now - self.duration: self.history.pop() #pop掉 #判断 访问的次数大不大于 限流的次数 if len(self.history) >= self.num_requests: #大于, return False return self.throttle_failure() #简3 return self.throttle_success()

    第简1步:  self.rate

    class SimpleRateThrottle(BaseThrottle):
        timer = time.time
        scope = None
        THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
    
        def __init__(self):
            if not getattr(self, 'rate', None):   #自定义类中没写 rate字段,执行 简1.1 步
            # 简1.1 self.rate
    = self.get_rate()
           # 简1.2 self.num_requests, self.duration
    = self.parse_rate(self.rate)

    第简1.1步:

    class SimpleRateThrottle(BaseThrottle):
        def get_rate(self):
            #如果 自定义中没有定义 scope 字段 ,抛出异常
            if not getattr(self, 'scope', None):
                msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                       self.__class__.__name__)
                raise ImproperlyConfigured(msg)
            
            try:
                #简1.1.1
                return self.THROTTLE_RATES[self.scope]      # 就是去配置文件中取值  '5/minute'
                           #THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES 
            except KeyError:
                msg = "No default throttle rate set for '%s' scope" % self.scope
                raise ImproperlyConfigured(msg)

    第简1.2 步:

    class SimpleRateThrottle(BaseThrottle):
        def parse_rate(self, rate):
            num, period = rate.split('/')    #  '5/minute'
            num_requests = int(num)          #  限流的次数
            duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]    #限流的周期
            return (num_requests, duration)

    第简2 步:

    class SimpleRateThrottle(BaseThrottle):
        def get_cache_key(self, request, view):
            raise NotImplementedError('.get_cache_key() must be overridden')    #抛出异常  .get_cache_key() 方法 必须被重写
         
    #所以我们要在自定义类中 重写 .get_cache_key()方法
    class MySimpleRateThrottle(SimpleRateThrottle):
      def
    get_cache_key(self, request, view):
           #简2.1   
    return self.get_ident(request)

    第简2.1步: 说白了就是 去request 中获取 唯一标识

    class BaseThrottle(object):
        def get_ident(self, request):
            
            xff = request.META.get('HTTP_X_FORWARDED_FOR')
            remote_addr = request.META.get('REMOTE_ADDR')
            num_proxies = api_settings.NUM_PROXIES
    
            if num_proxies is not None:
                if num_proxies == 0 or xff is None:
                    return remote_addr
                       addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()
    
            return ''.join(xff.split()) if xff else remote_addr

     第简3步:

    class SimpleRateThrottle(BaseThrottle):
        def throttle_success(self):
            #向记录列表中的第一个位置插入数据,再给self.key 辅助
            self.history.insert(0, self.now)
            self.cache.set(self.key, self.history, self.duration)
            return True

    接着我们再看看 wait 方法

    class SimpleRateThrottle(BaseThrottle):
        def wait(self):
         # 如果 记录列表有值
            if self.history:
           # 剩余的周期是 限流周期减去当前时间减去记录列表最后的一个值的时间 remaining_duration
    = self.duration - (self.now - self.history[-1]) else:
           # 如果列表没有值, 说明是第一次访问,剩余周期等于限流的周期 remaining_duration
    = self.duration      #可访问的次数 等于 限流的次数减去 记录列表的长度加上本次访问 available_requests = self.num_requests - len(self.history) + 1 if available_requests <= 0: #如果没有次数了 返回None return None return remaining_duration / float(available_requests)

    全局使用访问频率限制 的配置

    REST_FRAMEWORK = {
                "DEFAULT_THROTTLE_CLASSES":[
                    "app02.utils.AnonThrottle",
                ],
                'DEFAULT_THROTTLE_RATES':{
                    'wdp_anon':'5/minute',
                    'wdp_user':'10/minute',
                }
            }

    --------------------------------------------------------------------------------------------------------------------

    认证+权限+限流 一起使用的代码:   对匿名用户进行限制 每个用户1分钟允许访问5次,登录用户1分钟允许访问10次

         一个是通过ip(如果客户端使用代理就不好限流了),另外一个是通过登录用户的用户名

    from rest_framework.views import APIView
    from rest_framework.response import Response
    from rest_framework.throttling import BaseThrottle,SimpleRateThrottle
    from rest_framework.authentication import BaseAuthentication
    from app02 import models
    
    class MyAuthentication(BaseAuthentication):
        def authenticate(self, request):
            token = request.query_params.get('token')
            obj = models.UserInfo.objects.filter(token=token).first()
            if obj:
                return (obj.username,obj)
            return None
    
        def authenticate_header(self, request):
            pass
    
    class MyPermission(object):
        message = "无权访问"
        def has_permission(self,request,view):
            if request.user:
                return True
            return False
    
    class AdminPermission(object):
        message = "无权访问"
        def has_permission(self,request,view):
            if request.user == 'alex':
                return True
            return False
    
    
    class AnonThrottle(SimpleRateThrottle):
        scope = "wdp_anon"
    
        def get_cache_key(self, request, view):
            # 返回None,表示我不限制
            # 登录用户我不管
            if request.user:
                return None
            # 匿名用户
            return self.get_ident(request)
    
    class UserThrottle(SimpleRateThrottle):
        scope = "wdp_user"
    
        def get_cache_key(self, request, view):
            # 登录用户
            if request.user:
                return request.user
            # 匿名用户我不管
            return None
    
    
    # 无需登录就可以访问
    class IndexView(APIView):
        authentication_classes = [MyAuthentication,]
        permission_classes = []
        throttle_classes=[AnonThrottle,UserThrottle,]
        def get(self,request,*args,**kwargs):
      
            return Response('访问首页')
    
    # 需登录就可以访问
    class ManageView(APIView):
        authentication_classes = [MyAuthentication,]
        permission_classes = [MyPermission,]
        throttle_classes=[AnonThrottle,UserThrottle,]
        def get(self,request,*args,**kwargs):
    
            return Response('访问首页')
    View Code
    REST_FRAMEWORK = {
        'UNAUTHENTICATED_USER': None,
        'UNAUTHENTICATED_TOKEN': None,
        "DEFAULT_AUTHENTICATION_CLASSES": [
            
        ],
        'DEFAULT_PERMISSION_CLASSES':[
    
        ],
        'DEFAULT_THROTTLE_RATES':{
            'wdp_anon':'5/minute',
            'wdp_user':'10/minute',
    
        }
    }
    配置文件
  • 相关阅读:
    UVA 11174 Stand in a Line,UVA 1436 Counting heaps —— (组合数的好题)
    UVA 1393 Highways,UVA 12075 Counting Triangles —— (组合数,dp)
    【Same Tree】cpp
    【Recover Binary Search Tree】cpp
    【Binary Tree Zigzag Level Order Traversal】cpp
    【Binary Tree Level Order Traversal II 】cpp
    【Binary Tree Level Order Traversal】cpp
    【Binary Tree Post order Traversal】cpp
    【Binary Tree Inorder Traversal】cpp
    【Binary Tree Preorder Traversal】cpp
  • 原文地址:https://www.cnblogs.com/liuwei0824/p/8418998.html
Copyright © 2011-2022 走看看