一、需求背景
对外提供服务的接口需要统一做验签和参数合法性校验。每个接口的加签算法相同,不同的是参数的不为空的要求不同。
要求,在controller层外做校验,校验不通过直接返回,不进入controller层。
二、需求实现前代码
在这之前已经对每个请求做了AOP拦截,对每个请求植入了线程号。以及统计每个接口的执行耗时,打印每个接口的返回结果,捕获接口的未检查异常并打印和封装返回结果。
如:
/** * 为每一个的HTTP请求添加线程号 * * @author yangyongjie * @date 2019/9/2 * @desc */ @Order(1) @Aspect @Component public class LogAspect { private static final Logger LOGGER = LoggerFactory.getLogger(LogAspect.class); @Pointcut(value = "@annotation(org.springframework.web.bind.annotation.RequestMapping)") private void webPointcut() { // doNothing } /** * 为所有的HTTP请求添加线程号 * * @param joinPoint * @throws Throwable */ @Around(value = "webPointcut()") public Object around(ProceedingJoinPoint joinPoint) { // 执行开始的时间 Long beginTime = System.currentTimeMillis(); // 方法执行前加上线程号,并将线程号放到线程本地变量中 MDCUtil.init(); // 获取切点的方法名 String methodName = joinPoint.getSignature().getName(); // 执行拦截的方法 Object result = null; try { result = joinPoint.proceed(); } catch (Throwable throwable) { LOGGER.error("{}方法执行异常:" + throwable.getMessage(), methodName, throwable); LogUtil.sendErrorLogMail("系统异常", throwable); result = new CommonResult(ResponseEnum.ERROR_SYSTEM.getCode(), ResponseEnum.ERROR_SYSTEM.getMsg()); } finally { LOGGER.info("{}方法返回结果:{}", methodName, JacksonJsonUtil.toString(result)); Long endTime = System.currentTimeMillis(); LOGGER.info("{}方法耗时{}毫秒", methodName, endTime - beginTime); // 方法执行结束移除线程号,并移除线程本地变量,防止内存泄漏 MDCUtil.remove(); } return result; } }
@Order(1) :为多个AOP切面排序,数字越小,先执行谁。
MDCUtil:
/** * 日志相关工具类 * * @author yangyongjie * @date 2019/9/17 * @desc */ public class MDCUtil { private MDCUtil() { } private static final String STR_THREAD_ID = "threadId"; /** * 初始化日志参数并保存在线程副本中 */ public static void init() { String uuid = UUID.randomUUID().toString().replaceAll("-", ""); MDC.put(STR_THREAD_ID, uuid); ThreadContext.currentThreadContext().setThreadId(uuid); } /** * 初始化日志参数 */ public static void initWithOutContext() { String uuid = UUID.randomUUID().toString().replaceAll("-", ""); MDC.put(STR_THREAD_ID, uuid); } /** * 移除线程号和线程副本 */ public static void remove() { MDC.remove(STR_THREAD_ID); ThreadContext.remove(); } /** * 移除线程号 */ public static void removeWithOutContext() { MDC.remove(STR_THREAD_ID); } }
线程上下文ThreadContext:
/** * 线程上下文,一个线程内所需的上下文变量参数,使用ThreadLocal保存副本 * * @author yangyongjie * @date 2019/9/12 * @desc */ public class ThreadContext { /** * 每个线程的私有变量,每个线程都有独立的变量副本,所以使用private static final修饰,因为都需要复制进入本地线程 */ private static final ThreadLocal<ThreadContext> THREAD_LOCAL = new ThreadLocal<ThreadContext>() { @Override protected ThreadContext initialValue() { return new ThreadContext(); } }; public static ThreadContext currentThreadContext() { /*ThreadContext threadContext = THREAD_LOCAL.get(); if (threadContext == null) { THREAD_LOCAL.set(new ThreadContext()); threadContext = THREAD_LOCAL.get(); } return threadContext;*/ return THREAD_LOCAL.get(); } public static void remove() { THREAD_LOCAL.remove(); } /** * 线程号 */ private String threadId; /** * 请求参数 */ private Object requestParam; public String getThreadId() { return threadId; } public void setThreadId(String threadId) { this.threadId = threadId; } public Object getRequestParam() { return requestParam; } public void setRequestParam(Object requestParam) { this.requestParam = requestParam; } @Override public String toString() { return JacksonJsonUtil.toString(this); } }
公共返回结果类:
/** * 用于返回给调用方执行结果的公共结果类 * 自定义返回结果继承此类即可 * * @author yangyongjie * @date 2019/9/25 * @desc */ public class CommonResult { /** * 返回码,0000表示成功,其余都是失败,9998表示入参不符合要求,9999表示系统异常 */ private String code = "0000"; /** * 返回信息 */ private String msg = "success"; public CommonResult() { } public CommonResult(String code, String msg) { this.code = code; this.msg = msg; } /** * 失败情况 */ public void fail(String code, String msg) { this.code = code; this.msg = msg; } /** * 判断是否成功 */ @JsonIgnore public boolean isSuccess() { return StringUtils.equals("0000", code); } public String getCode() { return code; } public void setCode(String code) { this.code = code; } public String getMsg() { return msg; } public void setMsg(String msg) { this.msg = msg; } }
三、需求具体实现
1、现在需要再增加一个切面,对需要做验签和参数校验的接口拦截并校验
1)自定义注解,作用在controller层的方法上,标识此接口需要验签和验参,其有两个属性,一个是方法返回类型,一个是接收参数的实体类。
方法返回类型用来切面校验不通过封装返回数据,接收参数的实体类对需要验不为空的方法标志了注解,需在切面中进行校验。
/** * 对外请求参数校验注解 * * @author yangyongjie * @date 2019/11/5 * @desc */ @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface Check { /** * 方法的返回值类型,继承了CommonResult */ Class<? extends CommonResult> value(); /** * 校验的目标实体类 */ Class<?> paramBean(); }
如接收参数的实体类定义:
public class AuthTokenRequest extends BaseRequest { /** * 值为authorization_code */ @ParamVerify(nullable = CheckEnum.NOTNULL) private String grant_type; } public class BaseRequest { /** * 签名 */ @ParamVerify(nullable = CheckEnum.NOTNULL) private String sign; /** * 分配的接入id */ @ParamVerify(nullable = CheckEnum.NOTNULL) private String partnerId; }
属性校验注解:
/** * 字段校验注解,目前只进行非空校验,可扩展 */ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.FIELD) public @interface ParamVerify { /** * 是否允许为空 */ CheckEnum nullable() default CheckEnum.NULL; }
验签验参切面:
/** * 对外同步接口参数校验切面 * * @author yangyongjie * @date 2019/11/5 * @desc */ @Order(2) @Aspect @Component public class CheckAspect { private static final Logger LOGGER = LoggerFactory.getLogger(CheckAspect.class); /** * 验签公钥 */ @Value("${fx.publicKey}") private String fxPublicKey; @Autowired private OutgoingPartnerInfoDao outgoingPartnerInfoDao; @Pointcut("@annotation(com.xiaomi.mitv.outgoing.common.annotation.Check)") private void webPointcut() { // donothing } @Around(value = "webPointcut()") public Object around(ProceedingJoinPoint joinPoint) throws Throwable { // 获取被增强的方法的相关信息 MethodSignature ms = (MethodSignature) joinPoint.getSignature(); // 获取被增强的方法 Method pointcutMethod = ms.getMethod(); String methodName = pointcutMethod.getName(); // 对于对外接口,统一进行参数校验 CommonResult commonResult = null; // 判断方法上有没有@Check注解 if (pointcutMethod.isAnnotationPresent(Check.class)) { // 获取到拦截方法的HttpServletRequest // 获取当前方法执行的上下文的request HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest(); // 获取body请求参数 String bodyString = HttpUtil.getRequestBody(request); // Map<String, Object> originMap = JacksonJsonUtil.toObject(bodyString, Map.class); Map<String, Object> originMap = HttpUtil.fromJsonToObject(bodyString, Map.class); // 将请求参数放到线程本地拷贝中 ThreadContext.currentThreadContext().setRequestParam(originMap); // 得到方法上的Check注解 Check check = pointcutMethod.getAnnotation(Check.class); // 获取切点方法的返回类型 Class<?> returnType = check.value(); // 创建对象 commonResult = (CommonResult) returnType.newInstance(); // 获取参数签名 String sign = request.getParameter("sign"); LOGGER.info("{}-sign={}", methodName, sign); // 参数校验 Class<?> beanType = check.paramBean(); originMap.put("sign", sign); if(!HttpUtil.paramCheck(originMap, beanType)){ commonResult.fail(ResponseEnum.ERROR_PARAM.getCode(), ResponseEnum.ERROR_PARAM.getMsg()); return commonResult; } String partnerId = String.valueOf(originMap.get("partnerId")); if (!StringUtil.areNotEmpty(partnerId, sign)) { commonResult.fail(ResponseEnum.ERROR_PARAM_NULL.getCode(), ResponseEnum.ERROR_PARAM_NULL.getMsg()); return commonResult; } // 校验partnerId的有效性 if (!checkPartnerId(partnerId)) { commonResult.fail(ResponseEnum.ERROR_APP_INVALID.getCode(), ResponseEnum.ERROR_APP_INVALID.getMsg()); return commonResult; } // 组装加签串 String paramBody = HttpUtil.getAssembleParam(originMap); // 验签 boolean pass; try { pass = RSAUtil.rsa256CheckContent(paramBody, sign, fxPublicKey); } catch (BssException e) { LogUtil.LogAndMail("验签异常", e); commonResult.fail(ResponseEnum.ERROR_SYSTEM.getCode(), ResponseEnum.ERROR_SYSTEM.getMsg()); return commonResult; } if (!pass) { commonResult.fail(ResponseEnum.ERROR_CHECK_SIGN_FAIL.getCode(), ResponseEnum.ERROR_CHECK_SIGN_FAIL.getMsg()); return commonResult; } } // 执行增强方法 Object result = joinPoint.proceed(); return result; } /** * 校验partnerId的有效性,先查缓存,缓存中没有的话再查询数据库,使用互斥锁 * * @param partnerId * @return */ private boolean checkPartnerId(String partnerId) { // 先查询缓存,值为1表示存在且有效,值为0表示存在但无效,值为null表示不存在 String val = RedisUtil.get(CommonConstants.PARTNER_ID + partnerId); if (StringUtils.isEmpty(val)) { // 缓存中不存在,先拿到互斥锁,再查询数据库,并放进缓存中 // 获取互斥锁 String mutexKey = CommonConstants.NX_PARTNER_ID + partnerId; boolean flag = RedisUtil.setex(mutexKey, CommonConstants.STR_ONE, 60); // 拿到锁 if (flag) { // 查询数据库 OutgoingPartnerInfoDto partnerInfoDto = outgoingPartnerInfoDao.getByPartnerId(partnerId); if (partnerInfoDto != null && StringUtils.equals(CommonConstants.STR_ONE, partnerInfoDto.getStatus())) { // partnerId 存在且有效 RedisUtil.set(CommonConstants.PARTNER_ID + partnerId, CommonConstants.STR_ONE); // 删除锁 RedisUtil.del(mutexKey); return true; } else { // partnerId 不存在或无效 RedisUtil.set(CommonConstants.PARTNER_ID + partnerId, CommonConstants.STR_ZERO); return false; } } else { //休息50毫秒后重试 try { Thread.sleep(50); } catch (InterruptedException e) { LOGGER.error("获取partnerId互斥锁异常" + e.getMessage(), e); } return checkPartnerId(partnerId); } // val 不为空 } else { return StringUtils.equals(CommonConstants.STR_ONE, val); } } }
HttpUtil工具类:
public class HttpUtil { private HttpUtil() { } private static final Logger LOGGER = LoggerFactory.getLogger(HttpUtil.class); /** * 获取request中的body信息 JSON格式 * * @param request * @return */ public static String getRequestBody(HttpServletRequest request) { BufferedReader br = null; StringBuilder bodyDataBuilder = new StringBuilder(); try { br = request.getReader(); String str; while ((str = br.readLine()) != null) { bodyDataBuilder.append(str); } br.close(); } catch (IOException e) { LOGGER.error(e.getMessage(), e); } finally { if (null != br) { try { br.close(); } catch (IOException e) { LOGGER.error(e.getMessage(), e); } } } String bodyString = bodyDataBuilder.toString(); LOGGER.info("bodyString={}", bodyString); return bodyString; } /** * 获取request中的body信息,并组装好按“参数=参数值”的格式 * * @param request * @return */ public static String getAssembleRequestBody(HttpServletRequest request) { String bodyString = getRequestBody(request); Map<String, Object> originMap = JacksonJsonUtil.toObject(bodyString, Map.class); Map<String, Object> sortedParams = getSortedMap(originMap); String assembleBody = getSignContent(sortedParams); return assembleBody; } /** * 根据requestBody中的原始map获取解析后并组装的参数字符串,根据&符拼接 * * @param originMap * @return */ public static String getAssembleParam(Map<String, Object> originMap) { return getSignContent(getSortedMap(originMap)); } /** * 将body转成按key首字母排好序 * * @return */ public static Map<String, Object> getSortedMap(Map<String, Object> originMap) { Map<String, Object> sortedParams = new TreeMap<String, Object>(); if (originMap != null && originMap.size() > 0) { sortedParams.putAll(originMap); } return sortedParams; } /** * 将排序好的map的key和value拼接成字符串 * * @param sortedParams * @return */ public static String getSignContent(Map<String, Object> sortedParams) { StringBuffer content = new StringBuffer(); List<String> keys = new ArrayList<String>(sortedParams.keySet()); Collections.sort(keys); int index = 0; for (int i = 0; i < keys.size(); i++) { String key = keys.get(i); Object value = sortedParams.get(key); if (StringUtils.isNotEmpty(key) && value != null) { content.append((index == 0 ? "" : "&") + key + "=" + value); index++; } } return content.toString(); } /** * Json转实体对象 * * @param jsonStr * @param clazz 目标生成实体对象 * @return */ public static <T> T fromJsonToObject(String jsonStr, Class clazz) { T results = null; try { results = (T) JacksonJsonUtil.toObject(jsonStr, clazz); } catch (Exception e) { } return results; } /** * 对请求参数进行校验,目前只进行非空校验 * * @param srcData body数据 * @param tarClass 校验规则 * @return 校验成功返回true */ public static <T> boolean paramCheck(Map<String, Object> srcData, Class<T> tarClass){ try { Field[] fields = tarClass.getDeclaredFields(); for(Field field : fields){ ParamVerify verify = field.getAnnotation(ParamVerify.class); if(verify != null){ //非空校验,后续若需增加校验类型,应抽离 if(verify.nullable() == CheckEnum.NOTNULL){ String fn = field.getName(); Object val = srcData.get(fn); if(val == null || "".equals(val.toString())){ return false; } } } } }catch (Exception ex){ LOGGER.info("Param verify error"); return false; } return true; } }
日志工具类:
/** * 打印日志并发送错误邮件 * * @param msg * @param t */ public static void LogAndMail(String msg, Throwable t) { // 获取调用此工具类的该方法 的调用方信息 // 查询当前线程的堆栈信息 StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); // 按照规则,此方法的上一级调用类为 StackTraceElement ste = stackTrace[2]; String className = ste.getClassName(); String methodName = ste.getMethodName(); LOGGER.error("{}#{},{}," + t.getMessage(), className, methodName, msg, t); // 异步发送邮件 String ms = "[" + ThreadContext.currentThreadContext().getThreadId() + "]" + msg; executor.execute(() -> SendMailUtil.sendErrorMail(ms, t, 3)); } /** * 只发送错误邮件不打印日志 * * @param msg */ public static void sendErrorLogMail(String msg, Throwable t) { // 异步发送邮件 String ms = "[" + ThreadContext.currentThreadContext().getThreadId() + "]" + msg + assembleStackTrace(t); executor.execute(() -> SendMailUtil.sendErrorMail(ms, t, 3)); } /** * 组装异常堆栈 * * @param t * @return */ public static String assembleStackTrace(Throwable t) { StringWriter sw = new StringWriter(); PrintWriter ps = new PrintWriter(sw); t.printStackTrace(ps); return sw.toString(); }
有关两个切面的执行顺序问题,请参考:https://www.cnblogs.com/yangyongjie/p/11800862.html
END