zoukankan      html  css  js  c++  java
  • Spring AOP + Redis 实现针对用户的接口访问频率限制

    根据请求参数或请求头中的字段进行频率限制

    限制类型:

    package com.seliote.fr.config.api;
    
    /**
     * API 调用频率限制注解类型
     *
     * @author seliote
     */
    public enum ApiFrequencyType {
        // 请求参数,如果为该类型那么方法有且只能有一个参数
        ARG,
        // 请求头
        HEADER
    }
    

    限制实际使用的注解:

    package com.seliote.fr.annotation;
    
    import com.seliote.fr.config.api.ApiFrequencyType;
    import com.seliote.fr.config.auth.TokenFilter;
    
    import java.lang.annotation.*;
    import java.time.temporal.ChronoUnit;
    
    /**
     * API 调用频率限制注解,配合切面使用
     * 要求注解的方法有且只能有一个参数
     * 默认值为取 Token 请求头,每分钟五次
     *
     * @author seliote
     */
    @Documented
    @Target({ElementType.METHOD})
    @Retention(RetentionPolicy.RUNTIME)
    public @interface ApiFrequency {
    
        // 类型,表示判断频率的值需要从哪个字段取,有先后顺序
        ApiFrequencyType type() default ApiFrequencyType.HEADER;
    
        // 判断频率使用的值,多个参数使用 && 连接
        String key() default TokenFilter.TOKEN_HEADER;
    
        // API 最大频率
        int frequency() default 5;
    
        // 时间
        long time() default 1;
    
        // 时间单位
        ChronoUnit unit() default ChronoUnit.MINUTES;
    }
    

    切面代码:

    package com.seliote.fr.config.api;
    
    import com.seliote.fr.annotation.stereotype.ApiComponent;
    import com.seliote.fr.exception.FrequencyException;
    import com.seliote.fr.service.RedisService;
    import com.seliote.fr.util.CommonUtils;
    import com.seliote.fr.util.TextUtils;
    import lombok.extern.log4j.Log4j2;
    import org.aspectj.lang.JoinPoint;
    import org.aspectj.lang.annotation.Aspect;
    import org.aspectj.lang.annotation.Before;
    import org.aspectj.lang.reflect.MethodSignature;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.core.annotation.Order;
    import org.springframework.web.context.request.RequestContextHolder;
    import org.springframework.web.context.request.ServletRequestAttributes;
    
    import java.beans.IntrospectionException;
    import java.beans.PropertyDescriptor;
    import java.lang.reflect.InvocationTargetException;
    import java.time.Instant;
    import java.util.Optional;
    
    import static com.seliote.fr.util.ReflectUtils.getClassName;
    
    /**
     * API 调用频率限制 AOP
     *
     * @author seliote
     */
    @Log4j2
    @Order(1)
    @ApiComponent
    @Aspect
    public class ApiFrequency {
    
        private static final String redisNameSpace = "frequency";
    
        private final RedisService redisService;
    
        @Autowired
        public ApiFrequency(RedisService redisService) {
            this.redisService = redisService;
            log.debug("Construct {}", getClassName(this));
        }
    
        /**
         * API 调用频率限制
         *
         * @param joinPoint AOP JoinPoint 对象
         */
        @Before("com.seliote.fr.config.api.ApiAop.api() && @annotation(com.seliote.fr.annotation.ApiFrequency)")
        public void apiFrequency(JoinPoint joinPoint) {
            Optional<String> uri = CommonUtils.getUri();
            if (uri.isEmpty()) {
                log.error("Frequency check error, uri is null");
                throw new FrequencyException("URI is empty");
            }
            var method = ((MethodSignature) joinPoint.getSignature()).getMethod();
            var annotation = method.getAnnotation(com.seliote.fr.annotation.ApiFrequency.class);
            var identifier = (annotation.type() == ApiFrequencyType.ARG ?
                    getArg(uri.get(), joinPoint, annotation) : getHeader(uri.get(), annotation));
            if (identifier.isEmpty()) {
                log.error("Frequency check error, identifier is empty for: {}", uri.get());
                // 过滤器在切面前执行,如果获取到为空说明代码有问题
                throw new FrequencyException("Identifier is empty");
            }
            // Token 或者参数可能会很长,所以 SHA-1 一下
            var sha1 = TextUtils.sha1(identifier.get());
            // 单位时长
            var unitSeconds = CommonUtils.time2Seconds(annotation.time(), annotation.unit());
            var redisKey = getRedisKey(uri.get(), sha1, unitSeconds);
            var current = frequency(redisKey, unitSeconds);
            var frequency = annotation.frequency();
            if (current <= frequency) {
                log.debug("Frequency pass for: {}, current: {}, identifier: {}, sha1: {}",
                        uri.get(), current, identifier.get(), sha1);
            } else {
                log.warn("Frequency too high: {}, current: {}, identifier: {}, sha1: {}",
                        uri.get(), current, identifier.get(), sha1);
                throw new FrequencyException("Frequency too high");
            }
        }
    
        /**
         * 获取请求参数中的标识符
         *
         * @param joinPoint  JoinPoint 对象注解
         * @param annotation @ApiFrequency 对象
         * @return 请求参数中的标识符
         */
        private Optional<String> getArg(String uri, JoinPoint joinPoint, com.seliote.fr.annotation.ApiFrequency annotation) {
            var args = joinPoint.getArgs();
            if (args == null || args.length != 1) {
                log.error("Args length incorrect: {}, {}", uri, args);
                return Optional.empty();
            }
            var arg = args[0];
            var keys = annotation.key().split(getKeySeparator());
            var identifiers = new String[keys.length];
            for (var i = 0; i < keys.length; ++i) {
                try {
                    final var pd = new PropertyDescriptor(keys[i], arg.getClass());
                    var result = pd.getReadMethod().invoke(arg);
                    if (result == null) {
                        log.error("Argument getter return null: {}, argument: {}, getter: {}", uri, arg, keys[i]);
                        throw new FrequencyException("Argument getter return null");
                    }
                    identifiers[i] = result.toString();
                } catch (IntrospectionException | IllegalAccessException | InvocationTargetException exception) {
                    log.error("Get identifier args error: {}, {}, exception: {}, message: {}, exception at: {}",
                            uri, arg, getClassName(exception), exception.getMessage(), keys[i]);
                    return Optional.empty();
                }
            }
            return Optional.of(String.join(getIdentifierSeparator(), identifiers));
        }
    
        /**
         * 获取请求头中的标识符
         *
         * @param annotation @ApiFrequency 对象
         * @return 请求头中的标识符
         */
        private Optional<String> getHeader(String uri, com.seliote.fr.annotation.ApiFrequency annotation) {
            var keys = annotation.key().split(getKeySeparator());
            var identifiers = new String[keys.length];
            var servletAttr = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
            if (servletAttr != null) {
                var httpAttr = servletAttr.getRequest();
                for (var i = 0; i < keys.length; ++i) {
                    var header = httpAttr.getHeader(keys[i]);
                    if (header == null || header.length() <= 0) {
                        log.error("Header is null for: {}, header: {}", uri, keys[i]);
                        throw new FrequencyException("Get header return null");
                    }
                    identifiers[i] = header;
                }
            }
            return Optional.of(String.join(getIdentifierSeparator(), identifiers));
        }
    
        /**
         * 获取 Redis 中存储的 Key
         *
         * @param uri        请求 URI
         * @param identifier 本次请求的标识符
         * @param seconds    单位时间秒数
         * @return Redis 中存储的 Key
         */
        private String getRedisKey(String uri, String identifier, long seconds) {
            var now = Instant.now().getEpochSecond();
            // 计算单位起始时间,单位时间内访问次数限制
            var start = (now - (now % seconds)) + "";
            return redisService.formatKey(redisNameSpace, uri, identifier, start);
        }
    
        /**
         * 增加本次的频率并获取单位时间内的访问次数
         *
         * @param key     Redis Key
         * @param seconds 单位时间秒数
         * @return 访问次数
         */
        private long frequency(String key, long seconds) {
            if (!redisService.exists(key)) {
                redisService.setex(key, (int) seconds, "0");
            }
            return redisService.incr(key);
        }
    
        /**
         * 获取 Redis Key 的分隔符
         *
         * @return 分隔符
         */
        private String getKeySeparator() {
            return "&&";
        }
    
        /**
         * 获取标识符间的分隔符
         *
         * @return 分隔符
         */
        private String getIdentifierSeparator() {
            return ".";
        }
    }
    

    实际使用的例子:

    package com.seliote.fr.controller;
    
    import com.seliote.fr.annotation.ApiFrequency;
    import com.seliote.fr.annotation.stereotype.ApiController;
    import com.seliote.fr.config.api.ApiFrequencyType;
    import com.seliote.fr.exception.ApiException;
    import com.seliote.fr.model.ci.user.LoginCi;
    import com.seliote.fr.model.co.Co;
    import com.seliote.fr.model.co.user.LoginCo;
    import com.seliote.fr.model.si.user.LoginSi;
    import com.seliote.fr.service.UserService;
    import com.seliote.fr.util.ReflectUtils;
    import lombok.extern.log4j.Log4j2;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.web.bind.annotation.RequestBody;
    import org.springframework.web.bind.annotation.RequestMapping;
    import org.springframework.web.bind.annotation.RequestMethod;
    import org.springframework.web.bind.annotation.ResponseBody;
    
    import javax.validation.Valid;
    
    import static com.seliote.fr.util.ReflectUtils.getClassName;
    
    /**
     * 用户帐户 Controller
     *
     * @author seliote
     */
    @Log4j2
    @ApiController
    @RequestMapping(value = "user", method = {RequestMethod.POST})
    public class UserController {
    
        private final UserService userService;
    
        @Autowired
        public UserController(UserService userService) {
            this.userService = userService;
            log.debug("Construct {}", getClassName(this));
        }
    
        /**
         * 登录用户帐户,未注册的账户将会自动注册
         *
         * @param ci CI
         * @return CO
         */
        @ApiFrequency(type = ApiFrequencyType.ARG, key = "countryCode&&telNo")
        @RequestMapping("login")
        @ResponseBody
        public Co<LoginCo> login(@Valid @RequestBody LoginCi ci) {
            var so = userService.login(ReflectUtils.copy(ci, LoginSi.class));
            if (so.getLoginResult() == 0 || so.getLoginResult() == 1) {
                return Co.cco(ReflectUtils.copy(so, LoginCo.class));
            } else {
                log.error("login for: {}, service return: {}", ci, so);
                throw new ApiException("service return value error");
            }
        }
    
        @ApiFrequency()
        @RequestMapping("info")
        @ResponseBody
        public Co<Void> info() {
            return Co.cco(null);
        }
    }
    
  • 相关阅读:
    为什么WinCE中LoadBitmap加载位图后无法在其上DrawText?
    WinCE中加载位图的方法
    wince5+2440如何支持SDHC?
    WinCE中文字库占了这么多空间?
    【转】用MFC构造DIRECTX应用框架
    全局导出
    模板绑定
    筛选DOM元素
    获取当前所有的属性
    Canvas绘图(二)
  • 原文地址:https://www.cnblogs.com/seliote/p/14458006.html
Copyright © 2011-2022 走看看