zoukankan      html  css  js  c++  java
  • flask_wtf flask 的 CSRF 源代码初研究

    因为要搞一个基于flask的前后端分离的个人网站,所以需要研究下flask的csrf防护原理.

    用的扩展是flask_wtf,也算是比较官方的扩展库了.

    先上相关源代码:

      1 def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
      2     """Check if the given data is a valid CSRF token. This compares the given
      3     signed token to the one stored in the session.
      4 
      5     :param data: The signed CSRF token to be checked.
      6     :param secret_key: Used to securely sign the token. Default is
      7         ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
      8     :param time_limit: Number of seconds that the token is valid. Default is
      9         ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
     10     :param token_key: Key where token is stored in session for comparision.
     11         Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
     12 
     13     :raises ValidationError: Contains the reason that validation failed.
     14 
     15     .. versionchanged:: 0.14
     16         Raises ``ValidationError`` with a specific error message rather than
     17         returning ``True`` or ``False``.
     18     """
     19 
     20     secret_key = _get_config(
     21         secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key,
     22         message='A secret key is required to use CSRF.'
     23     )
     24     field_name = _get_config(
     25         token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token',
     26         message='A field name is required to use CSRF.'
     27     )
     28     time_limit = _get_config(
     29         time_limit, 'WTF_CSRF_TIME_LIMIT', 3600, required=False
     30     )
     31 
     32     if not data:
     33         raise ValidationError('The CSRF token is missing.')
     34 
     35     if field_name not in session:
     36         raise ValidationError('The CSRF session token is missing.')
     37 
     38     s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')
     39 
     40     try:
     41         token = s.loads(data, max_age=time_limit)
     42     except SignatureExpired:
     43         raise ValidationError('The CSRF token has expired.')
     44     except BadData:
     45         raise ValidationError('The CSRF token is invalid.')
     46 
     47     if not safe_str_cmp(session[field_name], token):
     48         raise ValidationError('The CSRF tokens do not match.')
     49 
     50 
     51 class CSRFProtect(object):
     52     """Enable CSRF protection globally for a Flask app.
     53 
     54     ::
     55 
     56         app = Flask(__name__)
     57         csrf = CsrfProtect(app)
     58 
     59     Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``
     60     header sent with JavaScript requests. Render the token in templates using
     61     ``{{ csrf_token() }}``.
     62 
     63     See the :ref:`csrf` documentation.
     64     """
     65 
     66     def __init__(self, app=None):
     67         self._exempt_views = set()
     68         self._exempt_blueprints = set()
     69 
     70         if app:
     71             self.init_app(app)
     72 
     73     def init_app(self, app):
     74         app.extensions['csrf'] = self
     75 
     76         app.config.setdefault('WTF_CSRF_ENABLED', True)
     77         app.config.setdefault('WTF_CSRF_CHECK_DEFAULT', True)
     78         app.config['WTF_CSRF_METHODS'] = set(app.config.get(
     79             'WTF_CSRF_METHODS', ['POST', 'PUT', 'PATCH', 'DELETE']
     80         ))
     81         app.config.setdefault('WTF_CSRF_FIELD_NAME', 'csrf_token')
     82         app.config.setdefault(
     83             'WTF_CSRF_HEADERS', ['X-CSRFToken', 'X-CSRF-Token']
     84         )
     85         app.config.setdefault('WTF_CSRF_TIME_LIMIT', 3600)
     86         app.config.setdefault('WTF_CSRF_SSL_STRICT', True)
     87 
     88         app.jinja_env.globals['csrf_token'] = generate_csrf        <><><><><><><><><><><><><><><><><><><>
     89         app.context_processor(lambda: {'csrf_token': generate_csrf})
     90 
     91         @app.before_request
     92         def csrf_protect():
     93             if not app.config['WTF_CSRF_ENABLED']:
     94                 return
     95 
     96             if not app.config['WTF_CSRF_CHECK_DEFAULT']:
     97                 return
     98 
     99             if request.method not in app.config['WTF_CSRF_METHODS']:
    100                 return
    101 
    102             if not request.endpoint:
    103                 return
    104 
    105             view = app.view_functions.get(request.endpoint)
    106 
    107             if not view:
    108                 return
    109 
    110             if request.blueprint in self._exempt_blueprints:
    111                 return
    112 
    113             dest = '%s.%s' % (view.__module__, view.__name__)
    114 
    115             if dest in self._exempt_views:
    116                 return
    117 
    118             self.protect()
    119 
    120     def _get_csrf_token(self):
    121         # find the ``csrf_token`` field in the subitted form
    122         # if the form had a prefix, the name will be
    123         # ``{prefix}-csrf_token``
    124         field_name = current_app.config['WTF_CSRF_FIELD_NAME']
    125 
    126         for key in request.form:
    127             if key.endswith(field_name):
    128                 csrf_token = request.form[key]
    129 
    130                 if csrf_token:
    131                     return csrf_token
    132 
    133         for header_name in current_app.config['WTF_CSRF_HEADERS']:
    134             csrf_token = request.headers.get(header_name)
    135 
    136             if csrf_token:
    137                 return csrf_token
    138 
    139         return None
    140 
    141     def protect(self):
    142         if request.method not in current_app.config['WTF_CSRF_METHODS']:
    143             return
    144 
    145         try:
    146             validate_csrf(self._get_csrf_token())
    147         except ValidationError as e:
    148             logger.info(e.args[0])
    149             self._error_response(e.args[0])
    150 
    151         if request.is_secure and current_app.config['WTF_CSRF_SSL_STRICT']:
    152             if not request.referrer:
    153                 self._error_response('The referrer header is missing.')
    154 
    155             good_referrer = 'https://{0}/'.format(request.host)
    156 
    157             if not same_origin(request.referrer, good_referrer):
    158                 self._error_response('The referrer does not match the host.')
    159 
    160         g.csrf_valid = True  # mark this request as CSRF valid

     先说明下csrftoken的普通机制,上面代码中有一行代码后面被我加了一串<>符号,这行代码表明,默认的jinja2渲染的方式就是通过generate_csrf 方法生成csrftoken字符串,所以前后端分离的话,可以直接通过这个方法获取csrftoken,效果是一样的.

    进入generate_csrf函数内部,会发现他做了这么点事:生成token,放在session里,然后返回一个加工过的token.这一块说明每当不同的访问触发该函数,那么服务器session内的csrftoken值就会不一样,所以,你可以这么做,获取一次之后在有效期(一个小时内)可以重复使用,但是不建议这么做.然后如果不是form表单提交的话,该csrf系统不会从json中获取token,而会从请求头获取,所以需要在请求头内添加关键字段:X-CSRFToken,将这个值赋值为获取的token即可.

    首先获取csrftoken的方式: _get_csrf_token

    会先从表单中查找关键字段,如果获取,那么返回该值,获取不到,从请求头获取,方式和django的基本一致,毕竟也就这两种规范方式.

     91         @app.before_request
     92         def csrf_protect():

    这两行代码表明wtf是如何实现校验的,通过flask的钩子函数在每次请求开始时进行校验,这是在初始化wtf init_app(app)的时候就已经添加了该钩子函数.

    在django里面,一旦中间件的process_request返回任何值,中间件即开始执行响应回调,视图不在执行,那么上面的两行代码下面好像不停地return了好多次,到底啥意思呢,只好再找源码看看.相关源码在下面:

        @setupmethod
        def before_request(self, f):
            """Registers a function to run before each request.
    
            For example, this can be used to open a database connection, or to load
            the logged in user from the session.
    
            The function will be called without any arguments. If it returns a
            non-None value, the value is handled as if it was the return value from
            the view, and further request handling is stopped.
            """
            self.before_request_funcs.setdefault(None, []).append(f)
            return f

    可以看到添加钩子函数的装饰器执行了什么操作,他只是把钩子函数放进了一个函数列表里,然后我们看看这个函数列表是什么方式处理的.源码如下:

        def preprocess_request(self):
            """Called before the request is dispatched. Calls
            :attr:`url_value_preprocessors` registered with the app and the
            current blueprint (if any). Then calls :attr:`before_request_funcs`
            registered with the app and the blueprint.
    
            If any :meth:`before_request` handler returns a non-None value, the
            value is handled as if it was the return value from the view, and
            further request handling is stopped.
            """
    
            bp = _request_ctx_stack.top.request.blueprint
    
            funcs = self.url_value_preprocessors.get(None, ())
            if bp is not None and bp in self.url_value_preprocessors:
                funcs = chain(funcs, self.url_value_preprocessors[bp])
            for func in funcs:
                func(request.endpoint, request.view_args)
    
            funcs = self.before_request_funcs.get(None, ())
            if bp is not None and bp in self.before_request_funcs:
                funcs = chain(funcs, self.before_request_funcs[bp])
            for func in funcs:
                rv = func()
                if rv is not None:
                    return rv

    该方法的注释说明了,如果钩子函数返回任意不为空的数据,那么等同于视图的响应,所以仅仅return 不会导致钩子函数结束,仍然可以访问视图.

     现在可以解释def csrf_protect():函数的内容了,即,请求方式不在保护范围内时,跳过校验,未开启防护时,跳过校验,视图无效时跳过校验.

    csrf_protect 中会执行 protect ,protect 会执行 validate_csrf(),validate_csrf()是校验的关键,源代码如下:

    def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
        """Check if the given data is a valid CSRF token. This compares the given
        signed token to the one stored in the session.
    
        :param data: The signed CSRF token to be checked.
        :param secret_key: Used to securely sign the token. Default is
            ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
        :param time_limit: Number of seconds that the token is valid. Default is
            ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
        :param token_key: Key where token is stored in session for comparision.
            Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
    
        :raises ValidationError: Contains the reason that validation failed.
    
        .. versionchanged:: 0.14
            Raises ``ValidationError`` with a specific error message rather than
            returning ``True`` or ``False``.
        """
    
        secret_key = _get_config(
            secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key,
            message='A secret key is required to use CSRF.'
        )
        field_name = _get_config(
            token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token',
            message='A field name is required to use CSRF.'
        )
        time_limit = _get_config(
            time_limit, 'WTF_CSRF_TIME_LIMIT', 3600, required=False
        )
    
        if not data:
            raise ValidationError('The CSRF token is missing.')
    
        if field_name not in session:
            raise ValidationError('The CSRF session token is missing.')
    
        s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')
    
        try:
            token = s.loads(data, max_age=time_limit)
        except SignatureExpired:
            raise ValidationError('The CSRF token has expired.')
        except BadData:
            raise ValidationError('The CSRF token is invalid.')
    
        if not safe_str_cmp(session[field_name], token):
            raise ValidationError('The CSRF tokens do not match.')

    该方法前面部分就是在获取相关秘钥和关键字,如果不自己自定义的话,这一块通常不会出问题,后面可以看到,方法会从全局变量session中寻找csrftoken字段名,然后最后一步进行校验,所以,wtf是通过比对session中的CSRFtoken和表单中的csrftoken是否一致.

     所以前后端分离方式开发的话,需要将csrftoken通过接口或者cookie的方式传给前端,前端将该部分数据取出保存,提交表单的时候带上.

    至于关键字,最上面那段代码写的很清楚,默认的,表单是csrf_token, 请求头是 X-CSRFToken.

    
    
  • 相关阅读:
    Redis之七种武器
    Redis与Memcached的区别
    java优化占用内存的方法(一)
    Java内存区域与内存溢出异常(二)
    深入理解java垃圾回收机制
    从JAVA多线程理解到集群分布式和网络设计的浅析
    大型网站系统架构系列:分布式消息队列(一)
    大型网站系统架构系列:分布式消息队列(二)
    大型分布式网站架构技术总结
    40个Java多线程问题总结
  • 原文地址:https://www.cnblogs.com/haiton/p/11044219.html
Copyright © 2011-2022 走看看