使用Netty搭建WebSocket服务器
1.WebSocketServer.java
public class WebSocketServer { private final ChannelGroup group = new DefaultChannelGroup(ImmediateEventExecutor.INSTANCE); private final EventLoopGroup workerGroup = new NioEventLoopGroup(); private Channel channel; public ChannelFuture start(InetSocketAddress address) { ServerBootstrap boot = new ServerBootstrap(); boot.group(workerGroup).channel(NioServerSocketChannel.class).childHandler(createInitializer(group)); ChannelFuture f = boot.bind(address).syncUninterruptibly(); channel = f.channel(); return f; } protected ChannelHandler createInitializer(ChannelGroup group2) { return new ChatServerInitializer(group2); } public void destroy() { if (channel != null) channel.close(); group.close(); workerGroup.shutdownGracefully(); } public static void main(String[] args) { final WebSocketServer server = new WebSocketServer(); ChannelFuture f = server.start(new InetSocketAddress(2048)); System.out.println("server start................"); Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { server.destroy(); } }); f.channel().closeFuture().syncUninterruptibly(); } private static WebSocketServer instance; private WebSocketServer() {} public static synchronized WebSocketServer getInstance() {// 懒汉,线程安全 if (instance == null) { instance = new WebSocketServer(); } return instance; } public void running(){ if(instance != null){ String port=null; port=BusinessConfigUtils.findProperty("websocket_port");//获取端口号 if(null==port||port.length()<0||!StringUtils.isNumeric(port)){ port="18080"; } instance.start(new InetSocketAddress(Integer.valueOf(port))); //ChannelFuture f = System.out.println("----------------------------------------WEBSOCKET SERVER START----------------------------------------"); /*Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { instance.destroy(); } }); f.channel().closeFuture().syncUninterruptibly();*/ } } }
2.ChatServerInitializer.java
public class ChatServerInitializer extends ChannelInitializer<Channel> { private final ChannelGroup group; public ChatServerInitializer(ChannelGroup group) { super(); this.group = group; } @Override protected void initChannel(Channel ch) throws Exception { ChannelPipeline pipeline = ch.pipeline(); pipeline.addLast(new HttpServerCodec()); pipeline.addLast(new ChunkedWriteHandler()); pipeline.addLast(new HttpObjectAggregator(64*1024)); pipeline.addLast(new HttpRequestHandler("/ws")); pipeline.addLast(new WebSocketServerProtocolHandler("/ws")); pipeline.addLast(new TextWebSocketFrameHandler(group)); } }
3. HttpRequestHandler.java
public class HttpRequestHandler extends SimpleChannelInboundHandler<FullHttpRequest> { private LoginTimeService loginTimeService = SpringContextHolder.getBean("loginTimeServiceImpl"); private final String wsUri; public HttpRequestHandler(String wsUri) { super(); this.wsUri = wsUri; } @Override @SuppressWarnings("deprecation") protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception { if (wsUri.equalsIgnoreCase(msg.getUri().substring(0, 3))) { String userId = findUserIdByUri(msg.getUri()); if (userId != null && userId.trim() != null && userId.trim().length() > 0) { ctx.channel().attr(AttributeKey.valueOf(ctx.channel().id().asShortText())).set(userId);// 写userid值 UserIdToWebSocketChannelShare.userIdToWebSocketChannelMap.put(userId, ctx.channel()); // 用户Id与Channel绑定 loginTimeService.onLine(userId, new Date());// 统计上线记录 } else { }// 没有获取到用户Id ctx.fireChannelRead(msg.setUri(wsUri).retain()); } } private String findUserIdByUri(String uri) {// 通过Uid获取用户Id--uri中包含userId String userId = ""; try { userId = uri.substring(uri.indexOf("userId") + 7); if (userId != null && userId.trim() != null && userId.trim().length() > 0) { userId = userId.trim(); } } catch (Exception e) { } return userId; } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { ctx.close(); cause.printStackTrace(System.err); } }
4. TextWebSocketFrameHandler.java
public class TextWebSocketFrameHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> { private LoginTimeService loginTimeService = SpringContextHolder.getBean("loginTimeServiceImpl"); private final ChannelGroup group; public TextWebSocketFrameHandler(ChannelGroup group) { super(); this.group = group; } @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt == WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) { ctx.pipeline().remove(HttpRequestHandler.class); // group.writeAndFlush(""); group.add(ctx.channel()); } else { super.userEventTriggered(ctx, evt); } } @Override protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception { group.writeAndFlush(msg.retain()); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { ctx.close(); cause.printStackTrace(); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { // (6) Channel incoming = ctx.channel(); String userId = (String) incoming.attr(AttributeKey.valueOf(incoming.id().asShortText())).get(); UserIdToWebSocketChannelShare.userIdToWebSocketChannelMap.remove(userId);// 删除缓存的通道 loginTimeService.outLine(userId, new Date());// 下线通过 } }