目前,我们的API对谁可以编辑或删除代码段没有任何限制。我们希望有更高级的行为,以确保:
- 代码片段始终与创建者相关联。
- 只有通过身份验证的用户可以创建片段。
- 只有代码片段的创建者可以更新或删除它。
- 未经身份验证的请求应具有完全只读访问权限。
01-认证
REST framework 提供了一些开箱即用的身份验证方案,并且还允许你实现自定义方案。
1.1-自定义Token认证
定义一个用户表和一个保存用户Token的表:
class UserInfo(models.Model): username = models.CharField(max_length=16) password = models.CharField(max_length=32) type = models.SmallIntegerField( choices=((1, '普通用户'), (2, 'VIP用户')), default=1 ) class Token(models.Model): user = models.OneToOneField(to='UserInfo') token_code = models.CharField(max_length=128)
1.2-定义一个登录视图
def get_random_token(username): """ 根据用户名和时间戳生成随机token :param username: :return: """ import hashlib, time timestamp = str(time.time()) m = hashlib.md5(bytes(username, encoding="utf8")) m.update(bytes(timestamp, encoding="utf8")) return m.hexdigest() class LoginView(APIView): """ 校验用户名密码是否正确从而生成token的视图 """ def post(self, request): res = {"code": 0} print(request.data) username = request.data.get("username") password = request.data.get("password") user = models.UserInfo.objects.filter(username=username, password=password).first() if user: # 如果用户名密码正确 token = get_random_token(username) models.Token.objects.update_or_create(defaults={"token_code": token}, user=user) res["token"] = token else: res["code"] = 1 res["error"] = "用户名或密码错误" return Response(res)
1.3-定义一个认证类
我们自己写的认证类都要继承内置认证类 "BaseAuthentication"

class BaseAuthentication(object): """ All authentication classes should extend BaseAuthentication. """ def authenticate(self, request): """ Authenticate the request and return a two-tuple of (user, token). """ #内置的认证类,authenticate方法,如果不自己写,默认则抛出异常 raise NotImplementedError(".authenticate() must be overridden.") def authenticate_header(self, request): """ Return a string to be used as the value of the `WWW-Authenticate` header in a `401 Unauthenticated` response, or `None` if the authentication scheme should return `403 Permission Denied` responses. """ #authenticate_header方法,作用是当认证失败的时候,返回的响应头 pass
from rest_framework.authentication import BaseAuthentication from rest_framework.exceptions import AuthenticationFailed class MyAuth(BaseAuthentication): def authenticate(self, request): if request.method in ["POST", "PUT", "DELETE"]: request_token = request.data.get("token", None) if not request_token: raise AuthenticationFailed('缺少token') token_obj = models.Token.objects.filter(token_code=request_token).first() if not token_obj: raise AuthenticationFailed('无效的token') return token_obj.user.username, None else: return None, None def authenticate_header(self, request): pass
1.4-视图级别认证
class CommentViewSet(ModelViewSet): queryset = models.Comment.objects.all() serializer_class = app01_serializers.CommentSerializer authentication_classes = [MyAuth, ]
1.5-全局级别认证
# 在settings.py中配置 REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ] }
1.6 在settings里面设置的全局认证,所有业务都需要经过认证,如果想让某个不需要认证,只需要在其中添加下面的代码:
authentication_classes = [] #里面为空,代表不需要认证
02-权限
只有VIP用户才能看的内容。premission.py
2.1-内置权限验证类
django rest framework 提供了内置的权限验证类,其本质都是定义has_permission()方法对权限进行验证:
2.2-自定义一个权限类
# 自定义权限 class MyPermission(BasePermission): message = 'VIP用户才能访问' def has_permission(self, request, view): """ 自定义权限只有VIP用户才能访问 """ # 因为在进行权限判断之前已经做了认证判断,所以这里可以直接拿到request.user if request.user and request.user.type == 2: # 如果是VIP用户 return True else: return False
2.3-视图级别配置
class CommentViewSet(ModelViewSet): queryset = models.Comment.objects.all() serializer_class = app01_serializers.CommentSerializer authentication_classes = [MyAuth, ] permission_classes = [MyPermission, ]
2.4-全局级别配置
# 在settings.py中设置rest framework相关配置项 REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ], "DEFAULT_PERMISSION_CLASSES": ["app01.utils.MyPermission", ] }
在全局配置权限后,因部分视图不想使用全局的权限限制,所以 需要在部分视图里配置:
permission_classes = [OrdinaryPremission,] #不用全局的权限配置的话,这里就要写自己的局部权限
03-限制
DRF内置了基本的限制类,首先我们自己动手写一个限制类,熟悉下限制组件的执行过程。
3.1-自定义限制类
VISIT_RECORD = {} # 自定义限制 class MyThrottle(object): def __init__(self): self.history = None def allow_request(self, request, view): """ 自定义频率限制60秒内只能访问三次 """ # 获取用户IP ip = request.META.get("REMOTE_ADDR") timestamp = time.time() if ip not in VISIT_RECORD: VISIT_RECORD[ip] = [timestamp, ] return True history = VISIT_RECORD[ip] self.history = history history.insert(0, timestamp) while history and history[-1] < timestamp - 60: history.pop() if len(history) > 3: return False else: return True def wait(self): """ 限制时间还剩多少 """ timestamp = time.time() return 60 - (timestamp - self.history[-1])
3.2-视图级别配置
class CommentViewSet(ModelViewSet): queryset = models.Comment.objects.all() serializer_class = app01_serializers.CommentSerializer throttle_classes = [MyThrottle, ]
3.3-全局级别配置
# 在settings.py中设置rest framework相关配置项 REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ], "DEFAULT_PERMISSION_CLASSES": ["app01.utils.MyPermission", ] "DEFAULT_THROTTLE_CLASSES": ["app01.utils.MyThrottle", ] }
04-使用内置限制类
我们可以通过继承SimpleRateThrottle类,来实现节流,会更加的简单,因为SimpleRateThrottle里面都帮我们写好了。

