1.定义注解
import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Repeatable; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; /** * redis缓存的注解 * */ @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @Documented @Repeatable(RateLimits.class) public @interface RedisRateLimitAttribute { /** * {@link #key()}的别名 * * @return key()的别名 */ String value() default ""; /** * key, 支持SpEL表达式解析 * * @return 限流的key值 */ String key() default ""; /** * 限流的优先级 * * @return 限流器的优先级 */ int order() default 0; /** * 执行计数的条件表达式,支持SpEL表达式,如果结果为真,则执行计数 * * @return 执行计数的条件表达式 */ String incrCondition() default "true"; /** * 限流的最大值,支持配置引用 * * @return 限流的最大值 */ String limit() default "1"; /** * 限流的时间范围值,支持配置引用 * * @return 限流的时间范围值 */ String intervalInMilliseconds() default "1000"; /** * 降级的方法名,降级方法的参数与原方法一致或多了一个原方法的ReturnValue的类型 * * @return 降级的方法名 */ String fallbackMethod() default ""; }
import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; /** * 多重限流注解的存储器 */ @Target({ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface RateLimits { /** * * @return 注解列表 */ RedisRateLimitAttribute[] value() default {}; }
2. 切面方法
import com.google.common.base.Strings; import org.aspectj.lang.JoinPoint; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Pointcut; import org.aspectj.lang.reflect.MethodSignature; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.EnableAspectJAutoProxy; import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.ParameterNameDiscoverer; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.Order; import org.springframework.core.env.Environment; import org.springframework.expression.EvaluationContext; import org.springframework.expression.ExpressionParser; import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.expression.spel.support.StandardEvaluationContext; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; import redis.clients.jedis.JedisCluster; //开启AspectJ 自动代理模式,如果不填proxyTargetClass=true,默认为false, @EnableAspectJAutoProxy(proxyTargetClass = true) @Component @Order(-1) @Aspect public class RedisRateLimitAspect { /** * 日志 */ private static Logger logger = LoggerFactory.getLogger(RedisRateLimitAspect.class); /** * SPEL表达式解析器 */ private static final ExpressionParser EXPRESSION_PARSER = new SpelExpressionParser(); /** * 获取方法参数名称发现器 */ private static final ParameterNameDiscoverer PARAMETER_NAME_DISCOVERER = new DefaultParameterNameDiscoverer(); /** * Redis集群 */ @Autowired private JedisCluster jedisCluster; /** * springboot自动加载配置信息 */ @Autowired private Environment environment; /** * 切面切入点 */ @Pointcut("@annotation(com.g2.order.server.annotation.RedisRateLimitAttribute)") public void rateLimit() { } /** * 环绕切面 */ @Around("rateLimit()") public Object handleControllerMethod(ProceedingJoinPoint proceedingJoinPoint) throws Throwable { //获取切入点对应的方法. MethodSignature methodSignature = (MethodSignature) proceedingJoinPoint.getSignature(); Method method = methodSignature.getMethod(); //获取注解列表 List<RedisRateLimitAttribute> redisRateLimitAttributes = AnnotatedElementUtils.findMergedRepeatableAnnotations(method, RedisRateLimitAttribute.class) .stream() .sorted(Comparator.comparing(RedisRateLimitAttribute::order)) .collect(Collectors.collectingAndThen(Collectors.toList(), Collections::unmodifiableList)); if (CollectionUtils.isEmpty(redisRateLimitAttributes)) { return proceedingJoinPoint.proceed(); } // 切入点所在的实例,调用fallback方法时需要 Object target = proceedingJoinPoint.getTarget(); // 方法入参集合,调用fallback方法时需要 Object[] args = proceedingJoinPoint.getArgs(); if (args == null) { args = new Object[0]; } // 前置检查 for (RedisRateLimitAttribute rateLimit : redisRateLimitAttributes) { // 获取限流设置的key(可能有配置占位符和spel表达式) String key = computeExpress(formatKey(rateLimit.key()), proceedingJoinPoint, String.class); // 获取限流配置的阀值 long limitV = Long.parseLong(formatKey(rateLimit.limit())); // 获取当前key已记录的值 String currentValue = jedisCluster.get(key); long currentV = Strings.isNullOrEmpty(currentValue) ? 0 : Long.parseLong(jedisCluster.get(key)); // 当前值如果小于等于阀值,则合法;否则不合法 boolean validated = currentV <= limitV; // 如果不合法则进入fallback流程 if (!validated) { // 获取当前限流配置的fallback Method fallbackMethod = getFallbackMethod(proceedingJoinPoint, rateLimit.fallbackMethod()); // 如果fallback参数数量与切入点参数数量不一样,则压入空的返回值 if (fallbackMethod.getParameterCount() != method.getParameterCount()) { Object[] args2 = Arrays.copyOf(args, args.length + 1); args2[args2.length - 1] = null; return invokeFallbackMethod(fallbackMethod, target, args2); } return invokeFallbackMethod(fallbackMethod, target, args); } } // 前置检查通过后,执行方法体 Object result = proceedingJoinPoint.proceed(); // 后置检查 for (RedisRateLimitAttribute rateLimit : redisRateLimitAttributes) { // 获取限流设置的key(可能有配置占位符和spel表达式) String key = computeExpress(formatKey(rateLimit.key()), proceedingJoinPoint, String.class, result); // 获取限流配置的阀值 long limitV = Long.parseLong(formatKey(rateLimit.limit())); // 获取限流配置的限流区间 long interval = Long.parseLong(formatKey(rateLimit.intervalInMilliseconds())); boolean validated = true; // 计算当前一次执行后是否满足限流条件 boolean incrMatch = match(proceedingJoinPoint, rateLimit, result); if (incrMatch) { // 如果不存在key,则设置该key,并且超时时间为限流区间值 // 获取当前key已记录的值 String currentValue = jedisCluster.get(key); // TODO 这里最好修改成 lua脚本来实现原子性 long currentV = Strings.isNullOrEmpty(currentValue) ? 0 : Long.parseLong(jedisCluster.get(key)); if (currentV == 0) { jedisCluster.set(key, "1", "nx", "ex", interval); } else { jedisCluster.incrBy(key, 1); } validated = currentV +1 <= limitV; } if (!validated) { // 获取fallback方法 // TODO 这里可以修改为已获取的话Map里,下次不需要再调用getFallbackMethod方法了 Method fallbackMethod = getFallbackMethod(proceedingJoinPoint, rateLimit.fallbackMethod()); Object[] args2 = Arrays.copyOf(args, args.length + 1); args2[args2.length - 1] = result; return invokeFallbackMethod(fallbackMethod, target, args2); } } return result; } /** * 计算spel表达式 * * @param expression 表达式 * @param context 上下文 * @return String的缓存key */ private <T> T computeExpress(String expression, JoinPoint context, Class<T> tClass) { // 计算表达式(根据参数上下文) return computeExpress(expression, context, tClass, null); } /** * 计算spel表达式 * * @param expression 表达式 * @param context 上下文 * @return String的缓存key */ private <T> T computeExpress(String expression, JoinPoint context, Class<T> tClass, Object returnValue) { // 将参数名与参数值放入参数上下文 EvaluationContext evaluationContext = buildEvaluationContext(returnValue, context); // 计算表达式(根据参数上下文) return EXPRESSION_PARSER.parseExpression(expression).getValue(evaluationContext, tClass); } /** * 计算是否匹配限流策略 * @param context * @param rateLimit * @param returnValue * @return */ private boolean match(JoinPoint context, RedisRateLimitAttribute rateLimit, Object returnValue) { return computeExpress(rateLimit.incrCondition(), context, Boolean.class, returnValue); } /** * 格式化key * @param v * @return */ private String formatKey(String v) { String result = v; if (Strings.isNullOrEmpty(result)) { throw new IllegalStateException("key配置不能为空"); } return environment.resolvePlaceholders(result); } /** * 放入参数值到StandardEvaluationContext */ private static void addParameterVariable(StandardEvaluationContext evaluationContext, JoinPoint context) { MethodSignature methodSignature = (MethodSignature) context.getSignature(); Method method = methodSignature.getMethod(); String[] parameterNames = PARAMETER_NAME_DISCOVERER.getParameterNames(method); if (parameterNames != null && parameterNames.length > 0) { Object[] args = context.getArgs(); for (int i = 0; i < parameterNames.length; i++) { evaluationContext.setVariable(parameterNames[i], args[i]); } } } /** * 放入返回值到StandardEvaluationContext */ private static void addReturnValue(StandardEvaluationContext evaluationContext, Object returnValue) { evaluationContext.setVariable("returnValue", returnValue); evaluationContext.setVariable("response", returnValue); } /** * 构建StandardEvaluationContext */ private static EvaluationContext buildEvaluationContext(Object returnValue, JoinPoint context) { StandardEvaluationContext evaluationContext = new StandardEvaluationContext(); addParameterVariable(evaluationContext, context); addReturnValue(evaluationContext, returnValue); return evaluationContext; } /** * 获取降级方法 * * @param context 过滤器上下文 * @param fallbackMethod 失败要执行的函数 * @return 降级方法 */ private static Method getFallbackMethod(JoinPoint context, String fallbackMethod) { MethodSignature methodSignature = (MethodSignature) context.getSignature(); Class[] parameterTypes = Optional.ofNullable(methodSignature.getParameterTypes()).orElse(new Class[0]); try { Method method = context.getTarget().getClass().getDeclaredMethod(fallbackMethod, parameterTypes); method.setAccessible(true); return method; } catch (NoSuchMethodException e) { } try { Class[] parameterTypes2 = Arrays.copyOf(parameterTypes, parameterTypes.length + 1); parameterTypes2[parameterTypes2.length - 1] = methodSignature.getReturnType(); Method method = context.getTarget().getClass().getDeclaredMethod(fallbackMethod, parameterTypes2); method.setAccessible(true); return method; } catch (NoSuchMethodException e) { } String message = String.format("获取fallbackMethod失败, context: %s, fallbackMethod: %s", context, fallbackMethod); throw new RuntimeException(message); } /** * 执行降级fallback方法 * @param fallbackMethod * @param fallbackTarget * @param fallbackArgs * @return * @throws Throwable */ private static Object invokeFallbackMethod(Method fallbackMethod, Object fallbackTarget, Object[] fallbackArgs) throws Throwable { try { return fallbackMethod.invoke(fallbackTarget, fallbackArgs); } catch (InvocationTargetException e) { if (e.getCause() != null) { throw e.getCause(); } throw e; } } }
3.调用事例
@Slf4j @Api(value = "HomeController", description = "用户登录登出接口") @RestController @RequestMapping("/home") public class HomeController { private static Logger logger = LoggerFactory.getLogger(HomeController.class); @ApiOperation(value = "用户登录", notes = "用户登录接口") @RequestMapping(value = "/login", method = RequestMethod.POST, consumes = MediaType.APPLICATION_JSON_VALUE, produces = MediaType.APPLICATION_JSON_VALUE) @ResponseBody @RedisRateLimitAttribute(key = "'login'+#req.userId" , limit = "${login.maxFailedTimes:3}" , incrCondition = "#response.success == true" , intervalInMilliseconds = "${login.limit.millseconds:3600}" , fallbackMethod = "loginFallback" ) public UserLoginResp login(@RequestBody UserLoginReq req) { logger.info("进入登陆业务"); UserModel userModel = new UserModel(); userModel.setRoleId(123); userModel.setUserId(req.getUserId()); userModel.setMustValidateCode(false); return new UserLoginResp(userModel); } private UserLoginResp loginFallback(UserLoginReq req, UserLoginResp resp) { if (resp == null) { return new UserLoginResp(); } resp.getPayload().setMustValidateCode(true); return resp; } }
@Data public class UserModel { /*** * 用户id */ private String userId; /** * 角色 */ private String roleName; /** * 角色编号 */ private Integer roleId; /** * 登陆是否需要验证码 * 当错误次数达到阀值时,需要验证码来增加提交难度 */ private Boolean mustValidateCode; }
import lombok.Data; @Data public class Response<T> { private Boolean success; private String errorMessage; private T payload; public Response() { this(true); } public Response(boolean succ) { this(succ, ""); } public Response(boolean succ, String msg) { this(succ, msg, null); } public Response(T data) { this(true, "", data); } public Response(boolean succ, String msg, T data) { success = succ; errorMessage = msg; this.payload = data; } }
public class UserLoginResp extends Response<UserModel> { public UserLoginResp(){ } public UserLoginResp(UserModel userModel){ super(userModel); } @Override public String toString() { return super.toString(); } }