zoukankan      html  css  js  c++  java
  • 手写一个简单版的SpringMVC

    一 写在前面

    这是自己实现一个简单的具有SpringMVC功能的小Demo,主要实现效果是;

    自己定义的实现效果是通过浏览器地址传一个name参数,打印“my name is”+name参数。不使用SpringMVC,自己定义部分注解,实现DispatcherServlet核心功能,通过这个demo可以加深自己对源码的理解。

    先看一下实现效果:

    (传入了参数时)

    (没有传入参数时)

    二  DispatcherServlet流程

    1. 加载配置文件
    2. 扫描所有相关类
    3. 初始化所有相关的类
    4. 自动注入
    5. 初始化HandlerMapping
    6. 等待请求

    三 代码回顾

    1.首先来看一下Pom文件的依赖:

    <dependencies>
      <dependency>
        <groupId>javax.servlet</groupId>
        <artifactId>servlet-api</artifactId>
        <version>2.5</version>
      </dependency>
      <dependency>
        <groupId>org.apache.commons</groupId>
        <artifactId>commons-lang3</artifactId>
        <version>3.10</version>
      </dependency>
      <dependency>
        <groupId>org.projectlombok</groupId>
        <artifactId>lombok</artifactId>
        <version>1.18.12</version>
      </dependency>
      <dependency>
        <groupId>ch.qos.logback</groupId>
        <artifactId>logback-core</artifactId>
        <version>1.2.3</version>
      </dependency>
      <dependency>
        <groupId>ch.qos.logback</groupId>
        <artifactId>logback-classic</artifactId>
        <version>1.2.3</version>
      </dependency>
    </dependencies>
    View Code

    依赖比较少,没有spring的依赖,主要就是一个servlet的。


    2. 配置文件:

    2.1. application.properties文件:

    scanPackage=com.qunar.framework.demo
    View Code

    这是说明要扫描的位置。

     2.2. web.xml文件:

    <!DOCTYPE web-app PUBLIC
     "-//Sun Microsystems, Inc.//DTD Web Application 2.3//EN"
     "http://java.sun.com/dtd/web-app_2_3.dtd" >
     
    <web-app>
      <display-name>MySpringMVC</display-name>
      <servlet>
        <servlet-name>mvc</servlet-name>
        <servlet-class>com.qunar.framework.webmvc.DispatcherServlet</servlet-class>
        <init-param>
          <param-name>contextConfigLocation</param-name>
          <param-value>/application.properties</param-value>
        </init-param>
        <load-on-startup>1</load-on-startup>
      </servlet>
      <servlet-mapping>
        <servlet-name>mvc</servlet-name>
        <url-pattern>/*</url-pattern>
      </servlet-mapping>
    </web-app>

    3. 下面是整个工程的目录结构:

    4. 自定义注解:

    @Controller:

    @Target(ElementType.TYPE)
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    public @interface Controller {
        String value() default "";
    }
    View Code

    @Service:

    @Target(ElementType.TYPE)
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    public @interface Service {
        String value() default "";
    }
    View Code

    @AutoWired:

    @Target(ElementType.FIELD)
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    public @interface Autowired {
        String value() default "";
    }
    View Code

    @RequestMapping:

    @Target(ElementType.FIELD)
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    public @interface Autowired {
        String value() default "";
    }
    View Code

    @RequestParam:

    @Target(ElementType.PARAMETER)
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    public @interface RequestParam {
        String value() default  "";
    }
    View Code

    5.自己封装的Handler:

    public class Handler {
        protected Object controller;
        protected Method method;
        protected Pattern pattern;
        protected Map<String,Integer> paramIndexMap;
     
        public Handler(Object controller, Method method, Pattern pattern) {
            this.controller = controller;
            this.method = method;
            this.pattern = pattern;
            this.paramIndexMap = new HashMap<>();
            putParamIndexMapping(method);
        }
     
        private void putParamIndexMapping(Method method) {
            //获取方法中加了注解的参数
            Annotation[][] annotations = method.getParameterAnnotations();
            for (int i =0; i < annotations.length;i++){
                for (Annotation annotation : annotations[i]){
                    if (annotation instanceof RequestParam){
                        String paramName = ((RequestParam) annotation).value();
                        if (!StringUtils.isBlank(paramName)){
                            paramIndexMap.put(paramName,i);
                        }
                    }
                }
            }
            //获取方法中的我request和response的参数
            Class<?>[] paramTypes = method.getParameterTypes();
            for (int i = 0; i < paramTypes.length; i++){
                Class<?> paramType = paramTypes[i];
                if (paramType == HttpServletRequest.class || paramType == HttpServletResponse.class){
                    paramIndexMap.put(paramType.getName(),i);
                }
            }
        }
    }
    View Code

    6. 自己封装的DispatcherServlet:

    @Slf4j
    public class DispatcherServlet extends HttpServlet {
        private static final long serialVersionUID = 1L;
        private Properties contextConfig = new Properties();
        private List<String> classNames = new ArrayList<>();
        private Map<String, Object> iocMap = new HashMap<>();
        private List<Handler> handlerMapping = new ArrayList<>();
     
        @Override
        protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
            this.doPost(req, resp);
        }
     
        @Override
        protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException {
            //等待请求
            try {
                doDispatch(req, resp);
            } catch (Exception exception) {
                resp.getWriter().write("500 Exception");
                log.error("500 Exception. Cause: {}", exception.getMessage());
                exception.printStackTrace();
            }
        }
     
        private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception {
            Handler handler = getHandler(req);
            if (handler == null) {
                //没有匹配上,404
                log.info("404 Not Found");
                resp.getWriter().write("404 Not Found");
                return;
            }
            //获取参数列表
            Class<?>[] parameterTypes = handler.method.getParameterTypes();
            //保存所有需要自动赋值的参数值
            Object[] parameterValues = new Object[parameterTypes.length];
     
            Map<String, String[]> parameterMap = req.getParameterMap();
            for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
                String value = Arrays.toString(entry.getValue()).replaceAll("\[|\]", "").replaceAll("/+", "/");
                log.info(value);
                //如果找到了匹配的值,就填充
                if (!handler.paramIndexMap.containsKey(entry.getKey())) {
                    continue;
                }
                Integer index = handler.paramIndexMap.get(entry.getKey());
                parameterValues[index] = convert(parameterTypes[index], value);
            }
            //设置方法中的request对象和response对象
            Integer reqIndex = handler.paramIndexMap.get(HttpServletRequest.class.getName());
            Integer respIndex = handler.paramIndexMap.get(HttpServletResponse.class.getName());
            parameterValues[reqIndex] = req;
            parameterValues[respIndex] = resp;
            handler.method.invoke(handler.controller, parameterValues);
        }
     
        private Object convert(Class<?> parameterType, String value) {
            if (parameterType == Integer.class) {
                return Integer.valueOf(value);
            }
            return value;
        }
     
        private Handler getHandler(HttpServletRequest req) {
            if (handlerMapping.isEmpty()) {
                return null;
            }
            String requestURI = req.getRequestURI();
            String contextPath = req.getContextPath();
            requestURI = requestURI.replace(contextPath, "").replaceAll("/+", "/");
            for (Handler handler : handlerMapping) {
                Matcher matcher = handler.pattern.matcher(requestURI);
                if (!matcher.matches()) {
                    continue;
                }
                return handler;
            }
            return null;
        }
     
        @Override
        public void init(ServletConfig config) {
            //从这里开始启动:
            //加载配置文件
            loadConfig(config.getInitParameter("contextConfigLocation"));
            //扫描相关类
            doScanner(contextConfig.getProperty("scanPackage"));
            //初始化相关类
            try {
                doInstance();
            } catch (Exception exception) {
                log.error("Execute doInstance method fail.");
                exception.printStackTrace();
            }
            //自动注入
            doAutowired();
            //初始化HandlerMapping
            initHandlerMapping();
        }
     
        private void initHandlerMapping() {
            if (iocMap.isEmpty()) {
                return;
            }
            for (Map.Entry<String, Object> entry : iocMap.entrySet()) {
                Class<?> clazz = entry.getValue().getClass();
                if (!clazz.isAnnotationPresent(Controller.class)) {
                    continue;
                }
                String baseUrl = "";
                if (clazz.isAnnotationPresent(RequestMapping.class)) {
                    RequestMapping requestMapping = clazz.getAnnotation(RequestMapping.class);
                    baseUrl = requestMapping.value();
                }
                //扫描所有的公共方法
                for (Method method : clazz.getMethods()) {
                    if (!method.isAnnotationPresent(RequestMapping.class)) {
                        continue;
                    }
                    RequestMapping requestMapping = method.getAnnotation(RequestMapping.class);
                    String regex = ("/" + baseUrl + requestMapping.value()).replaceAll("/+", "/");
                    Pattern pattern = Pattern.compile(regex);
                    handlerMapping.add(new Handler(entry.getValue(), method, pattern));
                    log.info("Mapping: {}.{}", regex, method);
                }
            }
        }
     
        private void doAutowired() {
            if (iocMap.isEmpty()) {
                return;
            }
            //循环所有的类,对需要自动赋值的属性进行赋值
            for (Map.Entry<String, Object> entry : iocMap.entrySet()) {
                Field[] fields = entry.getValue().getClass().getDeclaredFields();
                for (Field field : fields) {
                    if (!field.isAnnotationPresent(Autowired.class)) {
                        continue;
                    }
                    Autowired autowired = field.getAnnotation(Autowired.class);
                    String beanName = autowired.value();
                    if (beanName != null) {
                        beanName = beanName.trim();
                    }
                    if (StringUtils.isBlank(beanName)) {
                        beanName = field.getType().getName();
                    }
                    field.setAccessible(true);
                    try {
                        field.set(entry.getValue(), iocMap.get(beanName));
                    } catch (IllegalAccessException e) {
                        log.error("AutoWired fail,beanName: {}", beanName);
                        e.printStackTrace();
                        continue;
                    }
                }
            }
        }
     
        private void doInstance() throws Exception {
            if (classNames.isEmpty()) {
                return;
            }
            for (String className : classNames) {
                Class<?> clazz = Class.forName(className);
                //如果自定义了名字,就优先使用自己的名字,否则默认是小写(这里就不默认首字母为小写了
                if (clazz.isAnnotationPresent(Controller.class)) {
                    Controller controller = clazz.getAnnotation(Controller.class);
                    String beanName = controller.value();
                    if (StringUtils.isBlank(beanName)) {
                        beanName = clazz.getName().toLowerCase();
                    }
                    Object instance = clazz.newInstance();
                    iocMap.put(beanName, instance);
                } else if (clazz.isAnnotationPresent(Service.class)) {
                    Service service = clazz.getAnnotation(Service.class);
                    String beanName = service.value();
                    if (StringUtils.isBlank(beanName)) {
                        beanName = clazz.getName().toLowerCase();
                    }
                    Object instance = clazz.newInstance();
                    iocMap.put(beanName, instance);
                    //根据接口类型来赋值
                    for (Class<?> clazzInterface : clazz.getInterfaces()) {
                        iocMap.put(clazzInterface.getName(), instance);
                    }
                } else {
                    continue;
                }
            }
        }
     
        private void doScanner(String scanPackage) {
            URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\.", "/"));
            File classDir = new File(url.getFile());
            for (File file : classDir.listFiles()) {
                if (file.isDirectory()) {
                    doScanner(scanPackage + "." + file.getName());
                } else {
                    String className = scanPackage + "." + file.getName().replace(".class", "");
                    classNames.add(className);
                }
            }
        }
     
        private void loadConfig(String location) {
            InputStream inputStream = this.getClass().getResourceAsStream(location);
            try {
                contextConfig.load(inputStream);
            } catch (IOException e) {
                log.error("Load fail, location: {}", location);
                e.printStackTrace();
            } finally {
                if (inputStream != null) {
                    try {
                        inputStream.close();
                    } catch (IOException e) {
                        log.error("Close fail, inputStream: {}", inputStream);
                        e.printStackTrace();
                    }
                }
            }
        }
    }
    View Code

    这个类就是最核心的类,它做了SpringMVC的事情。

    7.下面是验证自己SpringMVC是否可用的时候了,自己写了service和controller:

    7.1 service:

    public class DemoServiceImpl implements IDemoService {
        @Override
        public String get(String name) {
            return "my name is " + name;
        }
    }
    View Code

    7.2 controller:

    @Controller
    @RequestMapping("/demo")
    @Slf4j
    public class DemoController {
        @Autowired
        IDemoService service;
     
        @RequestMapping("/get")
        public void get(HttpServletRequest req, HttpServletResponse resp, @RequestParam("name") String name) {
            String res = service.get(name);
            try {
                resp.setContentType("text/html;charset=UTF-8");
                resp.getWriter().println(res);
            } catch (IOException e) {
                log.info(e.getMessage());
                e.printStackTrace();
            }
        }
    }
    View Code

    再结合开头贴出来的图片,验证了自己的这个SpringMVC是可以使用的。

    四 最后

    这里只要实现了SpringMVC最简单的功能而已。这只是一个加深自己对SpringMVC的mapping映射流程的理解而已,真正的SpringMVC当然远不止如此简单。

    Demo的github地址:https://github.com/Happy-Ape/Spring

  • 相关阅读:
    CTSC2018滚粗记
    HNOI2018游记
    NOIWC 2018游记
    PKUWC2018滚粗记
    HNOI2017 游记
    NOIP2017题解
    [HNOI2017]抛硬币
    [HNOI2017]大佬
    NOIP难题汇总
    [NOI2013]树的计数
  • 原文地址:https://www.cnblogs.com/ericz2j/p/13553719.html
Copyright © 2011-2022 走看看