zoukankan      html  css  js  c++  java
  • Django REST framework 自定义(认证、权限、访问频率)组件

    本篇随笔在 "Django REST framework 初识" 基础上扩展

    一、认证组件

    # models.py
    class Account(models.Model):
        """用户表"""
        username = models.CharField(verbose_name="用户名", max_length=64, unique=True)
        password = models.CharField(verbose_name="密码", max_length=64)
    
    class UserToken(models.Model):
        """用户Token表"""
        user = models.OneToOneField(to="Account")
        token = models.CharField(max_length=64, unique=True)

    当然也可以使用django自带的 auth_user 表来保存用户信息,Token表一对一关联这张表或者继承这张表:

    from django.contrib.auth.models import User
    class Token(models.Model):
        user = models.OneToOneField(User)
        token = models.CharField(max_length=64)
    
    from django.contrib.auth.models import AbstractUser
    class Token(AbstractUser):
        token = models.CharField(max_length=64)

    auth.py

    from rest_framework import authentication
    from rest_framework import exceptions
    from api import models
    
    class UserTokenAuth(authentication.BaseAuthentication):
        """用户身份认证"""
        def authenticate(self, request):
            token = request.query_params.get("token")
            obj = models.UserToken.objects.filter(token=token).first()
            if not obj:
                raise exceptions.AuthenticationFailed({"code": 200, "error": "用户身份认证失败!"})
            else:
                return (obj.user.username, obj)

    Views.py

    import time
    import hashlib
    from rest_framework import viewsets
    from rest_framework.views import APIView
    from rest_framework.response import Response
    from django.core.exceptions import ObjectDoesNotExist
    from api import models
    from appxx import serializers
    from appxx.auth.auth import UserTokenAuth
    
    class LoginView(APIView):
        """
        用户认证接口
        """
        def post(self, request, *args, **kwargs):
            rep = {"code": 1000}
            username = request.data.get("username")
            password = request.data.get("password")
            try:
                user = models.Account.objects.get(username=username, password=password)
                token = self.get_token(user.password)
                rep["token"] = token
                models.UserToken.objects.update_or_create(user=user, defaults={"token": token})
            except ObjectDoesNotExist as e:
                rep["code"] = 1001
                rep["error"] = "用户名或密码错误"
            except Exception as e:
                rep["code"] = 1002
                rep["error"] = "发生错误,请重试"
            return Response(rep)
    
        @staticmethod
        def get_token(password):
            timestamp = str(time.time())
            md5 = hashlib.md5(bytes(password, encoding="utf-8"))
            md5.update(bytes(timestamp, encoding="utf-8"))
            return md5.hexdigest()
    
    class BookViewSet(viewsets.ModelViewSet):
        authentication_classes = [utils.AuthToken]
        queryset = models.Book.objects.all()
        serializer_class = serializers.BookSerializer

    urls.py

    from django.conf.urls import url, include
    from rest_framework.routers import DefaultRouter
    from appxx import views
    
    router = DefaultRouter()
    router.register(r"books", views.BookViewSet)
    router.register(r"publishers", views.PublisherViewSet)
    
    urlpatterns = [
        url(r"^login/$", views.LoginView.as_view(), name="login"),
        url(r"", include(router.urls)),
    ]

    局部认证(哪个视图类需要认证就在哪加上)

    如果需要每条URL都加上身份认证,那么是不是views.py中每个对应的类视图都加上authentication_classes呢?那多麻烦,有没有更简便的方法?请看下面如何设置全局的认证。

    全局认证

    在settings.py中设置:

    REST_FRAMEWORK = {
        "DEFAULT_AUTHENTICATION_CLASSES": ["appxx.utils.TokenAuthentication",],
        # "UNAUTHENTICATED_USER": None,   # 匿名,request.user = None
        # "UNAUTHENTICATED_TOKEN": None,  # 匿名,request.auth = None
    }

    可以看到,AuthToken 就是 BookViewSet 用到的 authentication_classes,这样views.py中的每个类视图都不需要加 authentication_classes 了;每条URL都必须经过此认证才能访问。

    class BookViewSet(viewsets.ModelViewSet):
        queryset = models.Book.objects.all()
        serializer_class = serializers.BookSerializer
    
    class PublisherViewSet(viewsets.ModelViewSet):
        queryset = models.Publisher.objects.all()
        serializer_class = serializers.PublisherSerializer

    二、权限组件

    修改模型表,给用户加上用户类型字段:

    class UserProfile(models.Model):
        username = models.CharField(verbose_name="用户名", max_length=16)
        password = models.CharField(verbose_name="密码", max_length=64)
        user_type_choices = ((1, "管理员"), (2, "普通用户"), (3, "VIP"))
        user_type = models.SmallIntegerField(choices=user_type_choices, default=2)
    class UserTypePermission(permissions.BasePermission):
        """权限认证"""
        message = "只有管理员才能访问"
    
        def has_permission(self, request, view):
            user = request.user
            try:
                user_type = models.UserProfile.objects.filter(username=user).first().user_type
            except AttributeError:
                return False
            if user_type == 1:
                return True
            else:
                return False

    局部权限

    class BookViewSet(viewsets.ModelViewSet):
        permission_classes = [utils.UserTypePermission]
        queryset = models.Book.objects.all()
        serializer_class = serializers.BookSerializer

    全局权限

    REST_FRAMEWORK = {"DEFAULT_PERMISSION_CLASSES": ["appxx.utils.UserTypePermission",],
    }

    三、访问频率组件

    import time
    
    visit_record = {}  # 可以放在redis中
    class IpRateThrottle(object):
        """60s内只能访问3次"""
        def __init__(self):
            self.history = None
    
        def allow_request(self, request, view):
            ip = request.META.get("REMOTE_ADDR")  # 获取用户IP
            current_time = time.time()
            if ip not in visit_record:  # 用户第一次访问
                visit_record[ip] = [current_time]
                return True
    
            history = visit_record.get(ip)
            self.history = history
    
            while history and history[-1] < current_time - 60:
                history.pop()
    
            if len(history) < 3:
                history.insert(0, current_time)
                return True
            # return True    # 表示可以继续访问
            # return False   # 表示访问频率太高,被限制
    
        def wait(self):
            """还需要等多久才能访问"""
            current_time = time.time()
            return 60 - (current_time - self.history[-1])

    局部节流

    class BookViewSet(viewsets.ModelViewSet):
        throttle_classes = [IpRateThrottle]
        queryset = models.Book.objects.all()
        serializer_class = serializers.BookSerializer

    全局节流

    REST_FRAMEWORK = {
        "DEFAULT_THROTTLE_CLASSES": ["appxx.utils.IpRateThrottle",],
    }

    PS:

    匿名用户:无法控制,因为用户可以换代理IP
    登录用户:如果有很多账号,也无法限制

  • 相关阅读:
    MongoDB慢查询性能分析
    redis的LRU算法(二)
    Skynet服务热点火焰图分析
    内存爆灯
    时区问题
    与机器共生
    bug狩猎
    Lesson Learned
    下划线引起的血案
    Intel的CPU漏洞:Spectre
  • 原文地址:https://www.cnblogs.com/believepd/p/10196971.html
Copyright © 2011-2022 走看看