zoukankan      html  css  js  c++  java
  • python-django rest framework框架之dispatch方法源码分析

    1.Django的 CBV 中在请求到来之后,都要执行dispatch方法,dispatch方法根据请求方式不同触发 get/post/put等方法

    class APIView(View):
        def dispatch(self, request, *args, **kwargs):#1.1 把wsgi的request进行封装
            request = self.initialize_request(request, *args, **kwargs)
            self.request = request     #此时的self.request 是rest_framework的Request对象,它里面比wsgi的request多了一些东西
    
            try:
                #1.2 进行 初始化     :版本,认证,权限,访问频率
                self.initial(request, *args, **kwargs)
    #1.3 反射执行get等方法 if request.method.lower() in self.http_method_names: handler = getattr(self, request.method.lower(), self.http_method_not_allowed) else: handler = self.http_method_not_allowed response = handler(request, *args, **kwargs) except Exception as exc: response = self.handle_exception(exc) #1.4 把返回的response 再进行封装 self.response = self.finalize_response(request, response, *args, **kwargs) #1.5 返回 , dispatch方法一定要有返回值,因为get等方法返回的结果要返回给前端 return self.response

    第1.1步:

    from rest_framework.request import Request
    
    class APIView(View):
        def initialize_request(self, request, *args, **kwargs):
           #返回了一个rest_framework的Request对象
            return Request(
                request,
           #1.1.1 parsers
    =self.get_parsers(),
           #1.1.2 authenticators
    =self.get_authenticators(),
           #1.1.3 negotiator
    =self.get_content_negotiator(), )

    第1.1.1步:

      pass

    第1.1.2步:

    class APIView(View):
        def get_authenticators(self):
         #self.authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES 是从配置文件中取数据,从变量名不难看出有可能是很多类的列表
    return [auth() for auth in self.authentication_classes]
         #self.authenticators = 一个多个对象的列表

    第1.2步:

    class APIView(View):
        def initial(self, request, *args, **kwargs):
         #1.2.0版本相关
         version, scheme = self.determine_version(request, *args, **kwargs)
         request.version, request.versioning_scheme = version, scheme
          # request.version, request.versioning_scheme = 版本,和相应的版本类对象
          # 1.2.0.0
          # 既然 request.version 是版本, 那request.versioning_scheme 代表什么呢?

    #1.2.1认证相关 self.perform_authentication(request) #1.2.2权限相关 self.check_permissions(request) #1.2.3访问频率相关 self.check_throttles(request)

    第1.2.0步:

    获取版本数据有五种接收方式:url上传参,url,子域名,namespace,请求头

    from rest_framework.versioning import QueryParameterVersioning,URLPathVersioning,HostNameVersioning  
       方式一:
        # 基于url传参  http://127.0.0.1:8001/api/users/?version=v1
        # versioning_class = QueryParameterVersioning

        方式二:
        # 基于URL http://127.0.0.1:8001/api/v2/users/
        # versioning_class = URLPathVersioning
         # url 配置: url(r'^api/(?P<version>[v1|v2]+)/', include('api.urls')),
        方式三:
        # 基于子域名 http://v1.luffy.com/users/
        # versioning_class = HostNameVersioning
      
     配置文件:
      REST_FRAMEWORK = {
      'VERSION_PARAM':'version',
      'DEFAULT_VERSION':'v1',
      'ALLOWED_VERSIONS':['v1','v2'],
      # 'DEFAULT_VERSIONING_CLASS':"rest_framework.versioning.HostNameVersioning",
      'DEFAULT_VERSIONING_CLASS':"rest_framework.versioning.URLPathVersioning",
      # 'DEFAULT_VERSIONING_CLASS':"rest_framework.versioning.QueryParameterVersioning",
      }
     
    class APIView(View):
        def determine_version(self, request, *args, **kwargs):
            if self.versioning_class is None:     # self.versioning_class = api_settings.DEFAULT_VERSIONING_CLASS
            #如果没有配置 版本相关的类
                return (None, None)
            scheme = self.versioning_class()      # 实例化版本类对象
            #调用对象的determine_version 方法
            return (scheme.determine_version(request, *args, **kwargs), scheme)

    方式一的 determine_version 方法:

    class QueryParameterVersioning(BaseVersioning):
       invalid_version_message = _('Invalid version in query parameter.')
    def determine_version(self, request, *args, **kwargs): # 相当于 request.GET.get() version = request.query_params.get(self.version_param, self.default_version) #self.version_param = api_settings.VERSION_PARAM #self.default_version = api_settings.DEFAULT_VERSION
         # 见下面
    if not self.is_allowed_version(version): raise exceptions.NotFound(self.invalid_version_message) return version

    class BaseVersioning(object):
        def is_allowed_version(self, version):
         #如果没有配置允许的版本列表,代表没限制
            if not self.allowed_versions: #self.allowed_versions = api_settings.ALLOWED_VERSIONS
                return True
            return ((version is not None and version == self.default_version) or
                    (version in self.allowed_versions))

    方式二的 determine_version 方法:

    class URLPathVersioning(BaseVersioning):
        invalid_version_message = _('Invalid version in URL path.')
        
        def determine_version(self, request, *args, **kwargs):
            #从传过来的字典中获取版本
            version = kwargs.get(self.version_param, self.default_version)
            
            if not self.is_allowed_version(version):
                raise exceptions.NotFound(self.invalid_version_message)
            return version
            
            
    class BaseVersioning(object):
        def is_allowed_version(self, version):
            if not self.allowed_versions:
                return True
            return ((version is not None and version == self.default_version) or
                    (version in self.allowed_versions))

    方式三的 determine_version 方法:

    class HostNameVersioning(BaseVersioning):
        hostname_regex = re.compile(r'^([a-zA-Z0-9]+).[a-zA-Z0-9]+.[a-zA-Z0-9]+$')
        invalid_version_message = _('Invalid version in hostname.')
        
        def determine_version(self, request, *args, **kwargs):
            hostname, separator, port = request.get_host().partition(':')
            match = self.hostname_regex.match(hostname)
            if not match:
                return self.default_version
            version = match.group(1)
            if not self.is_allowed_version(version):
                raise exceptions.NotFound(self.invalid_version_message)
            return version
            
    class BaseVersioning(object):
        def is_allowed_version(self, version):
            if not self.allowed_versions:
                return True
            return ((version is not None and version == self.default_version) or
                    (version in self.allowed_versions))

    第1.2.0.0步:     request.versioning_scheme  是用来反向生成url 的

            urlpatterns = [
                url(r'^api/(?P<version>[v1|v2]+)/', include('api.urls')),
                url(r'^api/', include('api.urls')),
            ]
        
            urlpatterns = [
                url(r'^users/', views.UsersView.as_view(),name='u'),
            ]
            
            
            # 当前版本一样的URL
            # url = request.versioning_scheme.reverse(viewname='u',request=request)    #不用传版本的参数
            # print(url)
    
            # 当前版本不一样的URL
            # from django.urls import reverse
            # url = reverse(viewname='u',kwargs={'version':'v2'})
            # print(url)

     方式一的 reverse方法

    class QueryParameterVersioning(BaseVersioning):
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            #调用基类的 reverse方法
            url = super(QueryParameterVersioning, self).reverse(
                viewname, args, kwargs, request, format, **extra
            )
            # 之所以在反向生成的时候不用传版本参数, 是因为这步帮你处理了
            if request.version is not None:
                return replace_query_param(url, self.version_param, request.version)
            return url
            
    class BaseVersioning(object):    
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            return _reverse(viewname, args, kwargs, request, format, **extra)
            
    def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
    
        if format is not None:
            kwargs = kwargs or {}
            kwargs['format'] = format
        # 最后还是调用了 django的 reverse方法 生成url
        url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
        if request:
            return request.build_absolute_uri(url)
        return url

     方式二的 reverse方法

    class URLPathVersioning(BaseVersioning):
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            if request.version is not None:
                kwargs = {} if (kwargs is None) else kwargs
                #给 version赋值
                kwargs[self.version_param] = request.version
    
            #调用基类的reverse方法
            return super(URLPathVersioning, self).reverse(
                viewname, args, kwargs, request, format, **extra
            )
            
    class BaseVersioning(object):
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            return _reverse(viewname, args, kwargs, request, format, **extra)
            
            
    def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
    
        if format is not None:
            kwargs = kwargs or {}
            kwargs['format'] = format
        #最后还是调用了django的reverse方法生成url
        url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
        if request:
            return request.build_absolute_uri(url)
        return url

    方式三的 reverse方法

    方式三 HostNameVersioning 类中没有 reverse方法,所以直接去基类中找
    class BaseVersioning(object):
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            return _reverse(viewname, args, kwargs, request, format, **extra)
             
    def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
        if format is not None:
            kwargs = kwargs or {}
            kwargs['format'] = format
        #最后还是调用了django的reverse方法生成url
        url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
        if request:
            return request.build_absolute_uri(url)
        return url

    第1.2.1步: 认证相关

    class APIView(View):
        def perform_authentication(self, request):
         #1.2.1.1
            request.user

    第1.2.1.1步:

    class Request(object):    
        @property
        def user(self):       #此时的self 是 rest_framework的request对象
            if not hasattr(self, '_user'):
                with wrap_attributeerrors():
              #1.2.1.1.1 self._authenticate()
    return self._user

    1.2.1.1.1步:

    class Request(object):    
        def _authenticate(self):
            #此时的self 是 rest_framework的request对象
    for authenticator in self.authenticators: #self.authenticators = 一个多个对象的列表 try:
              #执行每个对象的authenticate方法 user_auth_tuple
    = authenticator.authenticate(self)   #从变量的名不难看出 返回了一个元组 except exceptions.APIException: self._not_authenticated() raise if user_auth_tuple is not None: self._authenticator = authenticator
              #赋值, request.user和request.auth 并返回 self.user, self.auth
    = user_auth_tuple return self._not_authenticated()

    第1.3步反射执行get等方法

    我们可以自定义一个简单的用户认证

    class MyAuth(object):
        def authenticate(self,request):
            return "1111","222"
        
    class Host(APIView):
        authentication_classes=[MyAuth]
        def get(self,request):
            print(request.user)    #1111
            print(request.auth)   #222
            return HttpResponse("666")

    认证

    - 认证
                    - 局部  : 只是一个类内的一些接口用
                        class MyAuthentication(BaseAuthentication):
    
                            def authenticate(self, request):
                     '''
                     有三种返回值: None 表示我不管,交给下一个进行认证, 元组表示认证成功, 抛出异常
                     '''
    # return None ,我不管,交给下一个进行认证 token = request.query_params.get('token') obj = models.UserInfo.objects.filter(token=token).first() if obj: return (obj.username,obj) raise APIException('用户认证失败') #  认证失败时 需要注册restframework 进行友好的展示错误信息 class AuthView(APIView): authentication_classes=[MyAuthentication,]
                   ....

    - 全局 : 多个类都需要用到认证的时候 就需要在配置文件中配置了 REST_FRAMEWORK = { 'UNAUTHENTICATED_USER': None, 'UNAUTHENTICATED_TOKEN': None, "DEFAULT_AUTHENTICATION_CLASSES": [ "app02.utils.MyAuthentication", ], } class HostView(APIView):
                   #authentication_classes=[] #如果在类中写了authentication_classes 等于一个空列表,那就表示 这个类内的接口不需要认证
    def get(self,request,*args,**kwargs): return HttpResponse('主机列表')

             - 类的继承:
              ********utils.py

              from rest_framework.authentication import BaseAuthentication
              from rest_framework import exceptions

              class LuffyTokenAuthentication(BaseAuthentication):
                  keyword = 'Token'

                  def authenticate(self, request):
                      """
                      Authenticate the request and return a two-tuple of (user, token).
                      """

                      token = request.query_params.get('token')
                      if not token:
                          raise exceptions.AuthenticationFailed('验证失败')
                      return self.authenticate_credentials(token)

                  def authenticate_credentials(self, token):
                      from luffy.models import UserAuthToken
                      try:
                          token_obj = UserAuthToken.objects.select_related('user').get(token=token)
                      except Exception as e:
                          raise exceptions.AuthenticationFailed(_('Invalid token.'))

                      return (token_obj.user, token_obj)

             class AuthAPIView(object):
                 authentication_classes = [LuffyTokenAuthentication,]
        
             *******views.py
        
              from utils import AuthAPIView

              class ShoppingCarView(AuthAPIView,APIView):   #注意继承的顺序

                  def get(self,request,*args,**kwargs):
                      pass

     第1.2.2步: 权限相关

    class APIView(View):
        def check_permissions(self, request):
         #1.2.2.1 permission是每一个权限类的对象
    for permission in self.get_permissions():
           #1.2.2.2
    if not permission.has_permission(request, self): #我猜has_permission方法返回的值是 True/False,True代表有权限
              # 1.2.2.3 如果没有权限执行  self.permission_denied( request, message
    =getattr(permission, 'message', None) #根据这句话可以发现,可以在自定义类中写 message='无权访问' )

    第1.2.2.1步

    class APIView(View):
        def get_permissions(self):
            return [permission() for permission in self.permission_classes]

    第1.2.2.2步: 我们可以在自定义类中写这个方法,通过一些逻辑判断后让它返回True或False

    第1.2.2.3步

    class APIView(View):
        def permission_denied(self, request, message=None):
            # request.authenticators = 一个对象列表
            # 1.2.2.3.1  如果认证成功,不执行此步
            if request.authenticators and not request.successful_authenticator:
                raise exceptions.NotAuthenticated()   #抛出 未进行认证的异常,这里可以传错误信息 detail='xxx'
         # 1.2.2.3.2 抛出异常
    raise exceptions.PermissionDenied(detail=message)

    第1.2.2.3.1步

    class Request(object):        
        @property
        def successful_authenticator(self):
            #self._authenticator  是 最后的那个认证类的对象
            return self._authenticator

    第1.2.2.3.2步

    没有init方法,执行父类的        
    class PermissionDenied(APIException):
        default_detail = _('You do not have permission to perform this action.')
        
    class APIException(Exception):
        default_detail = _('A server error occurred.')

        def __init__(self, detail=None, code=None):
            if detail is None:
                detail = self.default_detail
            
            self.detail = _get_error_details(detail, code)

        def __str__(self):
            return six.text_type(self.detail)

    认证和权限联合使用:

    class MyAuthentication(BaseAuthentication):
    
        def authenticate(self, request):
            token = request.query_params.get('token')
            obj = models.UserInfo.objects.filter(token=token).first()
            if obj:
                return (obj.username,obj)
            return None
    
        def authenticate_header(self, request):
            """
            Return a string to be used as the value of the `WWW-Authenticate`
            header in a `401 Unauthenticated` response, or `None` if the
            authentication scheme should return `403 Permission Denied` responses.
            """
            # return 'Basic realm="api"'
            pass
    
    class MyPermission(object):
        message = "无权访问"
        def has_permission(self,request,view):
            if request.user:
                return True
            return False
    
    class AdminPermission(object):
        message = "无权访问"
        def has_permission(self,request,view):
            if request.user == 'alex':
                return True
            return False
    
    class HostView(APIView):
        """
        匿名用户和用户都能访问
        """
        authentication_classes = [MyAuthentication,]
        permission_classes = []
        def get(self,request,*args,**kwargs):
    
            return Response('主机列表')
    
    class UserView(APIView):
        """
        用户能访问
        """
        authentication_classes = [MyAuthentication, ]
        permission_classes = [MyPermission,]
        def get(self,request,*args,**kwargs):
            return Response('用户列表')
    
    class SalaryView(APIView):
        """
        管理员能访问
        """
        authentication_classes = [MyAuthentication, ]
        permission_classes = [MyPermission,AdminPermission,]
        def get(self,request,*args,**kwargs):
    
            return Response('薪资列表')
    
        #自定义未认证的错误信息
        def permission_denied(self, request, message=None):
            if request.authenticators and not request.successful_authenticator:
                raise exceptions.NotAuthenticated(detail='xxxxxxxx')
            raise exceptions.PermissionDenied(detail=message)

     在全局内使用权限需配置:

                        REST_FRAMEWORK = {
                                "DEFAULT_PERMISSION_CLASSES": [
                                     "app02.utils.MyPermission",
                                ],
                        }

    第1.2.3步: 访问频率相关  

    class APIView(View):
        def check_throttles(self, request):
         #1.2.3.1
    for throttle in self.get_throttles(): if not throttle.allow_request(request, self): #从这句代码可以看出,自定义的限流类 可以 写allow_request 方法,返回值应该是 True 表示通行或 False 表示限制
              # 1.2.3.2 限制的情况下执行 self.throttled(request, throttle.wait()) #从这句代码可以看出,自定义的限流类 中 要写 wait 方法,而且返回值必须是数字类型或者 None

    第1.2.3.1步:

    class APIView(View):
        def get_throttles(self):
            return [throttle() for throttle in self.throttle_classes]   #返回一个 限流 类的对象列表

    第1.2.3.2步:

    class APIView(View):
        def throttled(self, request, wait):
            raise exceptions.Throttled(wait)  #类的实例化  ,抛出异常是一个对象,那在打印的时候一定调用了 __str__方法
    class Throttled(APIException):
        default_detail = _('Request was throttled.')
        extra_detail_plural = 'Expected available in {wait} seconds.'
        def __init__(self, wait=None, detail=None, code=None):
            if detail is None:
                detail = force_text(self.default_detail)
            if wait is not None:
                wait = math.ceil(wait)
                detail = ' '.join((              # 把 wait方法的返回值和 detail 放到了一起,作为新的参数 传给了父类进行初始化
                    detail,
                    force_text(ungettext(self.extra_detail_singular.format(wait=wait),
                                         self.extra_detail_plural.format(wait=wait),
                                         wait))))
            self.wait = wait
            super(Throttled, self).__init__(detail, code)    #传给了父类进行初始化
    class APIException(Exception):
        default_detail = _('A server error occurred.')
        default_code = 'error'
    
        def __init__(self, detail=None, code=None):
            if detail is None:
                detail = self.default_detail
            if code is None:
                code = self.default_code
         #把 传过来的 错误信息 detail 赋值给了 self.detail
            self.detail = _get_error_details(detail, code)   
    
        def __str__(self):
            return six.text_type(self.detail)   #打印错误信息

    自定义的访问频率限制

    class MyThrottle(BaseThrottle):
        def allow_request(self,request,view):
            return False
            
        def wait(self):
            return 22     #表示还需22秒才能访问
        
    class User(APIView):
        throttle_classes=[MyThrottle,]
        
        def get(self,request,*args,**kwargs):
            return Response('333333')

    自定义一个对匿名用户的限流

    RECORD={}
    
    class MyThrottle(BaseThrottle):
        
        def allow_request(self,request,view):
            """
            返回False,限制
            返回True,通行
            """
             a. 对匿名用户进行限制:每个用户1分钟允许访问10次
                - 获取用户IP request 1.1.1
            """
            import time
            ctime = time.time()
            ip = "1.1.1"
            if ip not in RECORD:
                RECORD[ip] = [ctime,]
            else:
                # [4507862389234,3507862389234,2507862389234,1507862389234,]
                time_list = RECORD[ip]
                while True:
                    val = time_list[-1]
                    if (ctime-60) > val:
                        time_list.pop()
                    else:
                        break
                if len(time_list) > 10:
                    return False
                time_list.insert(0,ctime)
            return True
        def wait(self):
            import time
            ctime = time.time()
            first_in_time = RECORD["1.1.1"][-1]
            wt = 60 - (ctime - first_in_time)
            return wt
    class User(APIView):
    
        throttle_classes=[MyThrottle]
        
        def get(self,request,*args,**kwargs):
    
            return Response('333333')
    View Code

    但是这样写觉得很麻烦,故有更加简单的写法:如下 继承SimpleRateThrottle

    class MySimpleRateThrottle(SimpleRateThrottle):
        scope = "wdp"
    
        def get_cache_key(self, request, view):
            return self.get_ident(request)
         # 可以返回 None 表示 不限流
    class LimitView(APIView): authentication_classes = [] permission_classes = [] throttle_classes=[MySimpleRateThrottle,] def get(self,request,*args,**kwargs): return Response('控制访问频率示例') def throttled(self, request, wait): #自定义错误信息 class MyThrottled(exceptions.Throttled): default_detail = '请求被限制.' extra_detail_plural = '还需要再等待{wait}' raise MyThrottled(wait) 需要在配置文件中设置: REST_FRAMEWORK = { 'DEFAULT_THROTTLE_RATES':{ 'wdp':'5/minute', } }

    这种简单写法的源码流程:  首先也是要执行  allow_request 方法 ,它自己类中没有就去基类中找

    class SimpleRateThrottle(BaseThrottle):
        def allow_request(self, request, view):
            #简1   (不执行)
            if self.rate is None:
                return True
            #简2   如果你不重写.get_cache_key() 方法 就抛出异常
            self.key = self.get_cache_key(request, view)
            if self.key is None:
                return True
    
            self.history = self.cache.get(self.key, [])   #通过唯一标识 到 cache中取相当于刚才匿名用户的 一个ip的访问记录列表
                    # cache 可以放在本地,也可以放在缓存中 等
    self.now = self.timer() #如果 记录列表有值并且列表最后面的值 小于当前时间减去限流的周期 就说明这条记录过期了 while self.history and self.history[-1] <= self.now - self.duration: self.history.pop() #pop掉 #判断 访问的次数大不大于 限流的次数 if len(self.history) >= self.num_requests: #大于, return False return self.throttle_failure() #简3 return self.throttle_success()

    第简1步:  self.rate

    class SimpleRateThrottle(BaseThrottle):
        timer = time.time
        scope = None
        THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
    
        def __init__(self):
            if not getattr(self, 'rate', None):   #自定义类中没写 rate字段,执行 简1.1 步
            # 简1.1 self.rate
    = self.get_rate()
           # 简1.2 self.num_requests, self.duration
    = self.parse_rate(self.rate)

    第简1.1步:

    class SimpleRateThrottle(BaseThrottle):
        def get_rate(self):
            #如果 自定义中没有定义 scope 字段 ,抛出异常
            if not getattr(self, 'scope', None):
                msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                       self.__class__.__name__)
                raise ImproperlyConfigured(msg)
            
            try:
                #简1.1.1
                return self.THROTTLE_RATES[self.scope]      # 就是去配置文件中取值  '5/minute'
                           #THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES 
            except KeyError:
                msg = "No default throttle rate set for '%s' scope" % self.scope
                raise ImproperlyConfigured(msg)

    第简1.2 步:

    class SimpleRateThrottle(BaseThrottle):
        def parse_rate(self, rate):
            num, period = rate.split('/')    #  '5/minute'
            num_requests = int(num)          #  限流的次数
            duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]    #限流的周期
            return (num_requests, duration)

    第简2 步:

    class SimpleRateThrottle(BaseThrottle):
        def get_cache_key(self, request, view):
            raise NotImplementedError('.get_cache_key() must be overridden')    #抛出异常  .get_cache_key() 方法 必须被重写
         
    #所以我们要在自定义类中 重写 .get_cache_key()方法
    class MySimpleRateThrottle(SimpleRateThrottle):
      def
    get_cache_key(self, request, view):
           #简2.1   
    return self.get_ident(request)

    第简2.1步: 说白了就是 去request 中获取 唯一标识

    class BaseThrottle(object):
        def get_ident(self, request):
            
            xff = request.META.get('HTTP_X_FORWARDED_FOR')
            remote_addr = request.META.get('REMOTE_ADDR')
            num_proxies = api_settings.NUM_PROXIES
    
            if num_proxies is not None:
                if num_proxies == 0 or xff is None:
                    return remote_addr
                       addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()
    
            return ''.join(xff.split()) if xff else remote_addr

     第简3步:

    class SimpleRateThrottle(BaseThrottle):
        def throttle_success(self):
            #向记录列表中的第一个位置插入数据,再给self.key 辅助
            self.history.insert(0, self.now)
            self.cache.set(self.key, self.history, self.duration)
            return True

    接着我们再看看 wait 方法

    class SimpleRateThrottle(BaseThrottle):
        def wait(self):
         # 如果 记录列表有值
            if self.history:
           # 剩余的周期是 限流周期减去当前时间减去记录列表最后的一个值的时间 remaining_duration
    = self.duration - (self.now - self.history[-1]) else:
           # 如果列表没有值, 说明是第一次访问,剩余周期等于限流的周期 remaining_duration
    = self.duration      #可访问的次数 等于 限流的次数减去 记录列表的长度加上本次访问 available_requests = self.num_requests - len(self.history) + 1 if available_requests <= 0: #如果没有次数了 返回None return None return remaining_duration / float(available_requests)

    全局使用访问频率限制 的配置

    REST_FRAMEWORK = {
                "DEFAULT_THROTTLE_CLASSES":[
                    "app02.utils.AnonThrottle",
                ],
                'DEFAULT_THROTTLE_RATES':{
                    'wdp_anon':'5/minute',
                    'wdp_user':'10/minute',
                }
            }

    --------------------------------------------------------------------------------------------------------------------

    认证+权限+限流 一起使用的代码:   对匿名用户进行限制 每个用户1分钟允许访问5次,登录用户1分钟允许访问10次

         一个是通过ip(如果客户端使用代理就不好限流了),另外一个是通过登录用户的用户名

    from rest_framework.views import APIView
    from rest_framework.response import Response
    from rest_framework.throttling import BaseThrottle,SimpleRateThrottle
    from rest_framework.authentication import BaseAuthentication
    from app02 import models
    
    class MyAuthentication(BaseAuthentication):
        def authenticate(self, request):
            token = request.query_params.get('token')
            obj = models.UserInfo.objects.filter(token=token).first()
            if obj:
                return (obj.username,obj)
            return None
    
        def authenticate_header(self, request):
            pass
    
    class MyPermission(object):
        message = "无权访问"
        def has_permission(self,request,view):
            if request.user:
                return True
            return False
    
    class AdminPermission(object):
        message = "无权访问"
        def has_permission(self,request,view):
            if request.user == 'alex':
                return True
            return False
    
    
    class AnonThrottle(SimpleRateThrottle):
        scope = "wdp_anon"
    
        def get_cache_key(self, request, view):
            # 返回None,表示我不限制
            # 登录用户我不管
            if request.user:
                return None
            # 匿名用户
            return self.get_ident(request)
    
    class UserThrottle(SimpleRateThrottle):
        scope = "wdp_user"
    
        def get_cache_key(self, request, view):
            # 登录用户
            if request.user:
                return request.user
            # 匿名用户我不管
            return None
    
    
    # 无需登录就可以访问
    class IndexView(APIView):
        authentication_classes = [MyAuthentication,]
        permission_classes = []
        throttle_classes=[AnonThrottle,UserThrottle,]
        def get(self,request,*args,**kwargs):
      
            return Response('访问首页')
    
    # 需登录就可以访问
    class ManageView(APIView):
        authentication_classes = [MyAuthentication,]
        permission_classes = [MyPermission,]
        throttle_classes=[AnonThrottle,UserThrottle,]
        def get(self,request,*args,**kwargs):
    
            return Response('访问首页')
    View Code
    REST_FRAMEWORK = {
        'UNAUTHENTICATED_USER': None,
        'UNAUTHENTICATED_TOKEN': None,
        "DEFAULT_AUTHENTICATION_CLASSES": [
            
        ],
        'DEFAULT_PERMISSION_CLASSES':[
    
        ],
        'DEFAULT_THROTTLE_RATES':{
            'wdp_anon':'5/minute',
            'wdp_user':'10/minute',
    
        }
    }
    配置文件
  • 相关阅读:
    免费部署Woocall到您自己的网站上
    服务器控件开发之复杂属性
    删除数据库的所有存储过程、主键、外键、索引等
    怎样在dropdownlist的每一项前加一个或多个空格
    Java的内部类学习
    StringUtils全览 (转)
    Java异常大全
    Java web 开发小问题总结(持续更新中)
    Java常用方法总结(持续更新中)
    Python 常用函数
  • 原文地址:https://www.cnblogs.com/liuwei0824/p/8418998.html
Copyright © 2011-2022 走看看