zoukankan      html  css  js  c++  java
  • 04 Django REST Framework 认证、权限和限制

    目前,我们的API对谁可以编辑或删除代码段没有任何限制。我们希望有更高级的行为,以确保:

    • 代码片段始终与创建者相关联。
    • 只有通过身份验证的用户可以创建片段。
    • 只有代码片段的创建者可以更新或删除它。
    • 未经身份验证的请求应具有完全只读访问权限。

    01-认证

    REST framework 提供了一些开箱即用的身份验证方案,并且还允许你实现自定义方案。

    1.1-自定义Token认证

    定义一个用户表和一个保存用户Token的表:

    class UserInfo(models.Model):
        username = models.CharField(max_length=16)
        password = models.CharField(max_length=32)
        type = models.SmallIntegerField(
            choices=((1, '普通用户'), (2, 'VIP用户')),
            default=1
        )
    
    
    class Token(models.Model):
        user = models.OneToOneField(to='UserInfo')
        token_code = models.CharField(max_length=128)

    1.2-定义一个登录视图

    def get_random_token(username):
        """
        根据用户名和时间戳生成随机token
        :param username:
        :return:
        """
        import hashlib, time
        timestamp = str(time.time())
        m = hashlib.md5(bytes(username, encoding="utf8"))
        m.update(bytes(timestamp, encoding="utf8"))
        return m.hexdigest()
    
    
    class LoginView(APIView):
        """
        校验用户名密码是否正确从而生成token的视图
        """
        def post(self, request):
            res = {"code": 0}
            print(request.data)
            username = request.data.get("username")
            password = request.data.get("password")
    
            user = models.UserInfo.objects.filter(username=username, password=password).first()
            if user:
                # 如果用户名密码正确
                token = get_random_token(username)
                models.Token.objects.update_or_create(defaults={"token_code": token}, user=user)
                res["token"] = token
            else:
                res["code"] = 1
                res["error"] = "用户名或密码错误"
            return Response(res)

    1.3-定义一个认证类

    我们自己写的认证类都要继承内置认证类 "BaseAuthentication"

    class BaseAuthentication(object):
        """
        All authentication classes should extend BaseAuthentication.
        """
    
        def authenticate(self, request):
            """
            Authenticate the request and return a two-tuple of (user, token).
            """
            #内置的认证类,authenticate方法,如果不自己写,默认则抛出异常
            raise NotImplementedError(".authenticate() must be overridden.")
    
        def authenticate_header(self, request):
            """
            Return a string to be used as the value of the `WWW-Authenticate`
            header in a `401 Unauthenticated` response, or `None` if the
            authentication scheme should return `403 Permission Denied` responses.
            """
            #authenticate_header方法,作用是当认证失败的时候,返回的响应头
            pass
    BaseAuthentication源码
    from rest_framework.authentication import BaseAuthentication
    from rest_framework.exceptions import AuthenticationFailed
    
    
    class MyAuth(BaseAuthentication):
        def authenticate(self, request):
            if request.method in ["POST", "PUT", "DELETE"]:
                request_token = request.data.get("token", None)
                if not request_token:
                    raise AuthenticationFailed('缺少token')
                token_obj = models.Token.objects.filter(token_code=request_token).first()
                if not token_obj:
                    raise AuthenticationFailed('无效的token')
                return token_obj.user.username, None
            else:
                return None, None
    
        def authenticate_header(self, request):
            pass

    1.4-视图级别认证

    class CommentViewSet(ModelViewSet):
    
        queryset = models.Comment.objects.all()
        serializer_class = app01_serializers.CommentSerializer
        authentication_classes = [MyAuth, ]

    1.5-全局级别认证

    # 在settings.py中配置
    REST_FRAMEWORK = {
        "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ]
    }

    1.6 在settings里面设置的全局认证,所有业务都需要经过认证,如果想让某个不需要认证,只需要在其中添加下面的代码:

    authentication_classes = []    #里面为空,代表不需要认证

    02-权限

    只有VIP用户才能看的内容。premission.py

    2.1-内置权限验证类

    django rest framework 提供了内置的权限验证类,其本质都是定义has_permission()方法对权限进行验证:

    2.2-自定义一个权限类

    # 自定义权限
    class MyPermission(BasePermission):
        message = 'VIP用户才能访问'
    
        def has_permission(self, request, view):
            """
            自定义权限只有VIP用户才能访问
            """
            # 因为在进行权限判断之前已经做了认证判断,所以这里可以直接拿到request.user
            if request.user and request.user.type == 2:  # 如果是VIP用户
                return True
            else:
                return False

    2.3-视图级别配置

    class CommentViewSet(ModelViewSet):
    
        queryset = models.Comment.objects.all()
        serializer_class = app01_serializers.CommentSerializer
        authentication_classes = [MyAuth, ]
        permission_classes = [MyPermission, ]

    2.4-全局级别配置

    # 在settings.py中设置rest framework相关配置项
    REST_FRAMEWORK = {
        "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ],
        "DEFAULT_PERMISSION_CLASSES": ["app01.utils.MyPermission", ]
    }

    在全局配置权限后,因部分视图不想使用全局的权限限制,所以 需要在部分视图里配置:

    permission_classes = [OrdinaryPremission,]    #不用全局的权限配置的话,这里就要写自己的局部权限

    03-限制

    DRF内置了基本的限制类,首先我们自己动手写一个限制类,熟悉下限制组件的执行过程。

    3.1-自定义限制类

    VISIT_RECORD = {}
    # 自定义限制
    class MyThrottle(object):
    
        def __init__(self):
            self.history = None
    
        def allow_request(self, request, view):
            """
            自定义频率限制60秒内只能访问三次
            """
            # 获取用户IP
            ip = request.META.get("REMOTE_ADDR")
            timestamp = time.time()
            if ip not in VISIT_RECORD:
                VISIT_RECORD[ip] = [timestamp, ]
                return True
            history = VISIT_RECORD[ip]
            self.history = history
            history.insert(0, timestamp)
            while history and history[-1] < timestamp - 60:
                history.pop()
            if len(history) > 3:
                return False
            else:
                return True
    
        def wait(self):
            """
            限制时间还剩多少
            """
            timestamp = time.time()
            return 60 - (timestamp - self.history[-1])

    3.2-视图级别配置

    class CommentViewSet(ModelViewSet):
    
        queryset = models.Comment.objects.all()
        serializer_class = app01_serializers.CommentSerializer
        throttle_classes = [MyThrottle, ]

    3.3-全局级别配置

    # 在settings.py中设置rest framework相关配置项
    REST_FRAMEWORK = {
        "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ],
        "DEFAULT_PERMISSION_CLASSES": ["app01.utils.MyPermission", ]
        "DEFAULT_THROTTLE_CLASSES": ["app01.utils.MyThrottle", ]
    }

    04-使用内置限制类

    我们可以通过继承SimpleRateThrottle类,来实现节流,会更加的简单,因为SimpleRateThrottle里面都帮我们写好了。

    class SimpleRateThrottle(BaseThrottle):
        """
        A simple cache implementation, that only requires `.get_cache_key()`
        to be overridden.
    
        The rate (requests / seconds) is set by a `rate` attribute on the View
        class.  The attribute is a string of the form 'number_of_requests/period'.
    
        Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
    
        Previous request information used for throttling is stored in the cache.
        """
        cache = default_cache
        timer = time.time
        cache_format = 'throttle_%(scope)s_%(ident)s'
        scope = None   #这个值自定义,写什么都可以
        THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
    
        def __init__(self):
            if not getattr(self, 'rate', None):
                self.rate = self.get_rate()
            self.num_requests, self.duration = self.parse_rate(self.rate)
    
        def get_cache_key(self, request, view):
            """
            Should return a unique cache-key which can be used for throttling.
            Must be overridden.
    
            May return `None` if the request should not be throttled.
            """
            raise NotImplementedError('.get_cache_key() must be overridden')
    
        def get_rate(self):
            """
            Determine the string representation of the allowed request rate.
            """
            if not getattr(self, 'scope', None):
                msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                       self.__class__.__name__)
                raise ImproperlyConfigured(msg)
    
            try:
                return self.THROTTLE_RATES[self.scope]
            except KeyError:
                msg = "No default throttle rate set for '%s' scope" % self.scope
                raise ImproperlyConfigured(msg)
    
        def parse_rate(self, rate):
            """
            Given the request rate string, return a two tuple of:
            <allowed number of requests>, <period of time in seconds>
            """
            if rate is None:
                return (None, None)
            num, period = rate.split('/')
            num_requests = int(num)
            duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
            return (num_requests, duration)
    
        def allow_request(self, request, view):
            """
            Implement the check to see if the request should be throttled.
    
            On success calls `throttle_success`.
            On failure calls `throttle_failure`.
            """
            if self.rate is None:
                return True
    
            self.key = self.get_cache_key(request, view)
            if self.key is None:
                return True
    
            self.history = self.cache.get(self.key, [])
            self.now = self.timer()
    
            # Drop any requests from the history which have now passed the
            # throttle duration
            while self.history and self.history[-1] <= self.now - self.duration:
                self.history.pop()
            if len(self.history) >= self.num_requests:
                return self.throttle_failure()
            return self.throttle_success()
    
        def throttle_success(self):
            """
            Inserts the current request's timestamp along with the key
            into the cache.
            """
            self.history.insert(0, self.now)
            self.cache.set(self.key, self.history, self.duration)
            return True
    
        def throttle_failure(self):
            """
            Called when a request to the API has failed due to throttling.
            """
            return False
    
        def wait(self):
            """
            Returns the recommended next request time in seconds.
            """
            if self.history:
                remaining_duration = self.duration - (self.now - self.history[-1])
            else:
                remaining_duration = self.duration
    
            available_requests = self.num_requests - len(self.history) + 1
            if available_requests <= 0:
                return None
    
            return remaining_duration / float(available_requests)
    SimpleRateThrottle 源码
    from rest_framework.throttling import SimpleRateThrottle
    
    class VisitThrottle(SimpleRateThrottle):
        '''匿名用户60s只能访问三次(根据ip)'''
        scope = 'NBA'   #这里面的值,自己随便定义,settings里面根据这个值配置Rate
    
        def get_cache_key(self, request, view):
            #通过ip限制节流
            return self.get_ident(request)
    
    class UserThrottle(SimpleRateThrottle):
        '''登录用户60s可以访问10次'''
        scope = 'NBAUser'    #这里面的值,自己随便定义,settings里面根据这个值配置Rate
    
        def get_cache_key(self, request, view):
            return request.user.username

    4.1-全局配置

    # 在settings.py中设置rest framework相关配置项
    REST_FRAMEWORK = {
        "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ],
        # "DEFAULT_PERMISSION_CLASSES": ["app01.utils.MyPermission", ]
        "DEFAULT_THROTTLE_CLASSES": ["app01.utils.VisitThrottle", ],
        # 设置访问频率
        "DEFAULT_THROTTLE_RATES": {
            'NBA':'3/m',         #没登录用户3/m,NBA就是scope定义的值
            'NBAUser':'10/m',    #登录用户10/m,NBAUser就是scope定义的值
        }
    }

    4.2-局部配置

    class AuthView(APIView):
        .
        .    
        .
        # 默认的节流是登录用户(10/m),AuthView不需要登录,这里用匿名用户的节流(3/m)
        throttle_classes = [VisitThrottle,]
       .
        .

     4.3-BaseThrottle

    1. 自己要写allow_request和wait方法
    2. get_ident就是获取ip

    源码如下:

    class BaseThrottle(object):
        """
        Rate throttling of requests.
        """
    
        def allow_request(self, request, view):
            """
            Return `True` if the request should be allowed, `False` otherwise.
            """
            raise NotImplementedError('.allow_request() must be overridden')
    
        def get_ident(self, request):
            """
            Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
            if present and number of proxies is > 0. If not use all of
            HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
            """
            xff = request.META.get('HTTP_X_FORWARDED_FOR')
            remote_addr = request.META.get('REMOTE_ADDR')
            num_proxies = api_settings.NUM_PROXIES
    
            if num_proxies is not None:
                if num_proxies == 0 or xff is None:
                    return remote_addr
                addrs = xff.split(',')
                client_addr = addrs[-min(num_proxies, len(addrs))]
                return client_addr.strip()
    
            return ''.join(xff.split()) if xff else remote_addr
    
        def wait(self):
            """
            Optionally, return a recommended number of seconds to wait before
            the next request.
            """
            return None
  • 相关阅读:
    docx python
    haozip 命令行解压文件
    python 升级2.7版本到3.7
    Pyautogui
    python 库搜索技巧
    sqlserver学习笔记
    vim使用
    三极管工作原理分析
    串口扩展方案+简单自制电平转换电路
    功率二极管使用注意
  • 原文地址:https://www.cnblogs.com/pgxpython/p/10287898.html
Copyright © 2011-2022 走看看