简介
以前用了下SpringMVC感觉挺不错了,前段事件也简单了写了一些代码来实现了SpringMVC简单的请求分发功能,实现的主要思想如下:
- 将处理请求的类在系统启动的时候加载起来,相当于SpringMVC中的Controller
- 读取Controller中的配置并对应其处理的URL
- 通过调度Servlet进行拦截请求,并找到相应的Controller进行处理
主要代码
首先得标识出来哪些类是Controller类,这里我自己定义的是ServletHandler,通过Annotation的方式进行标识,并配置每个类和方法处理的URL:
package com.meet58.base.servlet.annotation; public @interface ServletHandler { }
这里注解主要是声明这个类是一个ServletHandler类,用于处理请求的类,系统启动的时候就会加载这些类。
package com.meet58.base.servlet.annotation; import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import com.meet58.base.servlet.types.RequestMethod; import com.meet58.base.servlet.types.ResponseType; @Target({ElementType.TYPE, ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface HandlerMapping { String value(); }
这个注解是配置处理请求的注解,定义了要处理的路径。
定义了注解之后就是要在系统启动的时候扫描并加载这些类,下面是如何进行扫描的代码:
package com.meet58.base.servlet.mapping; import java.io.IOException; import org.springframework.core.io.Resource; import org.springframework.core.io.support.PathMatchingResourcePatternResolver; import org.springframework.core.io.support.ResourcePatternResolver; import org.springframework.core.io.support.ResourcePatternUtils; import org.springframework.core.type.classreading.CachingMetadataReaderFactory; import org.springframework.core.type.classreading.MetadataReader; import org.springframework.core.type.classreading.MetadataReaderFactory; import org.springframework.core.type.filter.AnnotationTypeFilter; import org.springframework.core.type.filter.TypeFilter; import org.springframework.util.ClassUtils; import com.meet58.base.servlet.annotation.ServletHandler; public class ServletHandlerMappingResolver { private static final String RESOURCE_PATTERN = "/**/*.class"; private String[] packagesToScan; private ResourcePatternResolver resourcePatternResolver; private static final TypeFilter[] ENTITY_TYPE_FILTERS = new TypeFilter[] { new AnnotationTypeFilter(ServletHandler.class, false)}; public ServletHandlerMappingResolver(){ this.resourcePatternResolver = ResourcePatternUtils.getResourcePatternResolver(new PathMatchingResourcePatternResolver()); } public ServletHandlerMappingResolver scanPackages(String[] packagesToScan){ try { for (String pkg : packagesToScan) { String pattern = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX + ClassUtils.convertClassNameToResourcePath(pkg) + RESOURCE_PATTERN; Resource[] resources; resources = this.resourcePatternResolver.getResources(pattern); MetadataReaderFactory readerFactory = new CachingMetadataReaderFactory(this.resourcePatternResolver); for (Resource resource : resources) { if (resource.isReadable()) { MetadataReader reader = readerFactory.getMetadataReader(resource); String className = reader.getClassMetadata().getClassName(); if (matchesFilter(reader, readerFactory)) { ServletHandlerMappingFactory.addClassMapping(this.resourcePatternResolver.getClassLoader().loadClass(className)); } } } } } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (ClassNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } return this; } public String[] getPackagesToScan() { return packagesToScan; } public void setPackagesToScan(String[] packagesToScan) { this.packagesToScan = packagesToScan; this.scanPackages(packagesToScan); } private boolean matchesFilter(MetadataReader reader, MetadataReaderFactory readerFactory) throws IOException { for (TypeFilter filter : ENTITY_TYPE_FILTERS) { if (filter.match(reader, readerFactory)) { return true; } } return false; } }
这段代码是Spring中如何扫描Hibernate持久化对象的代码,拿过来借鉴了一下,下面要处理的就是把要处理的URL和相对应的ServletHandler进行匹配:
package com.meet58.base.servlet.mapping; import java.lang.reflect.Method; import java.util.HashMap; import java.util.Map; import org.apache.log4j.Logger; import com.meet58.base.servlet.annotation.HandlerMapping; import com.meet58.base.servlet.context.ServletHandlerFactory; public class ServletHandlerMappingFactory { private static Logger logger = Logger.getLogger(ServletHandlerMappingFactory.class); private static Map<String, Method> servletHandlerMapping = new HashMap<String, Method>(); public static void addClassMapping(Class<?> clazz) { String url = null; HandlerMapping handlerMapping = clazz.getAnnotation(HandlerMapping.class); if (handlerMapping != null) { url = handlerMapping.value(); } else { String classSimpleName = clazz.getSimpleName().toLowerCase(); url = "/" + classSimpleName.substring(0, classSimpleName.indexOf("servlet")); } if (url != null) { if(url.endsWith("/")){ url = url.substring(url.length() - 1); } ServletHandlerFactory.put(clazz); logger.info(" Load servlet handler class:" + clazz.getName() + " url:" + url); scanHandlerMethod(clazz,url); } } public static void scanHandlerMethod(Class<?> clazz,String classMapping) { Method[] methods = clazz.getDeclaredMethods(); for (Method method : methods) { HandlerMapping handlerMapping = method.getAnnotation(HandlerMapping.class); if (handlerMapping != null && handlerMapping.value() != null) { String mapping = handlerMapping.value(); if(!mapping.startsWith("/")){ mapping = "/" + mapping; } mapping = classMapping + mapping; addMethodMapping( mapping,method); } } } public static void addMethodMapping(String url,Method method) { logger.info(" Load servlet handler mapping, method:" + method.getName() + " for url:" + url); Method handlerMethod = servletHandlerMapping.get(url); if(handlerMethod != null){ throw new IllegalArgumentException(" url :" + url + " is already mapped by :" + handlerMethod); }else{ servletHandlerMapping.put(url, method); } } public static Method getMethodMapping(String url) { return servletHandlerMapping.get(url); } }
在这个类中扫描了每个ServletHandler类中的方法,并记录他们的要处理的URL,接下来就是通过容器实例化这些ServletHandler类了:
package com.meet58.base.servlet.context; import java.util.HashMap; import java.util.Map; import org.apache.log4j.Logger; public class ServletHandlerFactory { private static Logger logger = Logger.getLogger(ServletHandlerFactory.class); private static Map<String,Object> classes = new HashMap<String,Object>(); public static void put(Class<?> clazz){ try { logger.info("初始化ServletHandler类:"+ clazz.getName()); Object servlet = clazz.newInstance(); classes.put(clazz.getName(), servlet); } catch (InstantiationException e) { logger.error("初始化Servlet类:" + clazz.getName() + "失败:" + e.getMessage()); } catch (IllegalAccessException e) { logger.error("初始化Servlet类:" + clazz.getName() + "失败:" + e.getMessage()); } } @SuppressWarnings("unchecked") public static <T> T get(String className){ return (T)classes.get(className); } }
在ServletHandler类处理完成,并知道他们分别处理哪些URL之后,就可以通过一个调度器进行对对应的URL进行请求的分发了:
package com.meet58.base.servlet; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; import javax.servlet.ServletException; import javax.servlet.annotation.WebServlet; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.log4j.Logger; import com.meet58.base.context.WebHttpRequestContext; import com.meet58.base.servlet.context.ServletHandlerFactory; import com.meet58.base.servlet.mapping.ServletHandlerMappingFactory; import com.meet58.util.WebUtils; @WebServlet(urlPatterns = { "*.do" }) public class WebHttpDispatchServlet extends HttpServlet { private static final long serialVersionUID = 1L; private Logger logger = Logger.getLogger(this.getClass()); private List<String> excludeUrls = new ArrayList<String>(); @Override public void init() throws ServletException { // 屏蔽websocket地址 excludeUrls.add("/meet.do"); super.init(); } public void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { this.doPost(request, response); } public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { try { String url = request.getRequestURI().replace( request.getContextPath(), ""); if (excludeUrls.contains(url)) { return; } Method handlerMethod = ServletHandlerMappingFactory.getMethodMapping(url); if (handlerMethod == null) { response.sendError(404, "No handler found for " + url); logger.error("No handler found for " + url); return; } Object servlet = ServletHandlerFactory.get(handlerMethod .getDeclaringClass().getName()); if (servlet == null) { response.sendError(404, "No handler class found for " + url); logger.error("No handler class found for " + url); return; } Object result = invokeHandlerMethod(servlet, handlerMethod); handleInvokeResult(result); // this.doService(); } catch (Throwable e) { handlerException(e); } } public void handleInvokeResult(Object result) { String location = ""; if (result instanceof String) { if (((String) result).startsWith("redirect:")) { location = ((String) result).substring("redirect:".length(), ((String) result).length()); WebUtils.redirect(location); } else if (((String) result).startsWith("forward:")) { location = ((String) result).substring("forward:".length(), ((String) result).length()); WebUtils.forward(location); } } } public Object invokeHandlerMethod(Object object, Method method) throws Throwable { Object result = null; if (method != null) { try { result = method.invoke(object); } catch (InvocationTargetException e) { throw e.getTargetException(); } } return result; } public void handlerException(Throwable e) { String message = e.getMessage() != null ? e.getMessage() : e.toString(); e.printStackTrace(); if (WebHttpRequestContext.isAsyncRequest()) { WebUtils.writeFailure(message); } else { try { WebHttpRequestContext.getResponse().sendError(500, message); } catch (IOException e1) { e1.printStackTrace(); } } } public String getMappingClass(String url) { return null; } }
这段代码中就是通过URL找到对应的处理方法来进行处理,并且捕获异常。
这种方法Struts也是用到了,不过这个只是简单的兴趣研究并没有在实际项目中运用,可能会存在线程安全的问题。