zoukankan      html  css  js  c++  java
  • Django REST framework 自定义(认证、权限、访问频率)组件

    本篇随笔在 "Django REST framework 初识" 基础上扩展

    一、认证组件

    # models.py
    class Account(models.Model):
        """用户表"""
        username = models.CharField(verbose_name="用户名", max_length=64, unique=True)
        password = models.CharField(verbose_name="密码", max_length=64)
    
    class UserToken(models.Model):
        """用户Token表"""
        user = models.OneToOneField(to="Account")
        token = models.CharField(max_length=64, unique=True)

    当然也可以使用django自带的 auth_user 表来保存用户信息,Token表一对一关联这张表或者继承这张表:

    from django.contrib.auth.models import User
    class Token(models.Model):
        user = models.OneToOneField(User)
        token = models.CharField(max_length=64)
    
    from django.contrib.auth.models import AbstractUser
    class Token(AbstractUser):
        token = models.CharField(max_length=64)

    auth.py

    from rest_framework import authentication
    from rest_framework import exceptions
    from api import models
    
    class UserTokenAuth(authentication.BaseAuthentication):
        """用户身份认证"""
        def authenticate(self, request):
            token = request.query_params.get("token")
            obj = models.UserToken.objects.filter(token=token).first()
            if not obj:
                raise exceptions.AuthenticationFailed({"code": 200, "error": "用户身份认证失败!"})
            else:
                return (obj.user.username, obj)

    Views.py

    import time
    import hashlib
    from rest_framework import viewsets
    from rest_framework.views import APIView
    from rest_framework.response import Response
    from django.core.exceptions import ObjectDoesNotExist
    from api import models
    from appxx import serializers
    from appxx.auth.auth import UserTokenAuth
    
    class LoginView(APIView):
        """
        用户认证接口
        """
        def post(self, request, *args, **kwargs):
            rep = {"code": 1000}
            username = request.data.get("username")
            password = request.data.get("password")
            try:
                user = models.Account.objects.get(username=username, password=password)
                token = self.get_token(user.password)
                rep["token"] = token
                models.UserToken.objects.update_or_create(user=user, defaults={"token": token})
            except ObjectDoesNotExist as e:
                rep["code"] = 1001
                rep["error"] = "用户名或密码错误"
            except Exception as e:
                rep["code"] = 1002
                rep["error"] = "发生错误,请重试"
            return Response(rep)
    
        @staticmethod
        def get_token(password):
            timestamp = str(time.time())
            md5 = hashlib.md5(bytes(password, encoding="utf-8"))
            md5.update(bytes(timestamp, encoding="utf-8"))
            return md5.hexdigest()
    
    class BookViewSet(viewsets.ModelViewSet):
        authentication_classes = [utils.AuthToken]
        queryset = models.Book.objects.all()
        serializer_class = serializers.BookSerializer

    urls.py

    from django.conf.urls import url, include
    from rest_framework.routers import DefaultRouter
    from appxx import views
    
    router = DefaultRouter()
    router.register(r"books", views.BookViewSet)
    router.register(r"publishers", views.PublisherViewSet)
    
    urlpatterns = [
        url(r"^login/$", views.LoginView.as_view(), name="login"),
        url(r"", include(router.urls)),
    ]

    局部认证(哪个视图类需要认证就在哪加上)

    如果需要每条URL都加上身份认证,那么是不是views.py中每个对应的类视图都加上authentication_classes呢?那多麻烦,有没有更简便的方法?请看下面如何设置全局的认证。

    全局认证

    在settings.py中设置:

    REST_FRAMEWORK = {
        "DEFAULT_AUTHENTICATION_CLASSES": ["appxx.utils.TokenAuthentication",],
        # "UNAUTHENTICATED_USER": None,   # 匿名,request.user = None
        # "UNAUTHENTICATED_TOKEN": None,  # 匿名,request.auth = None
    }

    可以看到,AuthToken 就是 BookViewSet 用到的 authentication_classes,这样views.py中的每个类视图都不需要加 authentication_classes 了;每条URL都必须经过此认证才能访问。

    class BookViewSet(viewsets.ModelViewSet):
        queryset = models.Book.objects.all()
        serializer_class = serializers.BookSerializer
    
    class PublisherViewSet(viewsets.ModelViewSet):
        queryset = models.Publisher.objects.all()
        serializer_class = serializers.PublisherSerializer

    二、权限组件

    修改模型表,给用户加上用户类型字段:

    class UserProfile(models.Model):
        username = models.CharField(verbose_name="用户名", max_length=16)
        password = models.CharField(verbose_name="密码", max_length=64)
        user_type_choices = ((1, "管理员"), (2, "普通用户"), (3, "VIP"))
        user_type = models.SmallIntegerField(choices=user_type_choices, default=2)
    class UserTypePermission(permissions.BasePermission):
        """权限认证"""
        message = "只有管理员才能访问"
    
        def has_permission(self, request, view):
            user = request.user
            try:
                user_type = models.UserProfile.objects.filter(username=user).first().user_type
            except AttributeError:
                return False
            if user_type == 1:
                return True
            else:
                return False

    局部权限

    class BookViewSet(viewsets.ModelViewSet):
        permission_classes = [utils.UserTypePermission]
        queryset = models.Book.objects.all()
        serializer_class = serializers.BookSerializer

    全局权限

    REST_FRAMEWORK = {"DEFAULT_PERMISSION_CLASSES": ["appxx.utils.UserTypePermission",],
    }

    三、访问频率组件

    import time
    
    visit_record = {}  # 可以放在redis中
    class IpRateThrottle(object):
        """60s内只能访问3次"""
        def __init__(self):
            self.history = None
    
        def allow_request(self, request, view):
            ip = request.META.get("REMOTE_ADDR")  # 获取用户IP
            current_time = time.time()
            if ip not in visit_record:  # 用户第一次访问
                visit_record[ip] = [current_time]
                return True
    
            history = visit_record.get(ip)
            self.history = history
    
            while history and history[-1] < current_time - 60:
                history.pop()
    
            if len(history) < 3:
                history.insert(0, current_time)
                return True
            # return True    # 表示可以继续访问
            # return False   # 表示访问频率太高,被限制
    
        def wait(self):
            """还需要等多久才能访问"""
            current_time = time.time()
            return 60 - (current_time - self.history[-1])

    局部节流

    class BookViewSet(viewsets.ModelViewSet):
        throttle_classes = [IpRateThrottle]
        queryset = models.Book.objects.all()
        serializer_class = serializers.BookSerializer

    全局节流

    REST_FRAMEWORK = {
        "DEFAULT_THROTTLE_CLASSES": ["appxx.utils.IpRateThrottle",],
    }

    PS:

    匿名用户:无法控制,因为用户可以换代理IP
    登录用户:如果有很多账号,也无法限制

  • 相关阅读:
    java 在线网络考试系统源码 springboot mybaits vue.js 前后分离跨域
    springboot 整合flowable 项目源码 mybiats vue.js 前后分离 跨域
    flowable Springboot vue.js 前后分离 跨域 有代码生成器 工作流
    Flowable 工作流 Springboot vue.js 前后分离 跨域 有代码生成器
    java 企业 网站源码 后台 springmvc SSM 前台 静态化 代码生成器
    java 进销存 商户管理 系统 管理 库存管理 销售报表springmvc SSM项目
    基于FPGA的电子计算器设计(中)
    基于FPGA的电子计算器设计(上)
    FPGA零基础学习:SPI 协议驱动设计
    Signal tap 逻辑分析仪使用教程
  • 原文地址:https://www.cnblogs.com/believepd/p/10196971.html
Copyright © 2011-2022 走看看