java也有自己的SPI实现,但是有很多小毛病,比如:会一次性加载所有扩展实现,不能支持一些复杂的元数据表达,据说多了类加载器同时加载会有并发问题(没有考证过)。所以很多框架都提供了SPI机制供使用者自己扩展,例如Dubbo,使用SPI还可以实现按需加载扩展点。之前看过Dubbo的SPI实现,其实它的整个核心功能都是围绕SPI来实现的,所以显得很复杂。接下来看一个轻量级的SPI实现——来源于公司一个生产级的框架。不多说,直接上代码:
两个注解的定义:
@Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE}) public @interface Spi { Scope scope() default Scope.PROTOTYPE; }
@Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE}) public @interface SpiMeta { String name() default ""; }
核心的扩展点加载器:
public class ExtensionLoader<T> { private static final Logger logger = LoggerFactory.getLogger(ExtensionLoader.class); private ConcurrentMap<String, T> singletonInstances = null; //存放所有扩展点类的集合,key是扩展名,每一个扩展文件中可以定义多个扩展类 private ConcurrentMap<String, Class<T>> extensionClasses = null; //加载该路径下所有扩展类 private static final String SPI_LOCATION = "META-INF/services/"; private ClassLoader classLoader; private Class<T> type; private volatile boolean init = false; //存储扩展点对象和其对应的扩展加载器 private static final Map<Class<?>, ExtensionLoader<?>> extensionLoaders = new ConcurrentHashMap<>(); //通过构造方法设置类加载器,使用Thread.currentThread().getContextClassLoader(),如果没有setContextClassLoader()则为系统类加载器AppClassLoader private ExtensionLoader(Class<T> type) { this(type, Thread.currentThread().getContextClassLoader()); } private ExtensionLoader(Class<T> type, ClassLoader classLoader) { this.type = type; this.classLoader = classLoader; } @SuppressWarnings("unchecked") public static <T> ExtensionLoader<T> getExtensionLoader(Class<T> type) { //根据扩展点类型获取对应的扩展加载器 ExtensionLoader<T> loader = (ExtensionLoader<T>) extensionLoaders.get(type); if (loader == null) { //如果没有该扩展加载器,则进行初始化 loader = initExtensionLoader(type); } return loader; } //初始化方法是静态的且加了锁,相当于在类对象上加锁,可以保证同时只有一个扩展器的初始化操作 @SuppressWarnings("unchecked") private static synchronized <T> ExtensionLoader<T> initExtensionLoader(Class<T> type) { ExtensionLoader<T> loader = (ExtensionLoader<T>) extensionLoaders.get(type); if (loader == null) { loader = new ExtensionLoader<>(type); extensionLoaders.putIfAbsent(type, loader); loader = (ExtensionLoader<T>) extensionLoaders.get(type); } return loader; } @SuppressWarnings("unchecked") public List<T> getExtensions() { checkAndInit(); List<T> extensions = new ArrayList<>(extensionClasses.size()); for (Map.Entry<String, Class<T>> entry : extensionClasses.entrySet()) { extensions.add(getExtension(entry.getKey())); } extensions.sort(new ExtensionOrderComparator<T>()); return extensions; } //根据扩展名获取扩展点对象 public T getExtension(String name) { checkAndInit(); if (name == null) { return null; } try { //注意,@Spi是使用在扩展点接口上的,@SpiMeta是使用在实现类上的 Spi spi = type.getAnnotation(Spi.class); //单例类型 if (spi.scope() == Scope.SINGLETON) { return getSingletonInstance(name); } else { //原型类型 Class<T> clz = extensionClasses.get(name); if (clz == null) { return null; } return clz.newInstance(); } } catch (Exception e) { new RuntimeException(type.getName() + ":Error when getExtension " + name, e); } return null; } @SuppressWarnings("unchecked") private T getSingletonInstance(String name) throws InstantiationException, IllegalAccessException { T obj = singletonInstances.get(name); if (obj != null) { return obj; } Class<T> clz = extensionClasses.get(name); if (clz == null) { return null; } //加锁对象为集合对象,确保只有一个线程能创建扩展点对象 synchronized (singletonInstances) { obj = singletonInstances.get(name); if (obj != null) { return obj; } obj = clz.newInstance(); singletonInstances.put(name, obj); } return obj; } private void checkAndInit() { //init被volatile修饰,确保只有一个线程进行初始化 if (!init) { loadExtensionClasses(); } } //这里在方法级别加锁,锁对象是扩展点对应的扩展类加载器对象 private synchronized void loadExtensionClasses() { if (init) { return; } //将META-INF/services/目录下的扩展点加载进集合中 extensionClasses = loadExtensionClasses(SPI_LOCATION); singletonInstances = new ConcurrentHashMap<>(); init = true; } private ConcurrentMap<String, Class<T>> loadExtensionClasses(String prefix) { //根据前缀和类的全限定名来读取文件,所以这里注意扩展点的文件名称必须是类的全限定名 String fullName = prefix + type.getName(); List<String> classNames = new ArrayList<String>(); try { Enumeration<URL> urls; if (classLoader == null) { urls = ClassLoader.getSystemResources(fullName); } else { urls = classLoader.getResources(fullName); } if (urls == null || !urls.hasMoreElements()) { return new ConcurrentHashMap<>(); } while (urls.hasMoreElements()) { URL url = urls.nextElement(); //解析类 parseUrl(type, url, classNames); } } catch (Exception e) { throw new RuntimeException("ExtensionLoader loadExtensionClasses error, prefix: " + prefix + " type: " + type.getClass(), e); } //将类加载进内存,并放入集合中 return loadClass(classNames); } @SuppressWarnings("unchecked") private ConcurrentMap<String, Class<T>> loadClass(List<String> classNames) { ConcurrentMap<String, Class<T>> map = new ConcurrentHashMap<String, Class<T>>(); for (String className : classNames) { try { Class<T> clz; if (classLoader == null) { //classLoader为空,使用加载当前类的类加载器进行加载 clz = (Class<T>) Class.forName(className); } else { clz = (Class<T>) Class.forName(className, true, classLoader); } checkExtensionType(clz); String spiName = getSpiName(clz); if (map.containsKey(spiName)) { new RuntimeException(clz.getName() + ":Error spiName already exist " + spiName); } else { map.put(spiName, clz); } } catch (Exception e) { logger.error(type.getName() + ":" + "Error load spi class", e); } } return map; } private void checkExtensionType(Class<T> clz) { checkClassPublic(clz); checkConstructorPublic(clz); checkClassInherit(clz); } private void checkClassPublic(Class<T> clz) { if (!Modifier.isPublic(clz.getModifiers())) { new RuntimeException(clz.getName() + ":Error is not a public class"); } } private void checkClassInherit(Class<T> clz) { if (!type.isAssignableFrom(clz)) { new RuntimeException(clz.getName() + ":Error is not instanceof " + type.getName()); } } private void checkConstructorPublic(Class<T> clz) { Constructor<?>[] constructors = clz.getConstructors(); if (constructors == null || constructors.length == 0) { new RuntimeException(clz.getName() + ":Error has no public no-args constructor"); } for (Constructor<?> constructor : constructors) { if (Modifier.isPublic(constructor.getModifiers()) && constructor.getParameterTypes().length == 0) { return; } } new RuntimeException(clz.getName() + ":Error has no public no-args constructor"); } public String getSpiName(Class<?> clz) { SpiMeta spiMeta = clz.getAnnotation(SpiMeta.class); //如果SpiMeta中没有定义name属性,则使用类型,如@SpiMeta(name = "coreSamplePrinter") return (spiMeta != null && !"".equals(spiMeta.name())) ? spiMeta.name() : clz.getSimpleName(); } private void parseUrl(Class<T> type, URL url, List<String> classNames) throws ServiceConfigurationError { InputStream inputStream = null; BufferedReader reader = null; try { inputStream = url.openStream(); reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)); String line; int indexNumber = 0; while ((line = reader.readLine()) != null) { indexNumber++; parseLine(type, url, line, indexNumber, classNames); } } catch (Exception x) { logger.error(type.getName() + ":" + "Error reading spi configuration file", x); } finally { try { if (reader != null) { reader.close(); } if (inputStream != null) { inputStream.close(); } } catch (IOException y) { logger.error(type.getName() + ":" + "Error closing spi configuration file", y); } } } private void parseLine(Class<T> type, URL url, String line, int lineNumber, List<String> names) throws IOException, ServiceConfigurationError { int ci = line.indexOf('#'); //可以使用#在扩展文件后写一些说明 if (ci >= 0) { line = line.substring(0, ci); } line = line.trim(); if (line.length() <= 0) { return; } if ((line.indexOf(' ') >= 0) || (line.indexOf(' ') >= 0)) { throw new RuntimeException(type.getName() + ": " + "Illegal spi configuration-file syntax"); } int cp = line.codePointAt(0); if (!Character.isJavaIdentifierStart(cp)) { throw new RuntimeException(type.getName() + ": " + url + ": " + line + ": " + "Illegal spi provider-class name: " + line); } for (int i = Character.charCount(cp); i < line.length(); i += Character.charCount(cp)) { cp = line.codePointAt(i); if (!Character.isJavaIdentifierPart(cp) && (cp != '.')) { throw new RuntimeException(type.getName() + ": " + url + ": " + line + ": " + "Illegal spi provider-class name: " + line); } } if (!names.contains(line)) { names.add(line); } } }
主要的功能点和细节都有注释,就不多说了。接下来看一些应用吧。
注解的使用:
@Spi(scope = Scope.SINGLETON) public interface HealthPrinter { void print(Set<HealthStats> healthStats, String timestamp); }
@SpiMeta(name = "jedisClusterHealthPrinter") public class JedisClusterHealthPrinter extends AbstractHealthPrinter {
扩展文件的使用:
总结:SPI只是一种思想,可以根据实际需要定制化实现。