zoukankan      html  css  js  c++  java
  • 手写一个RPC框架

    一、前言

    前段时间看到一篇不错的文章《看了这篇你就会手写RPC框架了》,于是便来了兴趣对着实现了一遍,后面觉得还有很多优化的地方便对其进行了改进。

    主要改动点如下:

    1. 除了Java序列化协议,增加了protobuf和kryo序列化协议,配置即用。
    2. 增加多种负载均衡算法(随机、轮询、加权轮询、平滑加权轮询),配置即用。
    3. 客户端增加本地服务列表缓存,提高性能。
    4. 修复高并发情况下,netty导致的内存泄漏问题
    5. 由原来的每个请求建立一次连接,改为建立TCP长连接,并多次复用。
    6. 服务端增加线程池提高消息处理能力

    二、介绍

    RPC,即 Remote Procedure Call(远程过程调用),调用远程计算机上的服务,就像调用本地服务一样。RPC可以很好的解耦系统,如WebService就是一种基于Http协议的RPC。

    调用示意图
    调用示意图

    总的来说,就如下几个步骤:

    1. 客户端(ServerA)执行远程方法时就调用client stub传递类名、方法名和参数等信息。
    2. client stub会将参数等信息序列化为二进制流的形式,然后通过Sockect发送给服务端(ServerB)
    3. 服务端收到数据包后,server stub 需要进行解析反序列化为类名、方法名和参数等信息。
    4. server stub调用对应的本地方法,并把执行结果返回给客户端

    所以一个RPC框架有如下角色:

    服务消费者

    远程方法的调用方,即客户端。一个服务既可以是消费者也可以是提供者。

    服务提供者

    远程服务的提供方,即服务端。一个服务既可以是消费者也可以是提供者。

    注册中心

    保存服务提供者的服务地址等信息,一般由zookeeper、redis等实现。

    监控运维(可选)

    监控接口的响应时间、统计请求数量等,及时发现系统问题并发出告警通知。

    三、实现

    本RPC框架rpc-spring-boot-starter涉及技术栈如下:

    • 使用zookeeper作为注册中心
    • 使用netty作为通信框架
    • 消息编解码:protostuff、kryo、java
    • spring
    • 使用SPI来根据配置动态选择负载均衡算法等

    由于代码过多,这里只讲几处改动点。

    3.1动态负载均衡算法

    1.编写LoadBalance的实现类

    负载均衡算法实现类
    负载均衡算法实现类

    2.自定义注解 @LoadBalanceAno

    1. /** 
    2. * 负载均衡注解 
    3. */ 
    4. @Target(ElementType.TYPE) 
    5. @Retention(RetentionPolicy.RUNTIME) 
    6. @Documented 
    7. public @interface LoadBalanceAno { 
    8.  
    9. String value() default ""; 
    10. } 
    11.  
    12. /** 
    13. * 轮询算法 
    14. */ 
    15. @LoadBalanceAno(RpcConstant.BALANCE_ROUND) 
    16. public class FullRoundBalance implements LoadBalance { 
    17.  
    18. private static Logger logger = LoggerFactory.getLogger(FullRoundBalance.class); 
    19.  
    20. private volatile int index; 
    21.  
    22. @Override 
    23. public synchronized Service chooseOne(List<Service> services) { 
    24. // 加锁防止多线程情况下,index超出services.size() 
    25. if (index == services.size()) { 
    26. index = 0; 
    27. } 
    28. return services.get(index++); 
    29. } 
    30. } 

    3.新建在resource目录下META-INF/servers文件夹并创建文件

    enter description here
    enter description here

    4.RpcConfig增加配置项loadBalance

    1. /** 
    2. * @author 2YSP 
    3. * @date 2020/7/26 15:13 
    4. */ 
    5. @ConfigurationProperties(prefix = "sp.rpc") 
    6. public class RpcConfig { 
    7.  
    8. /** 
    9. * 服务注册中心地址 
    10. */ 
    11. private String registerAddress = "127.0.0.1:2181"; 
    12.  
    13. /** 
    14. * 服务暴露端口 
    15. */ 
    16. private Integer serverPort = 9999; 
    17. /** 
    18. * 服务协议 
    19. */ 
    20. private String protocol = "java"; 
    21. /** 
    22. * 负载均衡算法 
    23. */ 
    24. private String loadBalance = "random"; 
    25. /** 
    26. * 权重,默认为1 
    27. */ 
    28. private Integer weight = 1; 
    29.  
    30. // 省略getter setter 
    31. } 

    5.在自动配置类RpcAutoConfiguration根据配置选择对应的算法实现类

    1. /** 
    2. * 使用spi匹配符合配置的负载均衡算法 
    3. * 
    4. * @param name 
    5. * @return 
    6. */ 
    7. private LoadBalance getLoadBalance(String name) { 
    8. ServiceLoader<LoadBalance> loader = ServiceLoader.load(LoadBalance.class); 
    9. Iterator<LoadBalance> iterator = loader.iterator(); 
    10. while (iterator.hasNext()) { 
    11. LoadBalance loadBalance = iterator.next(); 
    12. LoadBalanceAno ano = loadBalance.getClass().getAnnotation(LoadBalanceAno.class); 
    13. Assert.notNull(ano, "load balance name can not be empty!"); 
    14. if (name.equals(ano.value())) { 
    15. return loadBalance; 
    16. } 
    17. } 
    18. throw new RpcException("invalid load balance config"); 
    19. } 
    20.  
    21. @Bean 
    22. public ClientProxyFactory proxyFactory(@Autowired RpcConfig rpcConfig) { 
    23. ClientProxyFactory clientProxyFactory = new ClientProxyFactory(); 
    24. // 设置服务发现着 
    25. clientProxyFactory.setServerDiscovery(new ZookeeperServerDiscovery(rpcConfig.getRegisterAddress())); 
    26.  
    27. // 设置支持的协议 
    28. Map<String, MessageProtocol> supportMessageProtocols = buildSupportMessageProtocols(); 
    29. clientProxyFactory.setSupportMessageProtocols(supportMessageProtocols); 
    30. // 设置负载均衡算法 
    31. LoadBalance loadBalance = getLoadBalance(rpcConfig.getLoadBalance()); 
    32. clientProxyFactory.setLoadBalance(loadBalance); 
    33. // 设置网络层实现 
    34. clientProxyFactory.setNetClient(new NettyNetClient()); 
    35.  
    36. return clientProxyFactory; 
    37. } 

    3.2本地服务列表缓存

    使用Map来缓存数据

    1. /** 
    2. * 服务发现本地缓存 
    3. */ 
    4. public class ServerDiscoveryCache { 
    5. /** 
    6. * key: serviceName 
    7. */ 
    8. private static final Map<String, List<Service>> SERVER_MAP = new ConcurrentHashMap<>(); 
    9. /** 
    10. * 客户端注入的远程服务service class 
    11. */ 
    12. public static final List<String> SERVICE_CLASS_NAMES = new ArrayList<>(); 
    13.  
    14. public static void put(String serviceName, List<Service> serviceList) { 
    15. SERVER_MAP.put(serviceName, serviceList); 
    16. } 
    17.  
    18. /** 
    19. * 去除指定的值 
    20. * @param serviceName 
    21. * @param service 
    22. */ 
    23. public static void remove(String serviceName, Service service) { 
    24. SERVER_MAP.computeIfPresent(serviceName, (key, value) -> 
    25. value.stream().filter(o -> !o.toString().equals(service.toString())).collect(Collectors.toList()) 
    26. ); 
    27. } 
    28.  
    29. public static void removeAll(String serviceName) { 
    30. SERVER_MAP.remove(serviceName); 
    31. } 
    32.  
    33.  
    34. public static boolean isEmpty(String serviceName) { 
    35. return SERVER_MAP.get(serviceName) == null || SERVER_MAP.get(serviceName).size() == 0; 
    36. } 
    37.  
    38. public static List<Service> get(String serviceName) { 
    39. return SERVER_MAP.get(serviceName); 
    40. } 
    41. } 

    ClientProxyFactory,先查本地缓存,缓存没有再查询zookeeper。

    1. /** 
    2. * 根据服务名获取可用的服务地址列表 
    3. * @param serviceName 
    4. * @return 
    5. */ 
    6. private List<Service> getServiceList(String serviceName) { 
    7. List<Service> services; 
    8. synchronized (serviceName){ 
    9. if (ServerDiscoveryCache.isEmpty(serviceName)) { 
    10. services = serverDiscovery.findServiceList(serviceName); 
    11. if (services == null || services.size() == 0) { 
    12. throw new RpcException("No provider available!"); 
    13. } 
    14. ServerDiscoveryCache.put(serviceName, services); 
    15. } else { 
    16. services = ServerDiscoveryCache.get(serviceName); 
    17. } 
    18. } 
    19. return services; 
    20. } 

    问题: 如果服务端因为宕机或网络问题下线了,缓存却还在就会导致客户端请求已经不可用的服务端,增加请求失败率。
    解决方案:由于服务端注册的是临时节点,所以如果服务端下线节点会被移除。只要监听zookeeper的子节点,如果新增或删除子节点就直接清空本地缓存即可。
    DefaultRpcProcessor

    1. /** 
    2. * Rpc处理者,支持服务启动暴露,自动注入Service 
    3. * @author 2YSP 
    4. * @date 2020/7/26 14:46 
    5. */ 
    6. public class DefaultRpcProcessor implements ApplicationListener<ContextRefreshedEvent> { 
    7.  
    8.  
    9.  
    10. @Override 
    11. public void onApplicationEvent(ContextRefreshedEvent event) { 
    12. // Spring启动完毕过后会收到一个事件通知 
    13. if (Objects.isNull(event.getApplicationContext().getParent())){ 
    14. ApplicationContext context = event.getApplicationContext(); 
    15. // 开启服务 
    16. startServer(context); 
    17. // 注入Service 
    18. injectService(context); 
    19. } 
    20. } 
    21.  
    22. private void injectService(ApplicationContext context) { 
    23. String[] names = context.getBeanDefinitionNames(); 
    24. for(String name : names){ 
    25. Class<?> clazz = context.getType(name); 
    26. if (Objects.isNull(clazz)){ 
    27. continue; 
    28. } 
    29.  
    30. Field[] declaredFields = clazz.getDeclaredFields(); 
    31. for(Field field : declaredFields){ 
    32. // 找出标记了InjectService注解的属性 
    33. InjectService injectService = field.getAnnotation(InjectService.class); 
    34. if (injectService == null){ 
    35. continue; 
    36. } 
    37.  
    38. Class<?> fieldClass = field.getType(); 
    39. Object object = context.getBean(name); 
    40. field.setAccessible(true); 
    41. try { 
    42. field.set(object,clientProxyFactory.getProxy(fieldClass)); 
    43. } catch (IllegalAccessException e) { 
    44. e.printStackTrace(); 
    45. } 
    46. // 添加本地服务缓存 
    47. ServerDiscoveryCache.SERVICE_CLASS_NAMES.add(fieldClass.getName()); 
    48. } 
    49. } 
    50. // 注册子节点监听 
    51. if (clientProxyFactory.getServerDiscovery() instanceof ZookeeperServerDiscovery){ 
    52. ZookeeperServerDiscovery serverDiscovery = (ZookeeperServerDiscovery) clientProxyFactory.getServerDiscovery(); 
    53. ZkClient zkClient = serverDiscovery.getZkClient(); 
    54. ServerDiscoveryCache.SERVICE_CLASS_NAMES.forEach(name ->{ 
    55. String servicePath = RpcConstant.ZK_SERVICE_PATH + RpcConstant.PATH_DELIMITER + name + "/service"; 
    56. zkClient.subscribeChildChanges(servicePath, new ZkChildListenerImpl()); 
    57. }); 
    58. logger.info("subscribe service zk node successfully"); 
    59. } 
    60.  
    61. } 
    62.  
    63. private void startServer(ApplicationContext context) { 
    64. ... 
    65.  
    66. } 
    67. } 
    68.  

    ZkChildListenerImpl

    1. /** 
    2. * 子节点事件监听处理类 
    3. */ 
    4. public class ZkChildListenerImpl implements IZkChildListener { 
    5.  
    6. private static Logger logger = LoggerFactory.getLogger(ZkChildListenerImpl.class); 
    7.  
    8. /** 
    9. * 监听子节点的删除和新增事件 
    10. * @param parentPath /rpc/serviceName/service 
    11. * @param childList 
    12. * @throws Exception 
    13. */ 
    14. @Override 
    15. public void handleChildChange(String parentPath, List<String> childList) throws Exception { 
    16. logger.debug("Child change parentPath:[{}] -- childList:[{}]", parentPath, childList); 
    17. // 只要子节点有改动就清空缓存 
    18. String[] arr = parentPath.split("/"); 
    19. ServerDiscoveryCache.removeAll(arr[2]); 
    20. } 
    21. } 

    3.3nettyClient支持TCP长连接

    这部分的改动最多,先增加新的sendRequest接口。

    添加接口
    添加接口

    实现类NettyNetClient

    1. /** 
    2. * @author 2YSP 
    3. * @date 2020/7/25 20:12 
    4. */ 
    5. public class NettyNetClient implements NetClient { 
    6.  
    7. private static Logger logger = LoggerFactory.getLogger(NettyNetClient.class); 
    8.  
    9. private static ExecutorService threadPool = new ThreadPoolExecutor(4, 10, 200, 
    10. TimeUnit.SECONDS, new LinkedBlockingQueue<>(1000), new ThreadFactoryBuilder() 
    11. .setNameFormat("rpcClient-%d") 
    12. .build()); 
    13.  
    14. private EventLoopGroup loopGroup = new NioEventLoopGroup(4); 
    15.  
    16. /** 
    17. * 已连接的服务缓存 
    18. * key: 服务地址,格式:ip:port 
    19. */ 
    20. public static Map<String, SendHandlerV2> connectedServerNodes = new ConcurrentHashMap<>(); 
    21.  
    22. @Override 
    23. public byte[] sendRequest(byte[] data, Service service) throws InterruptedException { 
    24. .... 
    25. return respData; 
    26. } 
    27.  
    28. @Override 
    29. public RpcResponse sendRequest(RpcRequest rpcRequest, Service service, MessageProtocol messageProtocol) { 
    30.  
    31. String address = service.getAddress(); 
    32. synchronized (address) { 
    33. if (connectedServerNodes.containsKey(address)) { 
    34. SendHandlerV2 handler = connectedServerNodes.get(address); 
    35. logger.info("使用现有的连接"); 
    36. return handler.sendRequest(rpcRequest); 
    37. } 
    38.  
    39. String[] addrInfo = address.split(":"); 
    40. final String serverAddress = addrInfo[0]; 
    41. final String serverPort = addrInfo[1]; 
    42. final SendHandlerV2 handler = new SendHandlerV2(messageProtocol, address); 
    43. threadPool.submit(() -> { 
    44. // 配置客户端 
    45. Bootstrap b = new Bootstrap(); 
    46. b.group(loopGroup).channel(NioSocketChannel.class) 
    47. .option(ChannelOption.TCP_NODELAY, true) 
    48. .handler(new ChannelInitializer<SocketChannel>() { 
    49. @Override 
    50. protected void initChannel(SocketChannel socketChannel) throws Exception { 
    51. ChannelPipeline pipeline = socketChannel.pipeline(); 
    52. pipeline 
    53. .addLast(handler); 
    54. } 
    55. }); 
    56. // 启用客户端连接 
    57. ChannelFuture channelFuture = b.connect(serverAddress, Integer.parseInt(serverPort)); 
    58. channelFuture.addListener(new ChannelFutureListener() { 
    59. @Override 
    60. public void operationComplete(ChannelFuture channelFuture) throws Exception { 
    61. connectedServerNodes.put(address, handler); 
    62. } 
    63. }); 
    64. } 
    65. ); 
    66. logger.info("使用新的连接。。。"); 
    67. return handler.sendRequest(rpcRequest); 
    68. } 
    69. } 
    70. } 
    71.  

    每次请求都会调用sendRequest()方法,用线程池异步和服务端创建TCP长连接,连接成功后将SendHandlerV2缓存到ConcurrentHashMap中方便复用,后续请求的请求地址(ip+port)如果在connectedServerNodes中存在则使用connectedServerNodes中的handler处理不再重新建立连接。

    SendHandlerV2

    1. /** 
    2. * @author 2YSP 
    3. * @date 2020/8/19 20:06 
    4. */ 
    5. public class SendHandlerV2 extends ChannelInboundHandlerAdapter { 
    6.  
    7. private static Logger logger = LoggerFactory.getLogger(SendHandlerV2.class); 
    8.  
    9. /** 
    10. * 等待通道建立最大时间 
    11. */ 
    12. static final int CHANNEL_WAIT_TIME = 4; 
    13. /** 
    14. * 等待响应最大时间 
    15. */ 
    16. static final int RESPONSE_WAIT_TIME = 8; 
    17.  
    18. private volatile Channel channel; 
    19.  
    20. private String remoteAddress; 
    21.  
    22. private static Map<String, RpcFuture<RpcResponse>> requestMap = new ConcurrentHashMap<>(); 
    23.  
    24. private MessageProtocol messageProtocol; 
    25.  
    26. private CountDownLatch latch = new CountDownLatch(1); 
    27.  
    28. public SendHandlerV2(MessageProtocol messageProtocol,String remoteAddress) { 
    29. this.messageProtocol = messageProtocol; 
    30. this.remoteAddress = remoteAddress; 
    31. } 
    32.  
    33. @Override 
    34. public void channelRegistered(ChannelHandlerContext ctx) throws Exception { 
    35. this.channel = ctx.channel(); 
    36. latch.countDown(); 
    37. } 
    38.  
    39. @Override 
    40. public void channelActive(ChannelHandlerContext ctx) throws Exception { 
    41. logger.debug("Connect to server successfully:{}", ctx); 
    42. } 
    43.  
    44. @Override 
    45. public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { 
    46. logger.debug("Client reads message:{}", msg); 
    47. ByteBuf byteBuf = (ByteBuf) msg; 
    48. byte[] resp = new byte[byteBuf.readableBytes()]; 
    49. byteBuf.readBytes(resp); 
    50. // 手动回收 
    51. ReferenceCountUtil.release(byteBuf); 
    52. RpcResponse response = messageProtocol.unmarshallingResponse(resp); 
    53. RpcFuture<RpcResponse> future = requestMap.get(response.getRequestId()); 
    54. future.setResponse(response); 
    55. } 
    56.  
    57. @Override 
    58. public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { 
    59. cause.printStackTrace(); 
    60. logger.error("Exception occurred:{}", cause.getMessage()); 
    61. ctx.close(); 
    62. } 
    63.  
    64. @Override 
    65. public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { 
    66. ctx.flush(); 
    67. } 
    68.  
    69. @Override 
    70. public void channelInactive(ChannelHandlerContext ctx) throws Exception { 
    71. super.channelInactive(ctx); 
    72. logger.error("channel inactive with remoteAddress:[{}]",remoteAddress); 
    73. NettyNetClient.connectedServerNodes.remove(remoteAddress); 
    74.  
    75. } 
    76.  
    77. @Override 
    78. public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { 
    79. super.userEventTriggered(ctx, evt); 
    80. } 
    81.  
    82. public RpcResponse sendRequest(RpcRequest request) { 
    83. RpcResponse response; 
    84. RpcFuture<RpcResponse> future = new RpcFuture<>(); 
    85. requestMap.put(request.getRequestId(), future); 
    86. try { 
    87. byte[] data = messageProtocol.marshallingRequest(request); 
    88. ByteBuf reqBuf = Unpooled.buffer(data.length); 
    89. reqBuf.writeBytes(data); 
    90. if (latch.await(CHANNEL_WAIT_TIME,TimeUnit.SECONDS)){ 
    91. channel.writeAndFlush(reqBuf); 
    92. // 等待响应 
    93. response = future.get(RESPONSE_WAIT_TIME, TimeUnit.SECONDS); 
    94. }else { 
    95. throw new RpcException("establish channel time out"); 
    96. } 
    97. } catch (Exception e) { 
    98. throw new RpcException(e.getMessage()); 
    99. } finally { 
    100. requestMap.remove(request.getRequestId()); 
    101. } 
    102. return response; 
    103. } 
    104. } 
    105.  

    RpcFuture

    1. package cn.sp.rpc.client.net; 
    2.  
    3. import java.util.concurrent.*; 
    4.  
    5. /** 
    6. * @author 2YSP 
    7. * @date 2020/8/19 22:31 
    8. */ 
    9. public class RpcFuture<T> implements Future<T> { 
    10.  
    11. private T response; 
    12. /** 
    13. * 因为请求和响应是一一对应的,所以这里是1 
    14. */ 
    15. private CountDownLatch countDownLatch = new CountDownLatch(1); 
    16. /** 
    17. * Future的请求时间,用于计算Future是否超时 
    18. */ 
    19. private long beginTime = System.currentTimeMillis(); 
    20.  
    21. @Override 
    22. public boolean cancel(boolean mayInterruptIfRunning) { 
    23. return false; 
    24. } 
    25.  
    26. @Override 
    27. public boolean isCancelled() { 
    28. return false; 
    29. } 
    30.  
    31. @Override 
    32. public boolean isDone() { 
    33. if (response != null) { 
    34. return true; 
    35. } 
    36. return false; 
    37. } 
    38.  
    39. /** 
    40. * 获取响应,直到有结果才返回 
    41. * @return 
    42. * @throws InterruptedException 
    43. * @throws ExecutionException 
    44. */ 
    45. @Override 
    46. public T get() throws InterruptedException, ExecutionException { 
    47. countDownLatch.await(); 
    48. return response; 
    49. } 
    50.  
    51. @Override 
    52. public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { 
    53. if (countDownLatch.await(timeout,unit)){ 
    54. return response; 
    55. } 
    56. return null; 
    57. } 
    58.  
    59. public void setResponse(T response) { 
    60. this.response = response; 
    61. countDownLatch.countDown(); 
    62. } 
    63.  
    64. public long getBeginTime() { 
    65. return beginTime; 
    66. } 
    67. } 
    68.  

    此处逻辑,第一次执行 SendHandlerV2#sendRequest() 时channel需要等待通道建立好之后才能发送请求,所以用CountDownLatch来控制,等待通道建立。
    自定义Future+requestMap缓存来实现netty的请求和阻塞等待响应,RpcRequest对象在创建时会生成一个请求的唯一标识requestId,发送请求前先将RpcFuture缓存到requestMap中,key为requestId,读取到服务端的响应信息后(channelRead方法),将响应结果放入对应的RpcFuture中。
    SendHandlerV2#channelInactive() 方法中,如果连接的服务端异常断开连接了,则及时清理缓存中对应的serverNode。

    四、压力测试

    测试环境:
    (英特尔)Intel(R) Core(TM) i5-6300HQ CPU @ 2.30GHz
    4核
    windows10家庭版(64位)
    16G内存

    1.本地启动zookeeper
    2.本地启动一个消费者,两个服务端,轮询算法
    3.使用ab进行压力测试,4个线程发送10000个请求

    ab -c 4 -n 10000 http://localhost:8080/test/user?id=1

    测试结果

    测试结果
    测试结果

    从图片可以看出,10000个请求只用了11s,比之前的130+秒耗时减少了10倍以上。

    代码地址:
    https://github.com/2YSP/rpc-spring-boot-starter
    https://github.com/2YSP/rpc-example

    参考:
    看了这篇你就会手写RPC框架了

  • 相关阅读:
    cocos2d-x CSV文件读取 (Excel生成csv文件)
    cocos2d-x 中 xml 文件读取
    String 类的实现
    json 文件解析与应用
    设计模式 之 《简单工厂模式》
    C++ 0X 新特性实例(比较常用的) (转)
    CCSpriteBatchNode CCSpriteFrameCache
    LongAdder
    ConcurrentHashMap源码
    HashMap源码
  • 原文地址:https://www.cnblogs.com/2YSP/p/13545217.html
Copyright © 2011-2022 走看看