认证组件
发生位置
APIview 类种的 dispatch 方法执行到 initial 方法 进行 认证组件认证
源码位置
rest_framework.authentication
源码内部需要了解的
# 用户用户自定义的重写继承类 class BaseAuthentication(object): ... # 自定义重写的认证方法 def authenticate(self, request):... # 以下4种为自带 认证 # 基于用户名和密码的认证 class BasicAuthentication(BaseAuthentication):... # 基于 session 的认证 class SessionAuthentication(BaseAuthentication):... # 基于 token 的认证 class TokenAuthentication(BaseAuthentication):... # 基于远端服务的认证 class RemoteUserAuthentication(BaseAuthentication):...
自定义认证函数
from rest_framework.authentication import BaseAuthentication from rest_framework.exceptions import AuthenticationFailed from api.models import * class YTAuth(BaseAuthentication): def authenticate(self, request): token = request.query_params.get('token') obj = UserAuthToken.objects.filter(token=token).first() if not obj: return AuthenticationFailed({'code': 1001, 'erroe': '认证失败'}) return (obj.user.username, obj) # 返回的必须是元组 然后元组的里面含有两个值 并且对应的取值是rquest.user(user对象),和reques.auth(token对象)
视图级别认证
class CommentViewSet(ModelViewSet): queryset = models.Comment.objects.all() serializer_class = app01_serializers.CommentSerializer authentication_classes = [YTAuth, ]
全局认证
# 在settings.py中配置 REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.YTAuth", ] }
权限组件
发生位置
APIview 类种的 dispatch 方法执行到 initial 方法 进行 认证组件执行后,进行权限组件认证
源码位置
rest_framework.permissions
权限组件内部需要了解的
# 自定义重写的类 class BasePermission(object): ... # 自定义重写的方法 def has_permission(self, request, view): ... # AllowAny 允许所有用户 class AllowAny(BasePermission):... # IsAuthenticated 仅通过认证的用户 class IsAuthenticated(BasePermission):... # IsAdminUser 仅管理员用户 class IsAdminUser(BasePermission):... # IsAuthenticatedOrReadOnly 认证的用户可以完全操作,否则只能get读取 class IsAuthenticatedOrReadOnly(BasePermission):...
自定义权限组件
from rest_framework.permissions import BasePermission class MyPermission(BasePermission): message = 'VIP用户才能访问' def has_permission(self, request, view): # 认证判断已经提供了request.user if request.user and request.user.type == 2: return True else: return False
视图级别使用自定义权限组件
class CommentViewSet(ModelViewSet):
queryset = models.Comment.objects.all() serializer_class = app01_serializers.CommentSerializer authentication_classes = [YTAuth, ] permission_classes = [YTPermission, ]
全局级别使用自定义权限组件
# 在settings.py中设置rest framework相关配置项
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.YTAuth", ],
"DEFAULT_PERMISSION_CLASSES": ["app01.utils.YTPermission", ]
}
频率限制
发生位置
APIview 类种的 dispatch 方法执行到 initial 方法 进行 认证组件执行,权限组件认证后 ,进行频率组件的认证
源码位置
rest_framework.throttling
权限组件内部需要了解的
# 需要自定义重写的类 class BaseThrottle(object): ... # 自定义频率的逻辑实现方法 def allow_request(self, request, view): ... # 自定义 限制后逻辑实现方法 def wait(self): ... # 内置的频率控制组件 常用的是这个 class SimpleRateThrottle(BaseThrottle): ... # 其他都不怎么用 class AnonRateThrottle(SimpleRateThrottle): ... # 其他都不怎么用 class UserRateThrottle(SimpleRateThrottle): # 其他都不怎么用 class ScopedRateThrottle(SimpleRateThrottle):
自定义频率组件
import time VISIT_RECORD = {} class YTThrottle(object): # 直接继承 object 就可以了 def __init__(self): self.history = None def allow_request(self, request, view): """ 自定义频率限制60秒内只能访问三次 """ # 获取用户IP ip = request.META.get("REMOTE_ADDR") timestamp = time.time() if ip not in VISIT_RECORD: VISIT_RECORD[ip] = [timestamp, ] return True history = VISIT_RECORD[ip] self.history = history history.insert(0, timestamp) while history and history[-1] < timestamp - 60: history.pop() if len(history) > 3: return False else: return True def wait(self): """ 限制时间还剩多少 """ timestamp = time.time() return 60 - (timestamp - self.history[-1])
视图级别使用自定义频率组件
class CommentViewSet(ModelViewSet): queryset = models.Comment.objects.all() serializer_class = app01_serializers.CommentSerializer throttle_classes = [YTThrottle, ]
全局级别使用自定义频率组件
# 在settings.py中设置rest framework相关配置项 REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.YTAuth", ], "DEFAULT_PERMISSION_CLASSES": ["app01.utils.YTPermission", ] "DEFAULT_THROTTLE_CLASSES": ["app01.utils.YTThrottle", ] }
ps:
使用内置 SimpleRateThrottle 频率控制组件
from rest_framework.throttling import SimpleRateThrottle class VisitThrottle(SimpleRateThrottle): scope = "xxx" def get_cache_key(self, request, view): return self.get_ident(request)
全局使用
# 在settings.py中设置rest framework相关配置项 REST_FRAMEWORK = { ...
"DEFAULT_THROTTLE_RATES": { "xxx": "5/m", # 每分钟5次最多 } }