zoukankan      html  css  js  c++  java
  • drf认证权限频率

    一 认证Authentication

    认证组件:校验用户 - 游客、合法用户、非法用户
    游客:代表校验通过,直接进入下一步校验(权限校验)
    合法用户:代表校验通过,将用户存储在request.user中,再进入下一步校验(权限校验)
    非法用户:代表校验失败,抛出异常,返回403权限异常结果

    认证源码分析

    #1 APIVIew----> dispatch方法----> self.initial(request, *args, **kwargs)---->有认证,权限,频率
    
    #2 只读认证源码: self.perform_authentication(request)
    
    #3 self.perform_authentication(request)中就一句话:request.user,(这里的request对象被dispatch方法中的self.initialize_request(request, *args, **kwargs)方法重新封装成了drf自己的request对象)所以需要去drf的Request对象中找user属性(方法)
    
    @property
        def user(self):
            """
            Returns the user associated with the current request, as authenticated
            by the authentication classes provided to the request.
            """
            if not hasattr(self, '_user'):
                with wrap_attributeerrors():
                    self._authenticate()
            return self._user
    
    
    def initialize_request(self, request, *args, **kwargs):
        """
        Returns the initial request object.
        """
        parser_context = self.get_parser_context(request)
    
        return Request(
            request,
            parsers=self.get_parsers(),
            authenticators=self.get_authenticators(), # 获得一个列表,列表中是一个个类的对象
            negotiator=self.get_content_negotiator(),
            parser_context=parser_context
        )
    
    
    def get_authenticators(self):
        """
        Instantiates and returns the list of authenticators that this view can use.
        """
        # authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES,自己在视图类中配置的认证类(一个列表,列表中是类名)
        return [auth() for auth in self.authentication_classes]
    
    
    #4 Request类中的user方法,刚开始来,没有_user,走 self._authenticate()
    
    #5 核心,就是Request类的 _authenticate(self):
    def _authenticate(self):
        # 遍历拿到一个个认证器,进行认证
        # self.authenticators配置的一堆认证类产生的认证类对象组成的 list
        #self.authenticators 你在视图类中配置的一个个的认证类:authentication_classes=[认证类1,认证类2],对象的列表   
        for authenticator in self.authenticators: # 这里的authenticators就是在self.initialize_request中封装的authenticators,是一个列表,里面是你自定义的认证类的对象
            try:
                # 认证器(对象)调用认证方法authenticate(认证类对象self, request请求对象)
                # 返回值:登陆的用户与认证的信息组成的 tuple
                # 该方法被try包裹,代表该方法会抛异常,抛异常就代表认证失败
                user_auth_tuple = authenticator.authenticate(self) # 注意这self是request对象,authenticate就是你自定义的认证类中重写的authenticate方法
            except exceptions.APIException:
                self._not_authenticated()
                raise
    
            # 返回值的处理
            """
            !!!!如果你配置了多个认证类,一定要在最后一个认证类中返回值,否则后面的认证类执行不了
            """
            if user_auth_tuple is not None:
                self._authenticator = authenticator
                # 如果有返回值,就将 登陆用户 与 登陆认证 分别保存到 request.user、request.auth
                self.user, self.auth = user_auth_tuple
                return
        # 如果返回值user_auth_tuple为空,代表认证通过,但是没有 登陆用户 与 登陆认证信息,代表游客
        self._not_authenticated()
    

    自定义认证方案

    # models.py
    class User(models.Model):
        username=models.CharField(max_length=32)
        password=models.CharField(max_length=32)
        user_type=models.IntegerField(choices=((1,'超级用户'),(2,'普通用户'),(3,'二笔用户')))
    
    class UserToken(models.Model):
        user=models.OneToOneField(to='User')
        token=models.CharField(max_length=64)
    
    
    # auth.py(新建认证类)
    from rest_framework.authentication import BaseAuthentication
    from rest_framework.exceptions import AuthenticationFailed
    from app01.models import UserToken
    
    class TokenAuth():
        # token方在url中
        def authenticate(self, request):
            token = request.GET.get('token')
            token_obj = models.UserToken.objects.filter(token=token).first()
            if token_obj:
                return
            else:
                raise AuthenticationFailed('认证失败')
    
        # token放在请求头中
        # def authenticate(self, request):
        #     #     token = request.META.get('HTTP_TOKEN', None) # 注意token被重新拼接了
        #     #     if not token:
        #     #         raise AuthenticationFailed('非法查看')
        #     #     obj = UserToken.objects.filter(token=token).first()
        #     #     if obj:
        #     #         return obj.user, token
        #     #     raise AuthenticationFailed('认证失败')
    
    
    # views.py
    from rest_framework.views import APIView
    from app01 import models
    
    def get_random(name):
        import hashlib
        import time
        md=hashlib.md5()
        md.update(bytes(str(time.time()),encoding='utf-8'))
        md.update(bytes(name,encoding='utf-8'))
        return md.hexdigest()
        
    class Login(APIView):
        authentication_classes = []
        def post(self,reuquest):
            back_msg={'status':1001,'msg':None}
            try:
                name=reuquest.data.get('name')
                pwd=reuquest.data.get('pwd')
                user=models.User.objects.filter(username=name,password=pwd).first()
                if user:
                    token=get_random(name)
                    models.UserToken.objects.update_or_create(user=user,defaults={'token':token})
                    back_msg['status']='1000'
                    back_msg['msg']='登录成功'
                    back_msg['token']=token
                else:
                    back_msg['msg'] = '用户名或密码错误'
            except Exception as e:
                back_msg['msg']=str(e)
            return Response(back_msg)
    
    
    class Course(APIView):
        authentication_classes = [TokenAuth, ] # 可以放多个认证类
    
        def get(self, request):
            return HttpResponse('get')
    
        def post(self, request):
            return HttpResponse('post')
    

    可以在配置文件中配置全局默认的认证方案

    # drf默认配置
    REST_FRAMEWORK = {
        'DEFAULT_AUTHENTICATION_CLASSES': (
            'rest_framework.authentication.SessionAuthentication',  # session认证
            'rest_framework.authentication.BasicAuthentication',   # 基本认证
        )
    }
    
    # 在项目settings.py中配置
    REST_FRAMEWORK={
        "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.TokenAuth",]
    }
    

    也可以在每个视图类中通过设置authentication_classess属性来设置局部使用

    from rest_framework.authentication import SessionAuthentication, BasicAuthentication
    from rest_framework.views import APIView
    
    class ExampleView(APIView):
        # 类属性
        authentication_classes = [TokenAuth,]
        ...
    

    禁用全局认证

    # 在视图类中
    class MyAuthentication(BaseAuthentication):
        authentication_classes = []
        def authenticate(self, request):
            ...
    

    查找顺序:先在视图类中查找----->项目的setting中找----->drf默认配置找

    认证失败会有两种可能的返回值:

    • 401 Unauthorized 未认证
    • 403 Permission Denied 权限被禁止

    二 权限Permissions

    权限组件:校验用户权限 - 必须登录、所有用户、登录读写游客只读、自定义用户角色
    认证通过:可以进入下一步校验(频率认证)
    认证失败:抛出异常,返回403权限异常结果

    权限控制可以限制用户对于视图的访问和对于具体数据对象的访问。

    • 在执行视图的dispatch()方法前,会先进行视图访问权限的判断
    • 在通过get_object()获取具体对象时,会进行模型对象访问权限的判断

    源码分析

    # APIView---->dispatch---->initial--->self.check_permissions(request)(APIView的对象方法)
    def check_permissions(self, request):
        # 遍历权限对象列表得到一个个权限对象(权限器),进行权限认证
        for permission in self.get_permissions():
            # 权限类一定有一个has_permission权限方法,用来做权限认证的
            # 参数:权限对象self、请求对象request、视图类对象
            # 返回值:有权限返回True,无权限返回False
            if not permission.has_permission(request, self):
                self.permission_denied(
                    request, message=getattr(permission, 'message', None)
                )
    

    使用

    # 写一个类,继承BasePermission,重写has_permission,如果权限通过,就返回True,不通过就返回False
    from rest_framework.permissions import BasePermission
    
    class UserPermission(BasePermission):
        def  has_permission(self, request, view):
            # 不是超级用户,不能访问
            # 由于认证已经过了,request内就有user对象了,当前登录用户
            user=request.user  # 当前登录用户
            # 如果该字段用了choice,通过get_字段名_display()就能取出choice后面的中文
            print(user.get_user_type_display())
            if user.user_type==1:
                return True
            else:
                return False
    

    可以在配置文件中全局设置默认的权限管理类

    REST_FRAMEWORK={
        ...
        "DEFAULT_PERMISSION_CLASSES":["app01.permissions.UserPermission",]
    }
    

    也可以在具体的视图中通过permission_classes属性来设置

    
    from rest_framework.views import APIView
    
    class TestView(APIView):
        permission_classes = (UserPermission,)
        ...
    

    局部禁用

    class TestView(APIView):
        permission_classes = []
        ...
    

    内置提供的权限

    需要和django自带的auth_user表连用

    • AllowAny 允许所有用户
    • IsAuthenticated 仅通过认证的用户
    • IsAdminUser 仅管理员用户
    • IsAuthenticatedOrReadOnly 已经登陆认证的用户可以对数据进行增删改操作,没有登陆认证的只能查看数据。

    可以在配置文件中全局设置默认的权限管理类

    REST_FRAMEWORK = {
        ....
        
        'DEFAULT_PERMISSION_CLASSES': (
            'rest_framework.permissions.IsAuthenticated',
        )
    }
    

    如果未指明,则采用如下默认配置

    'DEFAULT_PERMISSION_CLASSES': (
       'rest_framework.permissions.AllowAny',
    )
    

    也可以在具体的视图中通过permission_classes属性来设置

    from rest_framework.permissions import IsAuthenticated
    from rest_framework.views import APIView
    
    class ExampleView(APIView):
        permission_classes = (IsAuthenticated,)
        ...
    

    举例

    # 创建超级用户,登陆到admin,创建普通用户(注意设置职员状态,也就是能登陆)
    
    # 全局配置IsAuthenticated
    # setting.py
    'DEFAULT_PERMISSION_CLASSES': (
            'rest_framework.permissions.IsAuthenticated',
        )
        
    # urls.py
     path('test/', views.TestView.as_view()),
     
    # views.py
    class TestView(APIView):
        def get(self,request):
            return Response({'msg':'个人中心'})
    # 登陆到admin后台后,直接访问可以,如果没登陆,不能访问
    
    ##注意:如果全局配置了
    rest_framework.permissions.IsAdminUser
    # 就只有管理员能访问,普通用户访问不了
    

    自定义权限

    如需自定义权限,需继承rest_framework.permissions.BasePermission父类,并实现以下两个任何一个方法或全部

    • .has_permission(self, request, view)

      是否可以访问视图, view表示当前视图对象

    • .has_object_permission(self, request, view, obj)

      是否可以访问数据对象, view表示当前视图, obj为数据对象

    例如:

    在当前子应用下,创建一个权限文件permissions.py中声明自定义权限类:

    from rest_framework.permissions import BasePermission
    
    class IsXiaoMingPermission(BasePermission):
        def has_permission(self, request, view):
            if( request.user.username == "xiaoming" ):
                return True
    
    from .permissions import IsXiaoMingPermission
    class StudentViewSet(ModelViewSet):
        queryset = Student.objects.all()
        serializer_class = StudentSerializer
        permission_classes = [IsXiaoMingPermission]
    

    三 限流Throttling

    频率组件:限制视图接口被访问的频率次数 - 限制的条件(IP、id、唯一键)、频率周期时间(s、m、h)、频率的次数(3/s)
    没有达到限次:正常访问接口
    达到限次:限制时间内不能访问,限制时间达到后,可以重新访问

    可以对接口访问的频次进行限制,以减轻服务器压力。

    一般用于付费购买次数,投票等场景使用.

    使用

    可以在配置文件中,使用DEFAULT_THROTTLE_CLASSESDEFAULT_THROTTLE_RATES进行全局配置

    REST_FRAMEWORK = {
        'DEFAULT_THROTTLE_CLASSES': (
            'rest_framework.throttling.AnonRateThrottle',
            'rest_framework.throttling.UserRateThrottle'
        ),
        'DEFAULT_THROTTLE_RATES': {
            'anon': '100/day',
            'user': '1000/day'
        }
    }
    

    DEFAULT_THROTTLE_RATES 可以使用 second, minute, hourday来指明周期。

    也可以在具体视图中通过throttle_classess属性来配置

    from rest_framework.throttling import UserRateThrottle
    from rest_framework.views import APIView
    
    class ExampleView(APIView):
        throttle_classes = (UserRateThrottle,)
        ...
    

    实列

    # 内置的频率限制(限制未登录用户)
    
    # 全局使用  限制未登录用户1分钟访问5次
    REST_FRAMEWORK = {
        'DEFAULT_THROTTLE_CLASSES': (
            'rest_framework.throttling.AnonRateThrottle',
        ),
        'DEFAULT_THROTTLE_RATES': {
            'anon': '5/m',
        }
    }
    
    ## views.py
    from rest_framework.permissions import IsAdminUser
    from rest_framework.authentication import SessionAuthentication,BasicAuthentication
    class TestView1(APIView):
        authentication_classes=[]
        permission_classes = []
        def get(self,request,*args,**kwargs):
            return Response('我是未登录用户')
    
    
    # 局部使用
    from rest_framework.permissions import IsAdminUser
    from rest_framework.authentication import SessionAuthentication,BasicAuthentication
    from rest_framework.throttling import AnonRateThrottle
    class TestView2(APIView):
        authentication_classes=[]
        permission_classes = []
        throttle_classes = [AnonRateThrottle]
        def get(self,request,*args,**kwargs):
            return Response('我是未登录用户')
    

    可选限流类

    1) AnonRateThrottle

    限制所有匿名未认证用户,使用IP区分用户。

    使用DEFAULT_THROTTLE_RATES['anon'] 来设置频次

    2)UserRateThrottle

    限制认证用户,使用User id 来区分。

    使用DEFAULT_THROTTLE_RATES['user'] 来设置频次

    3)ScopedRateThrottle

    限制用户对于每个视图的访问频次,使用ip或user id。

    例如:

    class ContactListView(APIView):
        throttle_scope = 'contacts'
        ...
    
    class ContactDetailView(APIView):
        throttle_scope = 'contacts'
        ...
    
    class UploadView(APIView):
        throttle_scope = 'uploads'
        ...
    REST_FRAMEWORK = {
        'DEFAULT_THROTTLE_CLASSES': (
            'rest_framework.throttling.ScopedRateThrottle',
        ),
        'DEFAULT_THROTTLE_RATES': {
            'contacts': '1000/day',
            'uploads': '20/day'
        }
    }
    

    实例

    全局配置中设置访问频率

        'DEFAULT_THROTTLE_RATES': {
            'anon': '3/minute',
            'user': '10/minute'
        }
    
    from rest_framework.authentication import SessionAuthentication
    from rest_framework.permissions import IsAuthenticated
    from rest_framework.generics import RetrieveAPIView
    from rest_framework.throttling import UserRateThrottle
    
    class StudentAPIView(RetrieveAPIView):
        queryset = Student.objects.all()
        serializer_class = StudentSerializer
        authentication_classes = [SessionAuthentication]
        permission_classes = [IsAuthenticated]
        throttle_classes = (UserRateThrottle,)
    

    自定义频率类

    # 自定义的逻辑
    #(1)取出访问者ip
    #(2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问,在字典里,继续往下走
    #(3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
    #(4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
    #(5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
    class MyThrottles():
        VISIT_RECORD = {}
        def __init__(self):
            self.history=None
        def allow_request(self,request, view):
            #(1)取出访问者ip
            # print(request.META)
            ip=request.META.get('REMOTE_ADDR')
            import time
            ctime=time.time()
            # (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问
            if ip not in self.VISIT_RECORD:
                self.VISIT_RECORD[ip]=[ctime,]
                return True
            self.history=self.VISIT_RECORD.get(ip)
            # (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
            while self.history and ctime-self.history[-1]>60:
                self.history.pop()
            # (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
            # (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
            if len(self.history)<3:
                self.history.insert(0,ctime)
                return True
            else:
                return False
        def wait(self):
            import time
            ctime=time.time()
            return 60-(ctime-self.history[-1])
    

    全局使用

    REST_FRAMEWORK = {
        'DEFAULT_THROTTLE_CLASSES':['app01.utils.MyThrottles',],
    }
    

    局部使用

    #在视图类里使用
    throttle_classes = [MyThrottles,]
    

    继承SimpleRateThrottle自定义频率类

    # 写一个类,继承SimpleRateThrottle,只需要重写get_cache_key 
    from rest_framework.throttling import ScopedRateThrottle,SimpleRateThrottle
    
    #继承SimpleRateThrottle
    class MyThrottle(SimpleRateThrottle):
        scope='luffy'
        def get_cache_key(self, request, view):
            print(request.META.get('REMOTE_ADDR'))
            return request.META.get('REMOTE_ADDR')  # 返回用户ip(返回什么就根据什么限制用户访问频次)
    
    
    # 全局使用
    from utils.throttlings import MyThrottle
    REST_FRAMEWORK={
        'DEFAULT_THROTTLE_CLASSES': (
            'utils.throttling.MyThrottle',
        ),
        'DEFAULT_THROTTLE_RATES': {
            'luffy': '3/m'  # key要跟类中的scop对应
        },
    }
    
    # 局部使用
    # 视图类里加上
    from utils.throttlings import MyThrottle
    class ThView:
        throttle_classes = [MyThrottle,]
        def get(self, request):
            ...
    

    SimpleRateThrottle源码分析

    class SimpleRateThrottle(BaseThrottle):
        
        cache = default_cache
        timer = time.time
        cache_format = 'throttle_%(scope)s_%(ident)s'
        scope = None
        THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
    
        def __init__(self):
            if not getattr(self, 'rate', None):
                self.rate = self.get_rate()
            self.num_requests, self.duration = self.parse_rate(self.rate)
    
        def get_cache_key(self, request, view): 
            raise NotImplementedError('.get_cache_key() must be overridden')
    
        def get_rate(self):
            if not getattr(self, 'scope', None):
                msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                       self.__class__.__name__)
                raise ImproperlyConfigured(msg)
    
            try:
                return self.THROTTLE_RATES[self.scope]
            except KeyError:
                msg = "No default throttle rate set for '%s' scope" % self.scope
                raise ImproperlyConfigured(msg)
    
        def parse_rate(self, rate):
            if rate is None:
                return (None, None)
            num, period = rate.split('/')
            num_requests = int(num)
            duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
            return (num_requests, duration)
    
        def allow_request(self, request, view):
            if self.rate is None:
                return True
            # 当前登录用户的唯一标识
            self.key = self.get_cache_key(request, view)
            if self.key is None:
                return True
            # 用户初次访问缓存为空,self.history为[],是存放时间的列表
            self.history = self.cache.get(self.key, [])
            # 获取一下当前时间,存放到 self.now
            self.now = self.timer()
    
            # 当前访问与第一次访问时间间隔如果大于60s,第一次记录清除,不再算作一次计数
            while self.history and self.history[-1] <= self.now - self.duration:
                self.history.pop()
            # history的长度与限制次数进行比较
            if len(self.history) >= self.num_requests:
                return self.throttle_failure()
            return self.throttle_success()
    
        def throttle_success(self):
            self.history.insert(0, self.now)
            self.cache.set(self.key, self.history, self.duration)
            return True
    
        def throttle_failure(self):
            return False
    
        def wait(self):
            if self.history:
                remaining_duration = self.duration - (self.now - self.history[-1])
            else:
                remaining_duration = self.duration
    
            available_requests = self.num_requests - len(self.history) + 1
            if available_requests <= 0:
                return None
    
            return remaining_duration / float(available_requests)
    
  • 相关阅读:
    金融系列10《发卡行脚本》
    金融系列9《发卡行认证》
    金融系列8《应用密文产生》
    ED/EP系列5《消费指令》
    ED/EP系列4《圈存指令》
    ED/EP系列2《文件结构》
    ED/EP系列1《简介》
    社保系列11《ATR》
    社保系列3《文件结构》
    社保系列2《文件系统》
  • 原文地址:https://www.cnblogs.com/chenwenyin/p/13279128.html
Copyright © 2011-2022 走看看