在上一篇文章()中我介绍了服务端如何给指定用户的客户端发送消息,并如何处理对方不在线的情况。在这篇文章中我们继续思考另外一个重要的问题,那就是:如果我们的项目是分布式环境,登录的用户被Nginx的反向代理分配到多个不同服务器,那么在其中一个服务器建立了WebSocket连接的用户如何给在另外一个服务器上建立了WebSocket连接的用户发送消息呢?
其实,要解决这个问题就需要实现分布式WebSocket,而分布式WebSocket一般可以通过以下两种方案来实现:
- 方案一:将消息(<用户id,消息内容>)统一推送到一个消息队列(Redis、Kafka等)的的topic,然后每个应用节点都订阅这个topic,在接收到WebSocket消息后取出这个消息的“消息接收者的用户ID/用户名”,然后再比对自身是否存在相应用户的连接,如果存在则推送消息,否则丢弃接收到的这个消息(这个消息接收者所在的应用节点会处理)
- 方案二:在用户建立WebSocket连接后,使用Redis缓存记录用户的WebSocket建立在哪个应用节点上,然后同样使用消息队列将消息推送到接收者所在的应用节点上面(实现上比方案一要复杂,但是网络流量会更低)
注:本篇文章的完整源码可以参考:
在下面的示例中,我将根据相对简单的方案一来是实现,具体实现方式如下:
(1)定义一个WebSocket Channel枚举类:
package cn.zifangsky.mqwebsocket.enums;import org.apache.commons.lang3.StringUtils;/** * WebSocket Channel枚举类 * * @author zifangsky * @date 2018/10/16 * @since 1.0.0 */public enum WebSocketChannelEnum { //测试使用的简易点对点聊天 CHAT("CHAT", "测试使用的简易点对点聊天", "/topic/reply"); WebSocketChannelEnum(String code, String description, String subscribeUrl) { this.code = code; this.description = description; this.subscribeUrl = subscribeUrl; } /** * 唯一CODE */ private String code; /** * 描述 */ private String description; /** * WebSocket客户端订阅的URL */ private String subscribeUrl; public String getCode() { return code; } public String getDescription() { return description; } public String getSubscribeUrl() { return subscribeUrl; } /** * 通过CODE查找枚举类 */ public static WebSocketChannelEnum fromCode(String code){ if(StringUtils.isNoneBlank(code)){ for(WebSocketChannelEnum channelEnum : values()){ if(channelEnum.code.equals(code)){ return channelEnum; } } } return null; }}复制代码
(2)配置基于Redis的消息队列:
关于Redis实现的消息队列可以参考我之前的这篇文章:
需要注意的是,在大中型正式项目中并不推荐使用Redis实现的消息队列,因为经过测试它并不是特别可靠,所以应该考虑使用Kafka
、rabbitMQ
等专业的消息队列中间件(PS:据说Redis 5.0全新的数据结构Streams
极大增强了Redis的消息队列功能?)
package cn.zifangsky.mqwebsocket.config;import cn.zifangsky.mqwebsocket.mq.MessageReceiver;import com.fasterxml.jackson.annotation.JsonAutoDetect;import com.fasterxml.jackson.annotation.PropertyAccessor;import com.fasterxml.jackson.databind.ObjectMapper;import org.springframework.beans.factory.annotation.Autowired;import org.springframework.beans.factory.annotation.Value;import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;import org.springframework.context.annotation.Bean;import org.springframework.context.annotation.Configuration;import org.springframework.data.redis.connection.RedisClusterConfiguration;import org.springframework.data.redis.connection.RedisConnectionFactory;import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;import org.springframework.data.redis.core.RedisTemplate;import org.springframework.data.redis.listener.PatternTopic;import org.springframework.data.redis.listener.RedisMessageListenerContainer;import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;import org.springframework.data.redis.serializer.StringRedisSerializer;import redis.clients.jedis.JedisCluster;import redis.clients.jedis.JedisPoolConfig;import java.util.Arrays;/** * Redis相关配置 * * @author zifangsky * @date 2018/7/30 * @since 1.0.0 */@Configuration@ConditionalOnClass({JedisCluster.class})public class RedisConfig { @Value("${spring.redis.timeout}") private String timeOut; @Value("${spring.redis.cluster.nodes}") private String nodes; @Value("${spring.redis.cluster.max-redirects}") private int maxRedirects; @Value("${spring.redis.jedis.pool.max-active}") private int maxActive; @Value("${spring.redis.jedis.pool.max-wait}") private int maxWait; @Value("${spring.redis.jedis.pool.max-idle}") private int maxIdle; @Value("${spring.redis.jedis.pool.min-idle}") private int minIdle; @Value("${spring.redis.message.topic-name}") private String topicName; @Bean public JedisPoolConfig jedisPoolConfig(){ JedisPoolConfig config = new JedisPoolConfig(); config.setMaxTotal(maxActive); config.setMaxIdle(maxIdle); config.setMinIdle(minIdle); config.setMaxWaitMillis(maxWait); return config; } @Bean public RedisClusterConfiguration redisClusterConfiguration(){ RedisClusterConfiguration configuration = new RedisClusterConfiguration(Arrays.asList(nodes)); configuration.setMaxRedirects(maxRedirects); return configuration; } /** * JedisConnectionFactory */ @Bean public JedisConnectionFactory jedisConnectionFactory(RedisClusterConfiguration configuration,JedisPoolConfig jedisPoolConfig){ return new JedisConnectionFactory(configuration,jedisPoolConfig); } /** * 使用Jackson序列化对象 */ @Bean public Jackson2JsonRedisSerializer
需要注意的是,这里使用的配置如下所示:
spring: ... #redis redis: cluster: nodes: namenode22:6379,datanode23:6379,datanode24:6379 max-redirects: 6 timeout: 300000 jedis: pool: max-active: 8 max-wait: 100000 max-idle: 8 min-idle: 0 #自定义的监听的TOPIC路径 message: topic-name: topic-test复制代码
(3)定义一个Redis消息的处理者:
package cn.zifangsky.mqwebsocket.mq;import cn.zifangsky.mqwebsocket.enums.WebSocketChannelEnum;import cn.zifangsky.mqwebsocket.model.websocket.RedisWebsocketMsg;import org.apache.commons.lang3.StringUtils;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.beans.factory.annotation.Autowired;import org.springframework.messaging.simp.SimpMessagingTemplate;import org.springframework.messaging.simp.user.SimpUser;import org.springframework.messaging.simp.user.SimpUserRegistry;import org.springframework.stereotype.Component;import java.text.MessageFormat;/** * Redis中的WebSocket消息的处理者 * * @author zifangsky * @date 2018/10/16 * @since 1.0.0 */@Componentpublic class MessageReceiver { private final Logger logger = LoggerFactory.getLogger(getClass()); @Autowired private SimpMessagingTemplate messagingTemplate; @Autowired private SimpUserRegistry userRegistry; /** * 处理WebSocket消息 */ public void receiveMessage(RedisWebsocketMsg redisWebsocketMsg) { logger.info(MessageFormat.format("Received Message: {0}", redisWebsocketMsg)); //1. 取出用户名并判断是否连接到当前应用节点的WebSocket SimpUser simpUser = userRegistry.getUser(redisWebsocketMsg.getReceiver()); if(simpUser != null && StringUtils.isNoneBlank(simpUser.getName())){ //2. 获取WebSocket客户端的订阅地址 WebSocketChannelEnum channelEnum = WebSocketChannelEnum.fromCode(redisWebsocketMsg.getChannelCode()); if(channelEnum != null){ //3. 给WebSocket客户端发送消息 messagingTemplate.convertAndSendToUser(redisWebsocketMsg.getReceiver(), channelEnum.getSubscribeUrl(), redisWebsocketMsg.getContent()); } } }}复制代码
(4)在Controller中发送WebSocket消息:
package cn.zifangsky.mqwebsocket.controller;import cn.zifangsky.mqwebsocket.common.Constants;import cn.zifangsky.mqwebsocket.common.SpringContextUtils;import cn.zifangsky.mqwebsocket.enums.ExpireEnum;import cn.zifangsky.mqwebsocket.enums.WebSocketChannelEnum;import cn.zifangsky.mqwebsocket.model.User;import cn.zifangsky.mqwebsocket.model.websocket.HelloMessage;import cn.zifangsky.mqwebsocket.model.websocket.RedisWebsocketMsg;import cn.zifangsky.mqwebsocket.service.RedisService;import cn.zifangsky.mqwebsocket.utils.JsonUtils;import org.apache.commons.lang3.StringUtils;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.beans.factory.annotation.Autowired;import org.springframework.beans.factory.annotation.Value;import org.springframework.messaging.simp.SimpMessagingTemplate;import org.springframework.messaging.simp.user.SimpUser;import org.springframework.messaging.simp.user.SimpUserRegistry;import org.springframework.stereotype.Controller;import org.springframework.web.bind.annotation.PostMapping;import org.springframework.web.bind.annotation.RequestMapping;import org.springframework.web.bind.annotation.ResponseBody;import javax.annotation.Resource;import javax.servlet.http.HttpServletRequest;import javax.servlet.http.HttpSession;import java.text.MessageFormat;import java.util.HashMap;import java.util.List;import java.util.Map;/** * 测试{ @link org.springframework.messaging.simp.SimpMessagingTemplate}类的基本用法 * @author zifangsky * @date 2018/10/10 * @since 1.0.0 */@Controller@RequestMapping(("/wsTemplate"))public class RedisMessageController { private final Logger logger = LoggerFactory.getLogger(getClass()); @Value("${spring.redis.message.topic-name}") private String topicName; @Autowired private SimpMessagingTemplate messagingTemplate; @Autowired private SimpUserRegistry userRegistry; @Resource(name = "redisServiceImpl") private RedisService redisService; /** * 给指定用户发送WebSocket消息 */ @PostMapping("/sendToUser") @ResponseBody public String chat(HttpServletRequest request) { //消息接收者 String receiver = request.getParameter("receiver"); //消息内容 String msg = request.getParameter("msg"); HttpSession session = SpringContextUtils.getSession(); User loginUser = (User) session.getAttribute(Constants.SESSION_USER); HelloMessage resultData = new HelloMessage(MessageFormat.format("{0} say: {1}", loginUser.getUsername(), msg)); this.sendToUser(loginUser.getUsername(), receiver, WebSocketChannelEnum.CHAT.getSubscribeUrl(), JsonUtils.toJson(resultData)); return "ok"; } /** * 给指定用户发送消息,并处理接收者不在线的情况 * @param sender 消息发送者 * @param receiver 消息接收者 * @param destination 目的地 * @param payload 消息正文 */ private void sendToUser(String sender, String receiver, String destination, String payload){ SimpUser simpUser = userRegistry.getUser(receiver); //如果接收者存在,则发送消息 if(simpUser != null && StringUtils.isNoneBlank(simpUser.getName())){ messagingTemplate.convertAndSendToUser(receiver, destination, payload); } //如果接收者在线,则说明接收者连接了集群的其他节点,需要通知接收者连接的那个节点发送消息 else if(redisService.isSetMember(Constants.REDIS_WEBSOCKET_USER_SET, receiver)){ RedisWebsocketMsgredisWebsocketMsg = new RedisWebsocketMsg<>(receiver, WebSocketChannelEnum.CHAT.getCode(), payload); redisService.convertAndSend(topicName, redisWebsocketMsg); } //否则将消息存储到redis,等用户上线后主动拉取未读消息 else{ //存储消息的Redis列表名 String listKey = Constants.REDIS_UNREAD_MSG_PREFIX + receiver + ":" + destination; logger.info(MessageFormat.format("消息接收者{0}还未建立WebSocket连接,{1}发送的消息【{2}】将被存储到Redis的【{3}】列表中", receiver, sender, payload, listKey)); //存储消息到Redis中 redisService.addToListRight(listKey, ExpireEnum.UNREAD_MSG, payload); } } /** * 拉取指定监听路径的未读的WebSocket消息 * @param destination 指定监听路径 * @return java.util.Map */ @PostMapping("/pullUnreadMessage") @ResponseBody public Map pullUnreadMessage(String destination){ Map result = new HashMap<>(); try { HttpSession session = SpringContextUtils.getSession(); //当前登录用户 User loginUser = (User) session.getAttribute(Constants.SESSION_USER); //存储消息的Redis列表名 String listKey = Constants.REDIS_UNREAD_MSG_PREFIX + loginUser.getUsername() + ":" + destination; //从Redis中拉取所有未读消息 List messageList = redisService.rangeList(listKey, 0, -1); result.put("code", "200"); if(messageList !=null && messageList.size() > 0){ //删除Redis中的这个未读消息列表 redisService.delete(listKey); //将数据添加到返回集,供前台页面展示 result.put("result", messageList); } }catch (Exception e){ result.put("code", "500"); result.put("msg", e.getMessage()); } return result; }}复制代码
(5)其他拦截器处理WebSocket连接相关问题:
i)AuthHandshakeInterceptor:
package cn.zifangsky.mqwebsocket.interceptor.websocket;import cn.zifangsky.mqwebsocket.common.Constants;import cn.zifangsky.mqwebsocket.common.SpringContextUtils;import cn.zifangsky.mqwebsocket.model.User;import cn.zifangsky.mqwebsocket.service.RedisService;import org.apache.commons.lang3.StringUtils;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.http.server.ServerHttpRequest;import org.springframework.http.server.ServerHttpResponse;import org.springframework.stereotype.Component;import org.springframework.web.socket.WebSocketHandler;import org.springframework.web.socket.server.HandshakeInterceptor;import javax.annotation.Resource;import javax.servlet.http.HttpSession;import java.text.MessageFormat;import java.util.Map;/** * 自定义{ @link org.springframework.web.socket.server.HandshakeInterceptor},实现“需要登录才允许连接WebSocket” * * @author zifangsky * @date 2018/10/11 * @since 1.0.0 */@Componentpublic class AuthHandshakeInterceptor implements HandshakeInterceptor { private final Logger logger = LoggerFactory.getLogger(getClass()); @Resource(name = "redisServiceImpl") private RedisService redisService; @Override public boolean beforeHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Mapmap) throws Exception { HttpSession session = SpringContextUtils.getSession(); User loginUser = (User) session.getAttribute(Constants.SESSION_USER); if(redisService.isSetMember(Constants.REDIS_WEBSOCKET_USER_SET, loginUser.getUsername())){ logger.error("同一个用户不准建立多个连接WebSocket"); return false; }else if(loginUser == null || StringUtils.isBlank(loginUser.getUsername())){ logger.error("未登录系统,禁止连接WebSocket"); return false; }else{ logger.debug(MessageFormat.format("用户{0}请求建立WebSocket连接", loginUser.getUsername())); return true; } } @Override public void afterHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Exception e) { }}复制代码
ii)MyHandshakeHandler:
package cn.zifangsky.mqwebsocket.interceptor.websocket;import cn.zifangsky.mqwebsocket.common.Constants;import cn.zifangsky.mqwebsocket.common.SpringContextUtils;import cn.zifangsky.mqwebsocket.model.User;import cn.zifangsky.mqwebsocket.service.RedisService;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.http.server.ServerHttpRequest;import org.springframework.stereotype.Component;import org.springframework.web.socket.WebSocketHandler;import org.springframework.web.socket.server.support.DefaultHandshakeHandler;import javax.annotation.Resource;import javax.servlet.http.HttpSession;import java.security.Principal;import java.text.MessageFormat;import java.util.Map;/** * 自定义{ @link org.springframework.web.socket.server.support.DefaultHandshakeHandler},实现“生成自定义的{ @link java.security.Principal}” * * @author zifangsky * @date 2018/10/11 * @since 1.0.0 */@Componentpublic class MyHandshakeHandler extends DefaultHandshakeHandler{ private final Logger logger = LoggerFactory.getLogger(getClass()); @Resource(name = "redisServiceImpl") private RedisService redisService; @Override protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler, Mapattributes) { HttpSession session = SpringContextUtils.getSession(); User loginUser = (User) session.getAttribute(Constants.SESSION_USER); if(loginUser != null){ logger.debug(MessageFormat.format("WebSocket连接开始创建Principal,用户:{0}", loginUser.getUsername())); //1. 将用户名存到Redis中 redisService.addToSet(Constants.REDIS_WEBSOCKET_USER_SET, loginUser.getUsername()); //2. 返回自定义的Principal return new MyPrincipal(loginUser.getUsername()); }else{ logger.error("未登录系统,禁止连接WebSocket"); return null; } }}复制代码
iii)MyChannelInterceptor:
package cn.zifangsky.mqwebsocket.interceptor.websocket;import cn.zifangsky.mqwebsocket.common.Constants;import cn.zifangsky.mqwebsocket.service.RedisService;import org.apache.commons.lang3.StringUtils;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.messaging.Message;import org.springframework.messaging.MessageChannel;import org.springframework.messaging.simp.stomp.StompCommand;import org.springframework.messaging.simp.stomp.StompHeaderAccessor;import org.springframework.messaging.support.ChannelInterceptor;import org.springframework.stereotype.Component;import javax.annotation.Resource;import java.security.Principal;import java.text.MessageFormat;/** * 自定义{ @link org.springframework.messaging.support.ChannelInterceptor},实现断开连接的处理 * * @author zifangsky * @date 2018/10/10 * @since 1.0.0 */@Componentpublic class MyChannelInterceptor implements ChannelInterceptor{ private final Logger logger = LoggerFactory.getLogger(getClass()); @Resource(name = "redisServiceImpl") private RedisService redisService; @Override public void afterSendCompletion(Message message, MessageChannel channel, boolean sent, Exception ex) { StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message); StompCommand command = accessor.getCommand(); //用户已经断开连接 if(StompCommand.DISCONNECT.equals(command)){ String user = ""; Principal principal = accessor.getUser(); if(principal != null && StringUtils.isNoneBlank(principal.getName())){ user = principal.getName(); //从Redis中移除用户 redisService.removeFromSet(Constants.REDIS_WEBSOCKET_USER_SET, user); }else{ user = accessor.getSessionId(); } logger.debug(MessageFormat.format("用户{0}的WebSocket连接已经断开", user)); } }}复制代码
(6)WebSocket相关配置:
package cn.zifangsky.mqwebsocket.config;import cn.zifangsky.mqwebsocket.interceptor.websocket.MyHandshakeHandler;import cn.zifangsky.mqwebsocket.interceptor.websocket.AuthHandshakeInterceptor;import cn.zifangsky.mqwebsocket.interceptor.websocket.MyChannelInterceptor;import org.springframework.beans.factory.annotation.Autowired;import org.springframework.context.annotation.Configuration;import org.springframework.messaging.simp.config.ChannelRegistration;import org.springframework.messaging.simp.config.MessageBrokerRegistry;import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;import org.springframework.web.socket.config.annotation.StompEndpointRegistry;import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;/** * WebSocket相关配置 * * @author zifangsky * @date 2018/9/30 * @since 1.0.0 */@Configuration@EnableWebSocketMessageBrokerpublic class WebSocketConfig implements WebSocketMessageBrokerConfigurer{ @Autowired private AuthHandshakeInterceptor authHandshakeInterceptor; @Autowired private MyHandshakeHandler myHandshakeHandler; @Autowired private MyChannelInterceptor myChannelInterceptor; @Override public void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/chat-websocket") .addInterceptors(authHandshakeInterceptor) .setHandshakeHandler(myHandshakeHandler) .withSockJS(); } @Override public void configureMessageBroker(MessageBrokerRegistry registry) { //客户端需要把消息发送到/message/xxx地址 registry.setApplicationDestinationPrefixes("/message"); //服务端广播消息的路径前缀,客户端需要相应订阅/topic/yyy这个地址的消息 registry.enableSimpleBroker("/topic"); //给指定用户发送消息的路径前缀,默认值是/user/ registry.setUserDestinationPrefix("/user/"); } @Override public void configureClientInboundChannel(ChannelRegistration registration) { registration.interceptors(myChannelInterceptor); }}复制代码
(7)示例页面:
Chat With STOMP Message 复制代码
测试效果省略,具体效果可以自行在两台不同服务器上面运行示例源码查看。