zoukankan      html  css  js  c++  java
  • redis限流器的设计

    1.定义注解

    import java.lang.annotation.Documented;
    import java.lang.annotation.ElementType;
    import java.lang.annotation.Repeatable;
    import java.lang.annotation.Retention;
    import java.lang.annotation.RetentionPolicy;
    import java.lang.annotation.Target;
    
    /**
     * redis缓存的注解
     *
     */
    @Target(ElementType.METHOD)
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    @Repeatable(RateLimits.class)
    public @interface RedisRateLimitAttribute {
        /**
         * {@link #key()}的别名
         *
         * @return key()的别名
         */
        String value() default "";
    
        /**
         * key, 支持SpEL表达式解析
         *
         * @return 限流的key值
         */
        String key() default "";
    
        /**
         * 限流的优先级
         *
         * @return 限流器的优先级
         */
        int order() default 0;
    
        /**
         * 执行计数的条件表达式,支持SpEL表达式,如果结果为真,则执行计数
         *
         * @return 执行计数的条件表达式
         */
        String incrCondition() default "true";
    
        /**
         * 限流的最大值,支持配置引用
         *
         * @return 限流的最大值
         */
        String limit() default "1";
    
        /**
         * 限流的时间范围值,支持配置引用
         *
         * @return 限流的时间范围值
         */
        String intervalInMilliseconds() default "1000";
    
        /**
         * 降级的方法名,降级方法的参数与原方法一致或多了一个原方法的ReturnValue的类型
         *
         * @return 降级的方法名
         */
        String fallbackMethod() default "";
    }
    import java.lang.annotation.Documented;
    import java.lang.annotation.ElementType;
    import java.lang.annotation.Retention;
    import java.lang.annotation.RetentionPolicy;
    import java.lang.annotation.Target;
    
    /**
     * 多重限流注解的存储器
     */
    @Target({ElementType.METHOD})
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    public @interface RateLimits {
    
        /**
         *
         * @return 注解列表
         */
        RedisRateLimitAttribute[] value() default {};
    }

    2. 切面方法

    import com.google.common.base.Strings;
    import org.aspectj.lang.JoinPoint;
    import org.aspectj.lang.ProceedingJoinPoint;
    import org.aspectj.lang.annotation.Around;
    import org.aspectj.lang.annotation.Aspect;
    import org.aspectj.lang.annotation.Pointcut;
    import org.aspectj.lang.reflect.MethodSignature;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.context.annotation.EnableAspectJAutoProxy;
    import org.springframework.core.DefaultParameterNameDiscoverer;
    import org.springframework.core.ParameterNameDiscoverer;
    import org.springframework.core.annotation.AnnotatedElementUtils;
    import org.springframework.core.annotation.Order;
    import org.springframework.core.env.Environment;
    import org.springframework.expression.EvaluationContext;
    import org.springframework.expression.ExpressionParser;
    import org.springframework.expression.spel.standard.SpelExpressionParser;
    import org.springframework.expression.spel.support.StandardEvaluationContext;
    import org.springframework.stereotype.Component;
    import org.springframework.util.CollectionUtils;
    import java.lang.reflect.InvocationTargetException;
    import java.lang.reflect.Method;
    import java.util.Arrays;
    import java.util.Collections;
    import java.util.Comparator;
    import java.util.List;
    import java.util.Optional;
    import java.util.stream.Collectors;
    import redis.clients.jedis.JedisCluster;
    
    //开启AspectJ 自动代理模式,如果不填proxyTargetClass=true,默认为false,
    @EnableAspectJAutoProxy(proxyTargetClass = true)
    @Component
    @Order(-1)
    @Aspect
    public class RedisRateLimitAspect {
        /**
         * 日志
         */
        private static Logger logger = LoggerFactory.getLogger(RedisRateLimitAspect.class);
    
        /**
         * SPEL表达式解析器
         */
        private static final ExpressionParser EXPRESSION_PARSER = new SpelExpressionParser();
    
        /**
         * 获取方法参数名称发现器
         */
        private static final ParameterNameDiscoverer PARAMETER_NAME_DISCOVERER = new DefaultParameterNameDiscoverer();
    
        /**
         * Redis集群
         */
        @Autowired
        private JedisCluster jedisCluster;
    
        /**
         * springboot自动加载配置信息
         */
        @Autowired
        private Environment environment;
    
        /**
         * 切面切入点
         */
        @Pointcut("@annotation(com.g2.order.server.annotation.RedisRateLimitAttribute)")
        public void rateLimit() {
    
        }
    
        /**
         * 环绕切面
         */
        @Around("rateLimit()")
        public Object handleControllerMethod(ProceedingJoinPoint proceedingJoinPoint) throws Throwable {
            //获取切入点对应的方法.
            MethodSignature methodSignature = (MethodSignature) proceedingJoinPoint.getSignature();
            Method method = methodSignature.getMethod();
    
            //获取注解列表
            List<RedisRateLimitAttribute> redisRateLimitAttributes =
                    AnnotatedElementUtils.findMergedRepeatableAnnotations(method, RedisRateLimitAttribute.class)
                            .stream()
                            .sorted(Comparator.comparing(RedisRateLimitAttribute::order))
                            .collect(Collectors.collectingAndThen(Collectors.toList(), Collections::unmodifiableList));
    
            if (CollectionUtils.isEmpty(redisRateLimitAttributes)) {
                return proceedingJoinPoint.proceed();
            }
    
            // 切入点所在的实例,调用fallback方法时需要
            Object target = proceedingJoinPoint.getTarget();
            // 方法入参集合,调用fallback方法时需要
            Object[] args = proceedingJoinPoint.getArgs();
            if (args == null) {
                args = new Object[0];
            }
    
            // 前置检查
            for (RedisRateLimitAttribute rateLimit : redisRateLimitAttributes) {
                // 获取限流设置的key(可能有配置占位符和spel表达式)
                String key = computeExpress(formatKey(rateLimit.key()), proceedingJoinPoint, String.class);
                // 获取限流配置的阀值
                long limitV = Long.parseLong(formatKey(rateLimit.limit()));
                // 获取当前key已记录的值
                String currentValue = jedisCluster.get(key);
                long currentV = Strings.isNullOrEmpty(currentValue) ? 0 : Long.parseLong(jedisCluster.get(key));
                // 当前值如果小于等于阀值,则合法;否则不合法
                boolean validated = currentV <= limitV;
                // 如果不合法则进入fallback流程
                if (!validated) {
                    // 获取当前限流配置的fallback
                    Method fallbackMethod = getFallbackMethod(proceedingJoinPoint, rateLimit.fallbackMethod());
                    // 如果fallback参数数量与切入点参数数量不一样,则压入空的返回值
                    if (fallbackMethod.getParameterCount() != method.getParameterCount()) {
                        Object[] args2 = Arrays.copyOf(args, args.length + 1);
                        args2[args2.length - 1] = null;
                        return invokeFallbackMethod(fallbackMethod, target, args2);
                    }
    
                    return invokeFallbackMethod(fallbackMethod, target, args);
                }
            }
    
            // 前置检查通过后,执行方法体
            Object result = proceedingJoinPoint.proceed();
    
            // 后置检查
            for (RedisRateLimitAttribute rateLimit : redisRateLimitAttributes) {
                // 获取限流设置的key(可能有配置占位符和spel表达式)
                String key = computeExpress(formatKey(rateLimit.key()), proceedingJoinPoint, String.class, result);
                // 获取限流配置的阀值
                long limitV = Long.parseLong(formatKey(rateLimit.limit()));
                // 获取限流配置的限流区间
                long interval = Long.parseLong(formatKey(rateLimit.intervalInMilliseconds()));
                boolean validated = true;
                // 计算当前一次执行后是否满足限流条件
                boolean incrMatch = match(proceedingJoinPoint, rateLimit, result);
                if (incrMatch) {
                    // 如果不存在key,则设置该key,并且超时时间为限流区间值
                    // 获取当前key已记录的值
                    String currentValue = jedisCluster.get(key);
                    // TODO 这里最好修改成 lua脚本来实现原子性
                    long currentV = Strings.isNullOrEmpty(currentValue) ? 0 : Long.parseLong(jedisCluster.get(key));
                    if (currentV == 0) {
                        jedisCluster.set(key, "1", "nx", "ex", interval);
                    } else {
                        jedisCluster.incrBy(key, 1);
                    }
                    validated = currentV +1 <= limitV;
                }
    
                if (!validated) {
                    // 获取fallback方法
                    // TODO 这里可以修改为已获取的话Map里,下次不需要再调用getFallbackMethod方法了
                    Method fallbackMethod = getFallbackMethod(proceedingJoinPoint, rateLimit.fallbackMethod());
                    Object[] args2 = Arrays.copyOf(args, args.length + 1);
                    args2[args2.length - 1] = result;
                    return invokeFallbackMethod(fallbackMethod, target, args2);
                }
            }
    
            return result;
        }
    
        /**
         * 计算spel表达式
         *
         * @param expression 表达式
         * @param context    上下文
         * @return String的缓存key
         */
        private <T> T computeExpress(String expression, JoinPoint context, Class<T> tClass) {
            // 计算表达式(根据参数上下文)
            return computeExpress(expression, context, tClass, null);
        }
    
        /**
         * 计算spel表达式
         *
         * @param expression 表达式
         * @param context    上下文
         * @return String的缓存key
         */
        private <T> T computeExpress(String expression, JoinPoint context, Class<T> tClass, Object returnValue) {
            // 将参数名与参数值放入参数上下文
            EvaluationContext evaluationContext = buildEvaluationContext(returnValue, context);
    
            // 计算表达式(根据参数上下文)
            return EXPRESSION_PARSER.parseExpression(expression).getValue(evaluationContext, tClass);
        }
    
        /**
         * 计算是否匹配限流策略
         * @param context
         * @param rateLimit
         * @param returnValue
         * @return
         */
        private boolean match(JoinPoint context, RedisRateLimitAttribute rateLimit, Object returnValue) {
            return computeExpress(rateLimit.incrCondition(), context, Boolean.class, returnValue);
        }
    
        /**
         * 格式化key
         * @param v
         * @return
         */
        private String formatKey(String v) {
            String result = v;
            if (Strings.isNullOrEmpty(result)) {
                throw new IllegalStateException("key配置不能为空");
            }
            return environment.resolvePlaceholders(result);
        }
    
        /**
         * 放入参数值到StandardEvaluationContext
         */
        private static void addParameterVariable(StandardEvaluationContext evaluationContext, JoinPoint context) {
            MethodSignature methodSignature = (MethodSignature) context.getSignature();
            Method method = methodSignature.getMethod();
            String[] parameterNames = PARAMETER_NAME_DISCOVERER.getParameterNames(method);
            if (parameterNames != null && parameterNames.length > 0) {
                Object[] args = context.getArgs();
                for (int i = 0; i < parameterNames.length; i++) {
                    evaluationContext.setVariable(parameterNames[i], args[i]);
                }
            }
        }
    
        /**
         * 放入返回值到StandardEvaluationContext
         */
        private static void addReturnValue(StandardEvaluationContext evaluationContext, Object returnValue) {
            evaluationContext.setVariable("returnValue", returnValue);
            evaluationContext.setVariable("response", returnValue);
        }
    
        /**
         * 构建StandardEvaluationContext
         */
        private static EvaluationContext buildEvaluationContext(Object returnValue, JoinPoint context) {
            StandardEvaluationContext evaluationContext = new StandardEvaluationContext();
            addParameterVariable(evaluationContext, context);
            addReturnValue(evaluationContext, returnValue);
    
            return evaluationContext;
        }
    
        /**
         * 获取降级方法
         *
         * @param context        过滤器上下文
         * @param fallbackMethod 失败要执行的函数
         * @return 降级方法
         */
        private static Method getFallbackMethod(JoinPoint context, String fallbackMethod) {
            MethodSignature methodSignature = (MethodSignature) context.getSignature();
            Class[] parameterTypes = Optional.ofNullable(methodSignature.getParameterTypes()).orElse(new Class[0]);
            try {
                Method method = context.getTarget().getClass().getDeclaredMethod(fallbackMethod, parameterTypes);
                method.setAccessible(true);
                return method;
            } catch (NoSuchMethodException e) {
    
            }
    
            try {
                Class[] parameterTypes2 = Arrays.copyOf(parameterTypes, parameterTypes.length + 1);
                parameterTypes2[parameterTypes2.length - 1] = methodSignature.getReturnType();
    
                Method method = context.getTarget().getClass().getDeclaredMethod(fallbackMethod, parameterTypes2);
                method.setAccessible(true);
                return method;
            } catch (NoSuchMethodException e) {
    
            }
    
            String message = String.format("获取fallbackMethod失败, context: %s, fallbackMethod: %s",
                    context, fallbackMethod);
            throw new RuntimeException(message);
        }
    
        /**
         * 执行降级fallback方法
         * @param fallbackMethod
         * @param fallbackTarget
         * @param fallbackArgs
         * @return
         * @throws Throwable
         */
        private static Object invokeFallbackMethod(Method fallbackMethod, Object fallbackTarget, Object[] fallbackArgs)
                throws Throwable {
            try {
                return fallbackMethod.invoke(fallbackTarget, fallbackArgs);
            } catch (InvocationTargetException e) {
                if (e.getCause() != null) {
                    throw e.getCause();
                }
                throw e;
            }
        }
    }

    3.调用事例

    @Slf4j
    @Api(value = "HomeController", description = "用户登录登出接口")
    @RestController
    @RequestMapping("/home")
    public class HomeController {
        private static Logger logger = LoggerFactory.getLogger(HomeController.class);
    
    
        @ApiOperation(value = "用户登录", notes = "用户登录接口")
        @RequestMapping(value = "/login",
                method = RequestMethod.POST,
                consumes = MediaType.APPLICATION_JSON_VALUE,
                produces = MediaType.APPLICATION_JSON_VALUE)
        @ResponseBody
      
        @RedisRateLimitAttribute(key = "'login'+#req.userId"
                , limit = "${login.maxFailedTimes:3}"
                , incrCondition = "#response.success == true"
                , intervalInMilliseconds = "${login.limit.millseconds:3600}"
                , fallbackMethod = "loginFallback"
        )
        public UserLoginResp login(@RequestBody UserLoginReq req) {
            logger.info("进入登陆业务");
            
            UserModel userModel = new UserModel();
            userModel.setRoleId(123);
            userModel.setUserId(req.getUserId());
            userModel.setMustValidateCode(false);
    
            return new UserLoginResp(userModel);
        }
    
        private UserLoginResp loginFallback(UserLoginReq req, UserLoginResp resp) {
            if (resp == null) {
                return new UserLoginResp();
               
            }
            resp.getPayload().setMustValidateCode(true);
            return resp;
        }
    }
    @Data
    public class UserModel {
        /***
         * 用户id
         */
        private String userId;
    
        /**
         * 角色
         */
        private String roleName;
    
        /**
         * 角色编号
         */
        private Integer roleId;
    
        /**
         * 登陆是否需要验证码
         * 当错误次数达到阀值时,需要验证码来增加提交难度
         */
        private Boolean mustValidateCode;
    }
    import lombok.Data;
    
    @Data
    public class Response<T> {
        private Boolean success;
        private String errorMessage;
        private T payload;
    
        public Response() {
            this(true);
        }
    
        public Response(boolean succ) {
            this(succ, "");
        }
    
        public Response(boolean succ, String msg) {
            this(succ, msg, null);
        }
    
        public Response(T data) {
            this(true, "", data);
        }
    
        public Response(boolean succ, String msg, T data) {
            success = succ;
            errorMessage = msg;
            this.payload = data;
        }
    }
    public class UserLoginResp extends Response<UserModel> {
        public UserLoginResp(){
        }
        public UserLoginResp(UserModel userModel){
            super(userModel);
        }
    
        @Override
        public String toString() {
            return super.toString();
        }
    }
  • 相关阅读:
    对《软件工程》这门课的总结
    结对编程项目---四则运算
    PSP记录个人项目耗时情况
    代码复审
    是否需要有代码规范
    四则运算的实现(C++)重做
    四则运算器的实现
    学习进度总结
    通过阅读教材,所得的不懂的问题
    自我介绍
  • 原文地址:https://www.cnblogs.com/zhshlimi/p/11835401.html
Copyright © 2011-2022 走看看