class SimpleRateThrottle(BaseThrottle): """ A simple cache implementation, that only requires `.get_cache_key()` to be overridden. The rate (requests / seconds) is set by a `rate` attribute on the View class. The attribute is a string of the form 'number_of_requests/period'. Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') Previous request information used for throttling is stored in the cache. """ cache = default_cache timer = time.time cache_format = 'throttle_%(scope)s_%(ident)s' scope = None #这个值自定义,写什么都可以 THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES def __init__(self): if not getattr(self, 'rate', None): self.rate = self.get_rate() self.num_requests, self.duration = self.parse_rate(self.rate) def get_cache_key(self, request, view): """ Should return a unique cache-key which can be used for throttling. Must be overridden. May return `None` if the request should not be throttled. """ raise NotImplementedError('.get_cache_key() must be overridden') def get_rate(self): """ Determine the string representation of the allowed request rate. """ if not getattr(self, 'scope', None): msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % self.__class__.__name__) raise ImproperlyConfigured(msg) try: return self.THROTTLE_RATES[self.scope] except KeyError: msg = "No default throttle rate set for '%s' scope" % self.scope raise ImproperlyConfigured(msg) def parse_rate(self, rate): """ Given the request rate string, return a two tuple of: <allowed number of requests>, <period of time in seconds> """ if rate is None: return (None, None) num, period = rate.split('/') num_requests = int(num) duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] return (num_requests, duration) def allow_request(self, request, view): """ Implement the check to see if the request should be throttled. On success calls `throttle_success`. On failure calls `throttle_failure`. """ if self.rate is None: return True self.key = self.get_cache_key(request, view) if self.key is None: return True self.history = self.cache.get(self.key, []) self.now = self.timer() # Drop any requests from the history which have now passed the # throttle duration while self.history and self.history[-1] <= self.now - self.duration: self.history.pop() if len(self.history) >= self.num_requests: return self.throttle_failure() return self.throttle_success() def throttle_success(self): """ Inserts the current request's timestamp along with the key into the cache. """ self.history.insert(0, self.now) self.cache.set(self.key, self.history, self.duration) return True def throttle_failure(self): """ Called when a request to the API has failed due to throttling. """ return False def wait(self): """ Returns the recommended next request time in seconds. """ if self.history: remaining_duration = self.duration - (self.now - self.history[-1]) else: remaining_duration = self.duration available_requests = self.num_requests - len(self.history) + 1 if available_requests <= 0: return None return remaining_duration / float(available_requests)
from rest_framework.throttling import SimpleRateThrottle class VisitThrottle(SimpleRateThrottle): '''匿名用户60s只能访问三次(根据ip)''' scope = 'NBA' #这里面的值,自己随便定义,settings里面根据这个值配置Rate def get_cache_key(self, request, view): #通过ip限制节流 return self.get_ident(request) class UserThrottle(SimpleRateThrottle): '''登录用户60s可以访问10次''' scope = 'NBAUser' #这里面的值,自己随便定义,settings里面根据这个值配置Rate def get_cache_key(self, request, view): return request.user.username
4.1-全局配置
# 在settings.py中设置rest framework相关配置项 REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ], # "DEFAULT_PERMISSION_CLASSES": ["app01.utils.MyPermission", ] "DEFAULT_THROTTLE_CLASSES": ["app01.utils.VisitThrottle", ], # 设置访问频率 "DEFAULT_THROTTLE_RATES": { 'NBA':'3/m', #没登录用户3/m,NBA就是scope定义的值 'NBAUser':'10/m', #登录用户10/m,NBAUser就是scope定义的值 } }
4.2-局部配置
class AuthView(APIView): . . . # 默认的节流是登录用户(10/m),AuthView不需要登录,这里用匿名用户的节流(3/m) throttle_classes = [VisitThrottle,] . .
4.3-BaseThrottle
1. 自己要写allow_request和wait方法
2. get_ident就是获取ip
源码如下:
class BaseThrottle(object): """ Rate throttling of requests. """ def allow_request(self, request, view): """ Return `True` if the request should be allowed, `False` otherwise. """ raise NotImplementedError('.allow_request() must be overridden') def get_ident(self, request): """ Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR if present and number of proxies is > 0. If not use all of HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. """ xff = request.META.get('HTTP_X_FORWARDED_FOR') remote_addr = request.META.get('REMOTE_ADDR') num_proxies = api_settings.NUM_PROXIES if num_proxies is not None: if num_proxies == 0 or xff is None: return remote_addr addrs = xff.split(',') client_addr = addrs[-min(num_proxies, len(addrs))] return client_addr.strip() return ''.join(xff.split()) if xff else remote_addr def wait(self): """ Optionally, return a recommended number of seconds to wait before the next request. """ return None