From 1c0a763cd7a85c15a5c1d38e06f697480c94dabe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yingyi=20/=20=E9=A2=96=E9=80=B8?= <49649786+Zuoqiu-Yingyi@users.noreply.github.com> Date: Thu, 11 Jul 2024 09:05:35 +0800 Subject: [PATCH] :art: Fixed the issue of WebSocket asynchronous initialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * :bug: Fixed the issue where WebSocket broadcast was used before initialization was complete 修复 WebSocket 广播未初始化完成即使用的问题 * :art: Improved broadcast-related APIs 改进广播相关 API --- kernel/api/broadcast.go | 246 +++++++++++++++++++++++----------------- 1 file changed, 144 insertions(+), 102 deletions(-) diff --git a/kernel/api/broadcast.go b/kernel/api/broadcast.go index 3a9ddd147..976a323ba 100644 --- a/kernel/api/broadcast.go +++ b/kernel/api/broadcast.go @@ -17,9 +17,9 @@ package api import ( - "fmt" "net/http" "sync" + "time" "github.com/88250/gulu" "github.com/gin-gonic/gin" @@ -37,56 +37,88 @@ var ( BroadcastChannels = sync.Map{} ) -/* -broadcast create a broadcast channel WebSocket connection - -@param - - query.channel: channel name - -@example - - ws://localhost:6806/ws/broadcast?channel=test -*/ +// broadcast create a broadcast channel WebSocket connection +// +// @param +// +// { +// channel: string, // channel name +// } +// +// @example +// +// "ws://localhost:6806/ws/broadcast?channel=test" func broadcast(c *gin.Context) { var ( channel string = c.Query("channel") broadcastChannel *melody.Melody ) - if _broadcastChannel, exist := BroadcastChannels.Load(channel); exist { + _broadcastChannel, exist := BroadcastChannels.Load(channel) + if exist { // channel exists, use it broadcastChannel = _broadcastChannel.(*melody.Melody) - subscribe(c, broadcastChannel, channel) - } else { - // channel not found, create a new one - broadcastChannel := melody.New() - broadcastChannel.Config.MaxMessageSize = 1024 * 1024 * 128 // 128 MiB - BroadcastChannels.Store(channel, broadcastChannel) - subscribe(c, broadcastChannel, channel) + if broadcastChannel.IsClosed() { + BroadcastChannels.Delete(channel) + } else { + subscribe(c, broadcastChannel, channel) + return + } + } + initialize(c, channel) +} - // broadcast string message to other session - broadcastChannel.HandleMessage(func(s *melody.Session, msg []byte) { - broadcastChannel.BroadcastOthers(msg, s) - }) +// initialize initializes an broadcast session set +func initialize(c *gin.Context, channel string) { + // channel not found, create a new one + broadcastChannel := melody.New() + broadcastChannel.Config.MaxMessageSize = 1024 * 1024 * 128 // 128 MiB - // broadcast binary message to other session - broadcastChannel.HandleMessageBinary(func(s *melody.Session, msg []byte) { - broadcastChannel.BroadcastBinaryOthers(msg, s) - }) + // broadcast string message to other session + broadcastChannel.HandleMessage(func(s *melody.Session, msg []byte) { + broadcastChannel.BroadcastOthers(msg, s) + }) - // recycling - broadcastChannel.HandleClose(func(s *melody.Session, status int, reason string) error { - channel := s.Keys["channel"].(string) - logging.LogInfof("close broadcast session in channel [%s] with status code %d: %s", channel, status, reason) + // broadcast binary message to other session + broadcastChannel.HandleMessageBinary(func(s *melody.Session, msg []byte) { + broadcastChannel.BroadcastBinaryOthers(msg, s) + }) - count := broadcastChannel.Len() - if count == 0 { - BroadcastChannels.Delete(channel) - logging.LogInfof("dispose broadcast channel [%s]", channel) + // recycling + broadcastChannel.HandleClose(func(s *melody.Session, status int, reason string) error { + channel := s.Keys["channel"].(string) + logging.LogInfof("close broadcast session in channel [%s] with status code %d: %s", channel, status, reason) + + count := broadcastChannel.Len() + if count == 0 { + BroadcastChannels.Delete(channel) + broadcastChannel.Close() + logging.LogInfof("dispose broadcast channel [%s]", channel) + } + return nil + }) + + for { + // Melody Initialization is an asynchronous process, so we need to wait for it to complete + if broadcastChannel.IsClosed() { + time.Sleep(1 * time.Nanosecond) + } else { + _broadcastChannel, loaded := BroadcastChannels.LoadOrStore(channel, broadcastChannel) + __broadcastChannel := _broadcastChannel.(*melody.Melody) + if loaded { + // channel exists + if __broadcastChannel.IsClosed() { + // channel is closed, replace it + BroadcastChannels.Store(channel, broadcastChannel) + __broadcastChannel = broadcastChannel + } else { + // channel is open, close the new one + broadcastChannel.Close() + } } - return nil - }) + subscribe(c, __broadcastChannel, channel) + break + } } } @@ -104,19 +136,26 @@ func subscribe(c *gin.Context, broadcastChannel *melody.Melody, channel string) } } -/* -postMessage send string message to a broadcast channel - -@param - - body.channel: channel name - body.message: message payload - -@returns - - body.data.channel.name: channel name - body.data.channel.count: indicate how many websocket session received the message -*/ +// postMessage send string message to a broadcast channel +// @param +// +// { +// channel: string // channel name +// message: string // message payload +// } +// +// @returns +// +// { +// code: int, +// msg: string, +// data: { +// channel: { +// name: string, //channel name +// count: string, //listener count +// }, +// }, +// } func postMessage(c *gin.Context) { ret := gulu.Ret.NewResult() defer c.JSON(http.StatusOK, ret) @@ -126,16 +165,14 @@ func postMessage(c *gin.Context) { return } - channel := arg["channel"].(string) message := arg["message"].(string) + channel := &Channel{ + Name: arg["channel"].(string), + Count: 0, + } - if _broadcastChannel, ok := BroadcastChannels.Load(channel); !ok { - err := fmt.Errorf("broadcast channel [%s] not found", channel) - logging.LogWarnf(err.Error()) - - ret.Code = -1 - ret.Msg = err.Error() - return + if _broadcastChannel, ok := BroadcastChannels.Load(channel.Name); !ok { + channel.Count = 0 } else { var broadcastChannel = _broadcastChannel.(*melody.Melody) if err := broadcastChannel.Broadcast([]byte(message)); nil != err { @@ -146,27 +183,33 @@ func postMessage(c *gin.Context) { return } - count := broadcastChannel.Len() - ret.Data = map[string]interface{}{ - "channel": &Channel{ - Name: channel, - Count: count, - }, - } + channel.Count = broadcastChannel.Len() + } + ret.Data = map[string]interface{}{ + "channel": channel, } } -/* -getChannelInfo gets the information of a broadcast channel - -@param - - body.name: channel name - -@returns - - body.data.channel: the channel information -*/ +// getChannelInfo gets the information of a broadcast channel +// +// @param +// +// { +// name: string, // channel name +// } +// +// @returns +// +// { +// code: int, +// msg: string, +// data: { +// channel: { +// name: string, //channel name +// count: string, //listener count +// }, +// }, +// } func getChannelInfo(c *gin.Context) { ret := gulu.Ret.NewResult() defer c.JSON(http.StatusOK, ret) @@ -176,38 +219,37 @@ func getChannelInfo(c *gin.Context) { return } - name := arg["name"].(string) + channel := &Channel{ + Name: arg["name"].(string), + Count: 0, + } - if _broadcastChannel, ok := BroadcastChannels.Load(name); !ok { - err := fmt.Errorf("broadcast channel [%s] not found", name) - logging.LogWarnf(err.Error()) - - ret.Code = -1 - ret.Msg = err.Error() - return + if _broadcastChannel, ok := BroadcastChannels.Load(channel.Name); !ok { + channel.Count = 0 } else { var broadcastChannel = _broadcastChannel.(*melody.Melody) + channel.Count = broadcastChannel.Len() + } - count := broadcastChannel.Len() - ret.Data = map[string]interface{}{ - "channel": &Channel{ - Name: name, - Count: count, - }, - } + ret.Data = map[string]interface{}{ + "channel": channel, } } -/* -getChannels gets the channel name and lintener number of all broadcast chanel - -@returns - - body.data.channels: { - name: channel name - count: listener count - }[] -*/ +// getChannels gets the channel name and lintener number of all broadcast chanel +// +// @returns +// +// { +// code: int, +// msg: string, +// data: { +// channels: { +// name: string, //channel name +// count: string, //listener count +// }[], +// }, +// } func getChannels(c *gin.Context) { ret := gulu.Ret.NewResult() defer c.JSON(http.StatusOK, ret)