zoukankan      html  css  js  c++  java
  • 使用AOP统一验签和校参

    一、需求背景

      对外提供服务的接口需要统一做验签和参数合法性校验。每个接口的加签算法相同,不同的是参数的不为空的要求不同。

      要求,在controller层外做校验,校验不通过直接返回,不进入controller层。

    二、需求实现前代码

     在这之前已经对每个请求做了AOP拦截,对每个请求植入了线程号。以及统计每个接口的执行耗时,打印每个接口的返回结果,捕获接口的未检查异常并打印和封装返回结果。

     如:

    /**
     * 为每一个的HTTP请求添加线程号
     *
     * @author yangyongjie
     * @date 2019/9/2
     * @desc
     */
    @Order(1)
    @Aspect
    @Component
    public class LogAspect {
    
        private static final Logger LOGGER = LoggerFactory.getLogger(LogAspect.class);
    
        @Pointcut(value = "@annotation(org.springframework.web.bind.annotation.RequestMapping)")
        private void webPointcut() {
            // doNothing
        }
    
        /**
         * 为所有的HTTP请求添加线程号
         *
         * @param joinPoint
         * @throws Throwable
         */
        @Around(value = "webPointcut()")
        public Object around(ProceedingJoinPoint joinPoint) {
            // 执行开始的时间
            Long beginTime = System.currentTimeMillis();
            // 方法执行前加上线程号,并将线程号放到线程本地变量中
            MDCUtil.init();
            // 获取切点的方法名
            String methodName = joinPoint.getSignature().getName();
            // 执行拦截的方法
            Object result = null;
            try {
                result = joinPoint.proceed();
            } catch (Throwable throwable) {
                LOGGER.error("{}方法执行异常:" + throwable.getMessage(), methodName, throwable);
                LogUtil.sendErrorLogMail("系统异常", throwable);
                result = new CommonResult(ResponseEnum.ERROR_SYSTEM.getCode(), ResponseEnum.ERROR_SYSTEM.getMsg());
            } finally {
                LOGGER.info("{}方法返回结果:{}", methodName, JacksonJsonUtil.toString(result));
                Long endTime = System.currentTimeMillis();
                LOGGER.info("{}方法耗时{}毫秒", methodName, endTime - beginTime);
                // 方法执行结束移除线程号,并移除线程本地变量,防止内存泄漏
                MDCUtil.remove();
            }
            return result;
        }
    }

    @Order(1) :为多个AOP切面排序,数字越小,先执行谁。

    MDCUtil:

    /**
     * 日志相关工具类
     *
     * @author yangyongjie
     * @date 2019/9/17
     * @desc
     */
    public class MDCUtil {
        private MDCUtil() {
        }
    
        private static final String STR_THREAD_ID = "threadId";
    
        /**
         * 初始化日志参数并保存在线程副本中
         */
        public static void init() {
            String uuid = UUID.randomUUID().toString().replaceAll("-", "");
            MDC.put(STR_THREAD_ID, uuid);
            ThreadContext.currentThreadContext().setThreadId(uuid);
        }
    
        /**
         * 初始化日志参数
         */
        public static void initWithOutContext() {
            String uuid = UUID.randomUUID().toString().replaceAll("-", "");
            MDC.put(STR_THREAD_ID, uuid);
        }
    
        /**
         * 移除线程号和线程副本
         */
        public static void remove() {
            MDC.remove(STR_THREAD_ID);
            ThreadContext.remove();
        }
    
        /**
         * 移除线程号
         */
        public static void removeWithOutContext() {
            MDC.remove(STR_THREAD_ID);
        }
    }

    线程上下文ThreadContext:

    /**
     * 线程上下文,一个线程内所需的上下文变量参数,使用ThreadLocal保存副本
     *
     * @author yangyongjie
     * @date 2019/9/12
     * @desc
     */
    public class ThreadContext {
        /**
         * 每个线程的私有变量,每个线程都有独立的变量副本,所以使用private static final修饰,因为都需要复制进入本地线程
         */
        private static final ThreadLocal<ThreadContext> THREAD_LOCAL = new ThreadLocal<ThreadContext>() {
            @Override
            protected ThreadContext initialValue() {
                return new ThreadContext();
            }
        };
    
        public static ThreadContext currentThreadContext() {
            /*ThreadContext threadContext = THREAD_LOCAL.get();
            if (threadContext == null) {
                THREAD_LOCAL.set(new ThreadContext());
                threadContext = THREAD_LOCAL.get();
            }
            return threadContext;*/
            return THREAD_LOCAL.get();
        }
    
        public static void remove() {
            THREAD_LOCAL.remove();
        }
    
        /**
         * 线程号
         */
        private String threadId;
    
        /**
         * 请求参数
         */
        private Object requestParam;
    
        public String getThreadId() {
            return threadId;
        }
    
        public void setThreadId(String threadId) {
            this.threadId = threadId;
        }
    
        public Object getRequestParam() {
            return requestParam;
        }
    
        public void setRequestParam(Object requestParam) {
            this.requestParam = requestParam;
        }
    
        @Override
        public String toString() {
            return JacksonJsonUtil.toString(this);
        }
    }

     公共返回结果类:

    /**
     * 用于返回给调用方执行结果的公共结果类
     * 自定义返回结果继承此类即可
     *
     * @author yangyongjie
     * @date 2019/9/25
     * @desc
     */
    public class CommonResult {
        /**
         * 返回码,0000表示成功,其余都是失败,9998表示入参不符合要求,9999表示系统异常
         */
        private String code = "0000";
        /**
         * 返回信息
         */
        private String msg = "success";
    
        public CommonResult() {
        }
    
    
        public CommonResult(String code, String msg) {
            this.code = code;
            this.msg = msg;
        }
    
        /**
         * 失败情况
         */
        public void fail(String code, String msg) {
            this.code = code;
            this.msg = msg;
        }
    
        /**
         * 判断是否成功
         */
        @JsonIgnore
        public boolean isSuccess() {
            return StringUtils.equals("0000", code);
        }
    
        public String getCode() {
            return code;
        }
    
        public void setCode(String code) {
            this.code = code;
        }
    
        public String getMsg() {
            return msg;
        }
    
        public void setMsg(String msg) {
            this.msg = msg;
        }
    }

    三、需求具体实现

      1、现在需要再增加一个切面,对需要做验签和参数校验的接口拦截并校验

        1)自定义注解,作用在controller层的方法上,标识此接口需要验签和验参,其有两个属性,一个是方法返回类型,一个是接收参数的实体类。

        方法返回类型用来切面校验不通过封装返回数据,接收参数的实体类对需要验不为空的方法标志了注解,需在切面中进行校验。

    /**
     * 对外请求参数校验注解
     *
     * @author yangyongjie
     * @date 2019/11/5
     * @desc
     */
    @Target(ElementType.METHOD)
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    public @interface Check {
    
        /**
         * 方法的返回值类型,继承了CommonResult
         */
        Class<? extends CommonResult> value();
    
        /**
         * 校验的目标实体类
         */
        Class<?> paramBean();
    
    }

    如接收参数的实体类定义:

    public class AuthTokenRequest extends BaseRequest {
    
        /**
         * 值为authorization_code
         */
        @ParamVerify(nullable = CheckEnum.NOTNULL)
        private String grant_type;
    }
    
    public class BaseRequest {
        /**
         * 签名
         */
        @ParamVerify(nullable = CheckEnum.NOTNULL)
        private String sign;
    
        /**
         * 分配的接入id
         */
        @ParamVerify(nullable = CheckEnum.NOTNULL)
        private String partnerId;
    }

    属性校验注解:

    /**
     * 字段校验注解,目前只进行非空校验,可扩展
     */
    @Retention(RetentionPolicy.RUNTIME)
    @Target(ElementType.FIELD)
    public @interface ParamVerify {
        /**
         * 是否允许为空
         */
        CheckEnum nullable() default CheckEnum.NULL;
    }

    验签验参切面:

    /**
     * 对外同步接口参数校验切面
     *
     * @author yangyongjie
     * @date 2019/11/5
     * @desc
     */
    @Order(2)
    @Aspect
    @Component
    public class CheckAspect {
    
        private static final Logger LOGGER = LoggerFactory.getLogger(CheckAspect.class);
    
        /**
         * 验签公钥
         */
        @Value("${fx.publicKey}")
        private String fxPublicKey;
    
        @Autowired
        private OutgoingPartnerInfoDao outgoingPartnerInfoDao;
    
        @Pointcut("@annotation(com.xiaomi.mitv.outgoing.common.annotation.Check)")
        private void webPointcut() {
            // donothing
        }
    
        @Around(value = "webPointcut()")
        public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
            // 获取被增强的方法的相关信息
            MethodSignature ms = (MethodSignature) joinPoint.getSignature();
            // 获取被增强的方法
            Method pointcutMethod = ms.getMethod();
            String methodName = pointcutMethod.getName();
            // 对于对外接口,统一进行参数校验
            CommonResult commonResult = null;
            // 判断方法上有没有@Check注解
            if (pointcutMethod.isAnnotationPresent(Check.class)) {
                // 获取到拦截方法的HttpServletRequest
                // 获取当前方法执行的上下文的request
                HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
                // 获取body请求参数
                String bodyString = HttpUtil.getRequestBody(request);
    //            Map<String, Object> originMap = JacksonJsonUtil.toObject(bodyString, Map.class);
                Map<String, Object> originMap = HttpUtil.fromJsonToObject(bodyString, Map.class);
                // 将请求参数放到线程本地拷贝中
                ThreadContext.currentThreadContext().setRequestParam(originMap);
    
                // 得到方法上的Check注解
                Check check = pointcutMethod.getAnnotation(Check.class);
                // 获取切点方法的返回类型
                Class<?> returnType = check.value();
                // 创建对象
                commonResult = (CommonResult) returnType.newInstance();
                // 获取参数签名
                String sign = request.getParameter("sign");
                LOGGER.info("{}-sign={}", methodName, sign);
    
                // 参数校验
                Class<?> beanType = check.paramBean();
                originMap.put("sign", sign);
                if(!HttpUtil.paramCheck(originMap, beanType)){
                    commonResult.fail(ResponseEnum.ERROR_PARAM.getCode(), ResponseEnum.ERROR_PARAM.getMsg());
                    return commonResult;
                }
    
                String partnerId = String.valueOf(originMap.get("partnerId"));
                if (!StringUtil.areNotEmpty(partnerId, sign)) {
                    commonResult.fail(ResponseEnum.ERROR_PARAM_NULL.getCode(), ResponseEnum.ERROR_PARAM_NULL.getMsg());
                    return commonResult;
                }
                // 校验partnerId的有效性
                if (!checkPartnerId(partnerId)) {
                    commonResult.fail(ResponseEnum.ERROR_APP_INVALID.getCode(), ResponseEnum.ERROR_APP_INVALID.getMsg());
                    return commonResult;
                }
                // 组装加签串
                String paramBody = HttpUtil.getAssembleParam(originMap);
                // 验签
                boolean pass;
                try {
                    pass = RSAUtil.rsa256CheckContent(paramBody, sign, fxPublicKey);
                } catch (BssException e) {
                    LogUtil.LogAndMail("验签异常", e);
                    commonResult.fail(ResponseEnum.ERROR_SYSTEM.getCode(), ResponseEnum.ERROR_SYSTEM.getMsg());
                    return commonResult;
                }
                if (!pass) {
                    commonResult.fail(ResponseEnum.ERROR_CHECK_SIGN_FAIL.getCode(), ResponseEnum.ERROR_CHECK_SIGN_FAIL.getMsg());
                    return commonResult;
                }
            }
            // 执行增强方法
            Object result = joinPoint.proceed();
            return result;
        }
    
        /**
         * 校验partnerId的有效性,先查缓存,缓存中没有的话再查询数据库,使用互斥锁
         *
         * @param partnerId
         * @return
         */
        private boolean checkPartnerId(String partnerId) {
            // 先查询缓存,值为1表示存在且有效,值为0表示存在但无效,值为null表示不存在
            String val = RedisUtil.get(CommonConstants.PARTNER_ID + partnerId);
            if (StringUtils.isEmpty(val)) {
                // 缓存中不存在,先拿到互斥锁,再查询数据库,并放进缓存中
                // 获取互斥锁
                String mutexKey = CommonConstants.NX_PARTNER_ID + partnerId;
                boolean flag = RedisUtil.setex(mutexKey, CommonConstants.STR_ONE, 60);
                // 拿到锁
                if (flag) {
                    // 查询数据库
                    OutgoingPartnerInfoDto partnerInfoDto = outgoingPartnerInfoDao.getByPartnerId(partnerId);
                    if (partnerInfoDto != null && StringUtils.equals(CommonConstants.STR_ONE, partnerInfoDto.getStatus())) {
                        // partnerId 存在且有效
                        RedisUtil.set(CommonConstants.PARTNER_ID + partnerId, CommonConstants.STR_ONE);
                        // 删除锁
                        RedisUtil.del(mutexKey);
                        return true;
                    } else {
                        // partnerId 不存在或无效
                        RedisUtil.set(CommonConstants.PARTNER_ID + partnerId, CommonConstants.STR_ZERO);
                        return false;
                    }
                } else {
                    //休息50毫秒后重试
                    try {
                        Thread.sleep(50);
                    } catch (InterruptedException e) {
                        LOGGER.error("获取partnerId互斥锁异常" + e.getMessage(), e);
                    }
                    return checkPartnerId(partnerId);
                }
                // val 不为空
            } else {
                return StringUtils.equals(CommonConstants.STR_ONE, val);
            }
        }
    
    }

    HttpUtil工具类:

    public class HttpUtil {
    
        private HttpUtil() {
        }
    
        private static final Logger LOGGER = LoggerFactory.getLogger(HttpUtil.class);
    
        /**
         * 获取request中的body信息 JSON格式
         *
         * @param request
         * @return
         */
        public static String getRequestBody(HttpServletRequest request) {
            BufferedReader br = null;
            StringBuilder bodyDataBuilder = new StringBuilder();
            try {
                br = request.getReader();
                String str;
                while ((str = br.readLine()) != null) {
                    bodyDataBuilder.append(str);
                }
                br.close();
            } catch (IOException e) {
                LOGGER.error(e.getMessage(), e);
            } finally {
                if (null != br) {
                    try {
                        br.close();
                    } catch (IOException e) {
                        LOGGER.error(e.getMessage(), e);
                    }
                }
            }
            String bodyString = bodyDataBuilder.toString();
            LOGGER.info("bodyString={}", bodyString);
            return bodyString;
        }
    
        /**
         * 获取request中的body信息,并组装好按“参数=参数值”的格式
         *
         * @param request
         * @return
         */
        public static String getAssembleRequestBody(HttpServletRequest request) {
            String bodyString = getRequestBody(request);
            Map<String, Object> originMap = JacksonJsonUtil.toObject(bodyString, Map.class);
            Map<String, Object> sortedParams = getSortedMap(originMap);
            String assembleBody = getSignContent(sortedParams);
            return assembleBody;
        }
    
        /**
         * 根据requestBody中的原始map获取解析后并组装的参数字符串,根据&符拼接
         *
         * @param originMap
         * @return
         */
        public static String getAssembleParam(Map<String, Object> originMap) {
            return getSignContent(getSortedMap(originMap));
        }
    
    
        /**
         * 将body转成按key首字母排好序
         *
         * @return
         */
        public static Map<String, Object> getSortedMap(Map<String, Object> originMap) {
            Map<String, Object> sortedParams = new TreeMap<String, Object>();
            if (originMap != null && originMap.size() > 0) {
                sortedParams.putAll(originMap);
            }
            return sortedParams;
        }
    
        /**
         * 将排序好的map的key和value拼接成字符串
         *
         * @param sortedParams
         * @return
         */
        public static String getSignContent(Map<String, Object> sortedParams) {
            StringBuffer content = new StringBuffer();
            List<String> keys = new ArrayList<String>(sortedParams.keySet());
            Collections.sort(keys);
            int index = 0;
            for (int i = 0; i < keys.size(); i++) {
                String key = keys.get(i);
                Object value = sortedParams.get(key);
                if (StringUtils.isNotEmpty(key) && value != null) {
                    content.append((index == 0 ? "" : "&") + key + "=" + value);
                    index++;
                }
            }
            return content.toString();
        }
    
        /**
         * Json转实体对象
         *
         * @param jsonStr
         * @param clazz 目标生成实体对象
         * @return
         */
        public static <T> T fromJsonToObject(String jsonStr, Class clazz) {
            T results = null;
            try {
                results = (T) JacksonJsonUtil.toObject(jsonStr, clazz);
            } catch (Exception e) {
            }
            return results;
        }
    
        /**
         * 对请求参数进行校验,目前只进行非空校验
         *
         * @param srcData body数据
         * @param tarClass 校验规则
         * @return 校验成功返回true
         */
        public static <T> boolean paramCheck(Map<String, Object> srcData, Class<T> tarClass){
            try {
                Field[] fields = tarClass.getDeclaredFields();
                for(Field field : fields){
                    ParamVerify verify = field.getAnnotation(ParamVerify.class);
                    if(verify != null){
                        //非空校验,后续若需增加校验类型,应抽离
                        if(verify.nullable() == CheckEnum.NOTNULL){
                            String fn = field.getName();
                            Object val = srcData.get(fn);
                            if(val == null || "".equals(val.toString())){
                                return false;
                            }
                        }
                    }
                }
            }catch (Exception ex){
                LOGGER.info("Param verify error");
                return false;
            }
            return true;
        }
    
    }

    日志工具类:

     /**
         * 打印日志并发送错误邮件
         *
         * @param msg
         * @param t
         */
        public static void LogAndMail(String msg, Throwable t) {
            // 获取调用此工具类的该方法 的调用方信息
            // 查询当前线程的堆栈信息
            StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
            // 按照规则,此方法的上一级调用类为
            StackTraceElement ste = stackTrace[2];
            String className = ste.getClassName();
            String methodName = ste.getMethodName();
            LOGGER.error("{}#{},{}," + t.getMessage(), className, methodName, msg, t);
            // 异步发送邮件
            String ms = "[" + ThreadContext.currentThreadContext().getThreadId() + "]" + msg;
            executor.execute(() -> SendMailUtil.sendErrorMail(ms, t, 3));
        }
    
    
        /**
         * 只发送错误邮件不打印日志
         *
         * @param msg
         */
        public static void sendErrorLogMail(String msg, Throwable t) {
            // 异步发送邮件
            String ms = "[" + ThreadContext.currentThreadContext().getThreadId() + "]" + msg + assembleStackTrace(t);
            executor.execute(() -> SendMailUtil.sendErrorMail(ms, t, 3));
        }
    
        /**
         * 组装异常堆栈
         *
         * @param t
         * @return
         */
        public static String assembleStackTrace(Throwable t) {
            StringWriter sw = new StringWriter();
            PrintWriter ps = new PrintWriter(sw);
            t.printStackTrace(ps);
            return sw.toString();
        }

    有关两个切面的执行顺序问题,请参考:https://www.cnblogs.com/yangyongjie/p/11800862.html

    END

  • 相关阅读:
    谈谈node(1)
    怎么调用html5的摄像头,录音,视频?
    es6-块级作用域let 和 var的区别
    输入手机号自动分隔
    How do I know which version of Javascript I'm using?
    PHP的类中的常量,静态变量的问题。
    【转】马拉松式学习与技术人员的成长性
    JavaScript Prototype in Plain Language
    Promise编程规范
    XMLHttpRequest对象解读
  • 原文地址:https://www.cnblogs.com/yangyongjie/p/12535938.html
Copyright © 2011-2022 走看看