一 认证Authentication
认证组件:校验用户 - 游客、合法用户、非法用户
游客:代表校验通过,直接进入下一步校验(权限校验)
合法用户:代表校验通过,将用户存储在request.user中,再进入下一步校验(权限校验)
非法用户:代表校验失败,抛出异常,返回403权限异常结果
认证源码分析
#1 APIVIew----> dispatch方法----> self.initial(request, *args, **kwargs)---->有认证,权限,频率
#2 只读认证源码: self.perform_authentication(request)
#3 self.perform_authentication(request)中就一句话:request.user,(这里的request对象被dispatch方法中的self.initialize_request(request, *args, **kwargs)方法重新封装成了drf自己的request对象)所以需要去drf的Request对象中找user属性(方法)
@property
def user(self):
"""
Returns the user associated with the current request, as authenticated
by the authentication classes provided to the request.
"""
if not hasattr(self, '_user'):
with wrap_attributeerrors():
self._authenticate()
return self._user
def initialize_request(self, request, *args, **kwargs):
"""
Returns the initial request object.
"""
parser_context = self.get_parser_context(request)
return Request(
request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(), # 获得一个列表,列表中是一个个类的对象
negotiator=self.get_content_negotiator(),
parser_context=parser_context
)
def get_authenticators(self):
"""
Instantiates and returns the list of authenticators that this view can use.
"""
# authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES,自己在视图类中配置的认证类(一个列表,列表中是类名)
return [auth() for auth in self.authentication_classes]
#4 Request类中的user方法,刚开始来,没有_user,走 self._authenticate()
#5 核心,就是Request类的 _authenticate(self):
def _authenticate(self):
# 遍历拿到一个个认证器,进行认证
# self.authenticators配置的一堆认证类产生的认证类对象组成的 list
#self.authenticators 你在视图类中配置的一个个的认证类:authentication_classes=[认证类1,认证类2],对象的列表
for authenticator in self.authenticators: # 这里的authenticators就是在self.initialize_request中封装的authenticators,是一个列表,里面是你自定义的认证类的对象
try:
# 认证器(对象)调用认证方法authenticate(认证类对象self, request请求对象)
# 返回值:登陆的用户与认证的信息组成的 tuple
# 该方法被try包裹,代表该方法会抛异常,抛异常就代表认证失败
user_auth_tuple = authenticator.authenticate(self) # 注意这self是request对象,authenticate就是你自定义的认证类中重写的authenticate方法
except exceptions.APIException:
self._not_authenticated()
raise
# 返回值的处理
"""
!!!!如果你配置了多个认证类,一定要在最后一个认证类中返回值,否则后面的认证类执行不了
"""
if user_auth_tuple is not None:
self._authenticator = authenticator
# 如果有返回值,就将 登陆用户 与 登陆认证 分别保存到 request.user、request.auth
self.user, self.auth = user_auth_tuple
return
# 如果返回值user_auth_tuple为空,代表认证通过,但是没有 登陆用户 与 登陆认证信息,代表游客
self._not_authenticated()
自定义认证方案
# models.py
class User(models.Model):
username=models.CharField(max_length=32)
password=models.CharField(max_length=32)
user_type=models.IntegerField(choices=((1,'超级用户'),(2,'普通用户'),(3,'二笔用户')))
class UserToken(models.Model):
user=models.OneToOneField(to='User')
token=models.CharField(max_length=64)
# auth.py(新建认证类)
from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed
from app01.models import UserToken
class TokenAuth():
# token方在url中
def authenticate(self, request):
token = request.GET.get('token')
token_obj = models.UserToken.objects.filter(token=token).first()
if token_obj:
return
else:
raise AuthenticationFailed('认证失败')
# token放在请求头中
# def authenticate(self, request):
# # token = request.META.get('HTTP_TOKEN', None) # 注意token被重新拼接了
# # if not token:
# # raise AuthenticationFailed('非法查看')
# # obj = UserToken.objects.filter(token=token).first()
# # if obj:
# # return obj.user, token
# # raise AuthenticationFailed('认证失败')
# views.py
from rest_framework.views import APIView
from app01 import models
def get_random(name):
import hashlib
import time
md=hashlib.md5()
md.update(bytes(str(time.time()),encoding='utf-8'))
md.update(bytes(name,encoding='utf-8'))
return md.hexdigest()
class Login(APIView):
authentication_classes = []
def post(self,reuquest):
back_msg={'status':1001,'msg':None}
try:
name=reuquest.data.get('name')
pwd=reuquest.data.get('pwd')
user=models.User.objects.filter(username=name,password=pwd).first()
if user:
token=get_random(name)
models.UserToken.objects.update_or_create(user=user,defaults={'token':token})
back_msg['status']='1000'
back_msg['msg']='登录成功'
back_msg['token']=token
else:
back_msg['msg'] = '用户名或密码错误'
except Exception as e:
back_msg['msg']=str(e)
return Response(back_msg)
class Course(APIView):
authentication_classes = [TokenAuth, ] # 可以放多个认证类
def get(self, request):
return HttpResponse('get')
def post(self, request):
return HttpResponse('post')
可以在配置文件中配置全局默认的认证方案
# drf默认配置
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.SessionAuthentication', # session认证
'rest_framework.authentication.BasicAuthentication', # 基本认证
)
}
# 在项目settings.py中配置
REST_FRAMEWORK={
"DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.TokenAuth",]
}
也可以在每个视图类中通过设置authentication_classess属性来设置局部使用
from rest_framework.authentication import SessionAuthentication, BasicAuthentication
from rest_framework.views import APIView
class ExampleView(APIView):
# 类属性
authentication_classes = [TokenAuth,]
...
禁用全局认证
# 在视图类中
class MyAuthentication(BaseAuthentication):
authentication_classes = []
def authenticate(self, request):
...
查找顺序:先在视图类中查找----->项目的setting中找----->drf默认配置找
认证失败会有两种可能的返回值:
- 401 Unauthorized 未认证
- 403 Permission Denied 权限被禁止
二 权限Permissions
权限组件:校验用户权限 - 必须登录、所有用户、登录读写游客只读、自定义用户角色
认证通过:可以进入下一步校验(频率认证)
认证失败:抛出异常,返回403权限异常结果
权限控制可以限制用户对于视图的访问和对于具体数据对象的访问。
- 在执行视图的dispatch()方法前,会先进行视图访问权限的判断
- 在通过get_object()获取具体对象时,会进行模型对象访问权限的判断
源码分析
# APIView---->dispatch---->initial--->self.check_permissions(request)(APIView的对象方法)
def check_permissions(self, request):
# 遍历权限对象列表得到一个个权限对象(权限器),进行权限认证
for permission in self.get_permissions():
# 权限类一定有一个has_permission权限方法,用来做权限认证的
# 参数:权限对象self、请求对象request、视图类对象
# 返回值:有权限返回True,无权限返回False
if not permission.has_permission(request, self):
self.permission_denied(
request, message=getattr(permission, 'message', None)
)
使用
# 写一个类,继承BasePermission,重写has_permission,如果权限通过,就返回True,不通过就返回False
from rest_framework.permissions import BasePermission
class UserPermission(BasePermission):
def has_permission(self, request, view):
# 不是超级用户,不能访问
# 由于认证已经过了,request内就有user对象了,当前登录用户
user=request.user # 当前登录用户
# 如果该字段用了choice,通过get_字段名_display()就能取出choice后面的中文
print(user.get_user_type_display())
if user.user_type==1:
return True
else:
return False
可以在配置文件中全局设置默认的权限管理类
REST_FRAMEWORK={
...
"DEFAULT_PERMISSION_CLASSES":["app01.permissions.UserPermission",]
}
也可以在具体的视图中通过permission_classes属性来设置
from rest_framework.views import APIView
class TestView(APIView):
permission_classes = (UserPermission,)
...
局部禁用
class TestView(APIView):
permission_classes = []
...
内置提供的权限
需要和django自带的auth_user表连用
- AllowAny 允许所有用户
- IsAuthenticated 仅通过认证的用户
- IsAdminUser 仅管理员用户
- IsAuthenticatedOrReadOnly 已经登陆认证的用户可以对数据进行增删改操作,没有登陆认证的只能查看数据。
可以在配置文件中全局设置默认的权限管理类
REST_FRAMEWORK = {
....
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.IsAuthenticated',
)
}
如果未指明,则采用如下默认配置
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.AllowAny',
)
也可以在具体的视图中通过permission_classes属性来设置
from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView
class ExampleView(APIView):
permission_classes = (IsAuthenticated,)
...
举例
# 创建超级用户,登陆到admin,创建普通用户(注意设置职员状态,也就是能登陆)
# 全局配置IsAuthenticated
# setting.py
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.IsAuthenticated',
)
# urls.py
path('test/', views.TestView.as_view()),
# views.py
class TestView(APIView):
def get(self,request):
return Response({'msg':'个人中心'})
# 登陆到admin后台后,直接访问可以,如果没登陆,不能访问
##注意:如果全局配置了
rest_framework.permissions.IsAdminUser
# 就只有管理员能访问,普通用户访问不了
自定义权限
如需自定义权限,需继承rest_framework.permissions.BasePermission父类,并实现以下两个任何一个方法或全部
-
.has_permission(self, request, view)
是否可以访问视图, view表示当前视图对象
-
.has_object_permission(self, request, view, obj)
是否可以访问数据对象, view表示当前视图, obj为数据对象
例如:
在当前子应用下,创建一个权限文件permissions.py中声明自定义权限类:
from rest_framework.permissions import BasePermission
class IsXiaoMingPermission(BasePermission):
def has_permission(self, request, view):
if( request.user.username == "xiaoming" ):
return True
from .permissions import IsXiaoMingPermission
class StudentViewSet(ModelViewSet):
queryset = Student.objects.all()
serializer_class = StudentSerializer
permission_classes = [IsXiaoMingPermission]
三 限流Throttling
频率组件:限制视图接口被访问的频率次数 - 限制的条件(IP、id、唯一键)、频率周期时间(s、m、h)、频率的次数(3/s)
没有达到限次:正常访问接口
达到限次:限制时间内不能访问,限制时间达到后,可以重新访问
可以对接口访问的频次进行限制,以减轻服务器压力。
一般用于付费购买次数,投票等场景使用.
使用
可以在配置文件中,使用DEFAULT_THROTTLE_CLASSES
和 DEFAULT_THROTTLE_RATES
进行全局配置
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': (
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle'
),
'DEFAULT_THROTTLE_RATES': {
'anon': '100/day',
'user': '1000/day'
}
}
DEFAULT_THROTTLE_RATES
可以使用 second
, minute
, hour
或day
来指明周期。
也可以在具体视图中通过throttle_classess属性来配置
from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView
class ExampleView(APIView):
throttle_classes = (UserRateThrottle,)
...
实列
# 内置的频率限制(限制未登录用户)
# 全局使用 限制未登录用户1分钟访问5次
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': (
'rest_framework.throttling.AnonRateThrottle',
),
'DEFAULT_THROTTLE_RATES': {
'anon': '5/m',
}
}
## views.py
from rest_framework.permissions import IsAdminUser
from rest_framework.authentication import SessionAuthentication,BasicAuthentication
class TestView1(APIView):
authentication_classes=[]
permission_classes = []
def get(self,request,*args,**kwargs):
return Response('我是未登录用户')
# 局部使用
from rest_framework.permissions import IsAdminUser
from rest_framework.authentication import SessionAuthentication,BasicAuthentication
from rest_framework.throttling import AnonRateThrottle
class TestView2(APIView):
authentication_classes=[]
permission_classes = []
throttle_classes = [AnonRateThrottle]
def get(self,request,*args,**kwargs):
return Response('我是未登录用户')
可选限流类
1) AnonRateThrottle
限制所有匿名未认证用户,使用IP区分用户。
使用DEFAULT_THROTTLE_RATES['anon']
来设置频次
2)UserRateThrottle
限制认证用户,使用User id 来区分。
使用DEFAULT_THROTTLE_RATES['user']
来设置频次
3)ScopedRateThrottle
限制用户对于每个视图的访问频次,使用ip或user id。
例如:
class ContactListView(APIView):
throttle_scope = 'contacts'
...
class ContactDetailView(APIView):
throttle_scope = 'contacts'
...
class UploadView(APIView):
throttle_scope = 'uploads'
...
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': (
'rest_framework.throttling.ScopedRateThrottle',
),
'DEFAULT_THROTTLE_RATES': {
'contacts': '1000/day',
'uploads': '20/day'
}
}
实例
全局配置中设置访问频率
'DEFAULT_THROTTLE_RATES': {
'anon': '3/minute',
'user': '10/minute'
}
from rest_framework.authentication import SessionAuthentication
from rest_framework.permissions import IsAuthenticated
from rest_framework.generics import RetrieveAPIView
from rest_framework.throttling import UserRateThrottle
class StudentAPIView(RetrieveAPIView):
queryset = Student.objects.all()
serializer_class = StudentSerializer
authentication_classes = [SessionAuthentication]
permission_classes = [IsAuthenticated]
throttle_classes = (UserRateThrottle,)
自定义频率类
# 自定义的逻辑
#(1)取出访问者ip
#(2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问,在字典里,继续往下走
#(3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
#(4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
#(5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
class MyThrottles():
VISIT_RECORD = {}
def __init__(self):
self.history=None
def allow_request(self,request, view):
#(1)取出访问者ip
# print(request.META)
ip=request.META.get('REMOTE_ADDR')
import time
ctime=time.time()
# (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问
if ip not in self.VISIT_RECORD:
self.VISIT_RECORD[ip]=[ctime,]
return True
self.history=self.VISIT_RECORD.get(ip)
# (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
while self.history and ctime-self.history[-1]>60:
self.history.pop()
# (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
# (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
if len(self.history)<3:
self.history.insert(0,ctime)
return True
else:
return False
def wait(self):
import time
ctime=time.time()
return 60-(ctime-self.history[-1])
全局使用
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES':['app01.utils.MyThrottles',],
}
局部使用
#在视图类里使用
throttle_classes = [MyThrottles,]
继承SimpleRateThrottle自定义频率类
# 写一个类,继承SimpleRateThrottle,只需要重写get_cache_key
from rest_framework.throttling import ScopedRateThrottle,SimpleRateThrottle
#继承SimpleRateThrottle
class MyThrottle(SimpleRateThrottle):
scope='luffy'
def get_cache_key(self, request, view):
print(request.META.get('REMOTE_ADDR'))
return request.META.get('REMOTE_ADDR') # 返回用户ip(返回什么就根据什么限制用户访问频次)
# 全局使用
from utils.throttlings import MyThrottle
REST_FRAMEWORK={
'DEFAULT_THROTTLE_CLASSES': (
'utils.throttling.MyThrottle',
),
'DEFAULT_THROTTLE_RATES': {
'luffy': '3/m' # key要跟类中的scop对应
},
}
# 局部使用
# 视图类里加上
from utils.throttlings import MyThrottle
class ThView:
throttle_classes = [MyThrottle,]
def get(self, request):
...
SimpleRateThrottle源码分析
class SimpleRateThrottle(BaseThrottle):
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
def get_cache_key(self, request, view):
raise NotImplementedError('.get_cache_key() must be overridden')
def get_rate(self):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
def allow_request(self, request, view):
if self.rate is None:
return True
# 当前登录用户的唯一标识
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
# 用户初次访问缓存为空,self.history为[],是存放时间的列表
self.history = self.cache.get(self.key, [])
# 获取一下当前时间,存放到 self.now
self.now = self.timer()
# 当前访问与第一次访问时间间隔如果大于60s,第一次记录清除,不再算作一次计数
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
# history的长度与限制次数进行比较
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
return False
def wait(self):
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)