ReflectUtils.java (反射工具类)
package top.icss.ioc; import java.io.File; import java.io.FileFilter; import java.io.IOException; import java.net.JarURLConnection; import java.net.URL; import java.net.URLDecoder; import java.util.ArrayList; import java.util.Enumeration; import java.util.List; import java.util.jar.JarEntry; import java.util.jar.JarFile; /** * @author cd * @desc 反射工具类 * @create 2020/3/26 11:30 * @since 1.0.0 */ public class ReflectUtils { /** * 是否循环迭代 */ private final static boolean recursive = true; /** * 扫描 包下面所有Class * @param packageName 包名称 * @param <T> * @return */ public static <T> List<Class<T>> getClass(String packageName){ List<Class<T>> list = new ArrayList<>(); String packageNamePath = packageName; packageNamePath = packageNamePath.replace(".", "/"); Enumeration<URL> resources; try { //定义一个枚举的集合 并进行循环来处理这个目录下的things resources = Thread.currentThread().getContextClassLoader().getResources(packageNamePath); //循环迭代 while (resources.hasMoreElements()){ URL url = resources.nextElement(); //得到协议的名称 String protocol = url.getProtocol(); //如果是以文件的形式保存在服务器上 if("file".equals(protocol)){ System.err.println("file类型的扫描"); String filePath = URLDecoder.decode(url.getFile(), "utf-8"); // 获取此包的目录 建立一个File File dir = new File(filePath); list.addAll(getClass(dir, packageName)); } } } catch (IOException e) { e.printStackTrace(); } return list; } /** * 迭代查找文件类 * @param filePath * @param packageName * @param <T> * @return */ private static <T> List<Class<T>> getClass(File filePath, String packageName){ List<Class<T>> classes = new ArrayList<>(); if(!filePath.exists()){ return classes; } // 如果存在 就获取包下的所有文件 包括目录 File[] files = filePath.listFiles(new FileFilter() { //自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件) @Override public boolean accept(File file) { return (recursive && file.isDirectory()) || file.getName().endsWith(".class"); } }); for (File file : files){ // 如果是目录 则继续扫描 if(file.isDirectory()){ classes.addAll(getClass(file, packageName + "." + file.getName())); }else { // 如果是java类文件 去掉后面的.class 只留下类名 String fileName = file.getName(); String className = fileName.substring(0, fileName.length() - 6); className = packageName + "." + className; try { //这里用forName有一些不好,会触发static方法,没有使用classLoader的load干净 //Class<T> cls = (Class<T>) Class.forName(className); Class<T> cls = (Class<T>) Thread.currentThread().getContextClassLoader().loadClass(className); classes.add(cls); } catch (ClassNotFoundException e) { e.printStackTrace(); } } } return classes; } public static void main(String[] args) { getClass("top.icss"); } }
MyService.java MyAutowired.java (注解类)
package top.icss.ioc.annotation; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; /** * @author cd * @desc * @create 2020/3/26 14:57 * @since 1.0.0 */ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface MyService { String value() default ""; }
package top.icss.ioc.annotation; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; /** * @author cd * @desc 自动注入 * @create 2020/3/26 14:59 * @since 1.0.0 */ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.FIELD) public @interface MyAutowired { String value() default ""; }
IocContainer.java (IOC容器)
package top.icss.ioc; import top.icss.ioc.annotation.MyAutowired; import top.icss.ioc.annotation.MyService; import java.lang.reflect.AccessibleObject; import java.lang.reflect.Field; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; /** * @author cd * @desc ioc容器 * @create 2020/3/26 14:48 * @since 1.0.0 */ public class IocContainer { private Set<Class<?>> clss = new LinkedHashSet<Class<?>>(); private Map<String, Object> beans = new ConcurrentHashMap<String, Object>(); /** * 获取类上 @MyService的注解类 * * @param packageName */ public void doScanner(String packageName) { List<Class<Object>> list = ReflectUtils.getClass(packageName); for (Class cls : list) { boolean service = cls.isAnnotationPresent(MyService.class); if (service) { clss.add(cls); } } } /** * 将class中的类实例化,经key-value:类名(小写)-类对象放入ioc字段中 */ public void doInstance() { for (Class cls : clss) { if (cls.isAnnotationPresent(MyService.class)) { MyService myService = (MyService) cls.getAnnotation(MyService.class); String beanName = ""; if(cls.isInterface()){ beanName = cls.getName(); }else { beanName = ("".equals(myService.value().trim())) ? toLowerFirstWord(cls.getSimpleName()) : myService.value(); } try { Object instance = cls.newInstance(); beans.put(beanName, instance); Class[] interfaces = cls.getInterfaces(); for (Class<?> i:interfaces){ beans.put(i.getName(), instance); } } catch (Exception e) { e.printStackTrace(); } } } } /** * 自动化的依赖注入 */ public void doAutowired(){ if(beans.isEmpty()){ return; } try { Set<Map.Entry<String, Object>> entries = beans.entrySet(); for (Map.Entry<String, Object> entry: entries){ Class<?> cls = entry.getValue().getClass(); Field[] fields = cls.getDeclaredFields(); //强制获取私有字段 AccessibleObject.setAccessible(fields,true); for (Field f: fields){ if(!f.isAnnotationPresent(MyAutowired.class)){ continue; } MyAutowired myAutowired = f.getAnnotation(MyAutowired.class); String beanName = ""; Class icls = f.getType(); if(icls.isInterface()){ beanName = icls.getName(); }else { beanName = ("".equals(myAutowired.value().trim())) ? toLowerFirstWord(icls.getName()) : myAutowired.value(); } //获取当前类实例 Object obj = entry.getValue(); //容器中获取字段实例 Object value = beans.get(beanName); f.set(obj, value); } } } catch (Exception e) { e.printStackTrace(); } } /** * 获取实例 * @param cls * @param <T> * @return */ public <T> T getBean(Class<?> cls){ MyService myService = cls.getAnnotation(MyService.class); String beanName = ""; if(cls.isInterface()){ beanName = cls.getName(); }else { beanName = ("".equals(myService.value().trim())) ? toLowerFirstWord(cls.getSimpleName()) : myService.value(); } return (T) beans.get(beanName); } /** * 将字符串首字母转换为小写 * @param name * @return */ private String toLowerFirstWord(String name) { char[] charArray = name.toCharArray(); charArray[0] += 32; return String.valueOf(charArray); } }
测试
Service1.java Service2.java (接口类)
public interface Service1 { public void print1(); } public interface Service2 { public void print2(); }
Service1Impl.java Service2Impl.java (实现类)
package top.icss.ioc.test.impl; import top.icss.ioc.annotation.MyAutowired; import top.icss.ioc.annotation.MyService; import top.icss.ioc.test.Service1; import top.icss.ioc.test.Service2; /** * @author cd * @desc * @create 2020/3/26 15:45 * @since 1.0.0 */ @MyService public class Service1Impl implements Service1 { @MyAutowired private Service2 service2; @Override public void print1() { service2.print2(); } }
package top.icss.ioc.test.impl; import top.icss.ioc.annotation.MyAutowired; import top.icss.ioc.annotation.MyService; import top.icss.ioc.test.Service1; import top.icss.ioc.test.Service2; /** * @author cd * @desc * @create 2020/3/26 15:45 * @since 1.0.0 */ @MyService public class Service2Impl implements Service2 { @MyAutowired private Service1 service1; @Override public void print2() { System.out.println("print2"); } }
IocTest.java (测试类)
public class IocTest { public static void main(String[] args) throws InterruptedException { IocContainer ioc = new IocContainer(); ioc.doScanner("top.icss.ioc.test"); ioc.doInstance(); ioc.doAutowired(); Service2 bean = ioc.getBean(Service2Impl.class); bean.print2(); } }