🎨 Fixed the issue of WebSocket asynchronous initialization

* 🐛 Fixed the issue where WebSocket broadcast was used before initialization was complete

修复 WebSocket 广播未初始化完成即使用的问题

* 🎨 Improved broadcast-related APIs

改进广播相关 API
This commit is contained in:
Yingyi / 颖逸 2024-07-11 09:05:35 +08:00 committed by GitHub
parent f6bcb165b8
commit 1c0a763cd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,9 +17,9 @@
package api package api
import ( import (
"fmt"
"net/http" "net/http"
"sync" "sync"
"time"
"github.com/88250/gulu" "github.com/88250/gulu"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -37,56 +37,88 @@ var (
BroadcastChannels = sync.Map{} BroadcastChannels = sync.Map{}
) )
/* // broadcast create a broadcast channel WebSocket connection
broadcast create a broadcast channel WebSocket connection //
// @param
@param //
// {
query.channel: channel name // channel: string, // channel name
// }
@example //
// @example
ws://localhost:6806/ws/broadcast?channel=test //
*/ // "ws://localhost:6806/ws/broadcast?channel=test"
func broadcast(c *gin.Context) { func broadcast(c *gin.Context) {
var ( var (
channel string = c.Query("channel") channel string = c.Query("channel")
broadcastChannel *melody.Melody broadcastChannel *melody.Melody
) )
if _broadcastChannel, exist := BroadcastChannels.Load(channel); exist { _broadcastChannel, exist := BroadcastChannels.Load(channel)
if exist {
// channel exists, use it // channel exists, use it
broadcastChannel = _broadcastChannel.(*melody.Melody) broadcastChannel = _broadcastChannel.(*melody.Melody)
subscribe(c, broadcastChannel, channel) if broadcastChannel.IsClosed() {
} else { BroadcastChannels.Delete(channel)
// channel not found, create a new one } else {
broadcastChannel := melody.New() subscribe(c, broadcastChannel, channel)
broadcastChannel.Config.MaxMessageSize = 1024 * 1024 * 128 // 128 MiB return
BroadcastChannels.Store(channel, broadcastChannel) }
subscribe(c, broadcastChannel, channel) }
initialize(c, channel)
}
// broadcast string message to other session // initialize initializes an broadcast session set
broadcastChannel.HandleMessage(func(s *melody.Session, msg []byte) { func initialize(c *gin.Context, channel string) {
broadcastChannel.BroadcastOthers(msg, s) // channel not found, create a new one
}) broadcastChannel := melody.New()
broadcastChannel.Config.MaxMessageSize = 1024 * 1024 * 128 // 128 MiB
// broadcast binary message to other session // broadcast string message to other session
broadcastChannel.HandleMessageBinary(func(s *melody.Session, msg []byte) { broadcastChannel.HandleMessage(func(s *melody.Session, msg []byte) {
broadcastChannel.BroadcastBinaryOthers(msg, s) broadcastChannel.BroadcastOthers(msg, s)
}) })
// recycling // broadcast binary message to other session
broadcastChannel.HandleClose(func(s *melody.Session, status int, reason string) error { broadcastChannel.HandleMessageBinary(func(s *melody.Session, msg []byte) {
channel := s.Keys["channel"].(string) broadcastChannel.BroadcastBinaryOthers(msg, s)
logging.LogInfof("close broadcast session in channel [%s] with status code %d: %s", channel, status, reason) })
count := broadcastChannel.Len() // recycling
if count == 0 { broadcastChannel.HandleClose(func(s *melody.Session, status int, reason string) error {
BroadcastChannels.Delete(channel) channel := s.Keys["channel"].(string)
logging.LogInfof("dispose broadcast channel [%s]", channel) logging.LogInfof("close broadcast session in channel [%s] with status code %d: %s", channel, status, reason)
count := broadcastChannel.Len()
if count == 0 {
BroadcastChannels.Delete(channel)
broadcastChannel.Close()
logging.LogInfof("dispose broadcast channel [%s]", channel)
}
return nil
})
for {
// Melody Initialization is an asynchronous process, so we need to wait for it to complete
if broadcastChannel.IsClosed() {
time.Sleep(1 * time.Nanosecond)
} else {
_broadcastChannel, loaded := BroadcastChannels.LoadOrStore(channel, broadcastChannel)
__broadcastChannel := _broadcastChannel.(*melody.Melody)
if loaded {
// channel exists
if __broadcastChannel.IsClosed() {
// channel is closed, replace it
BroadcastChannels.Store(channel, broadcastChannel)
__broadcastChannel = broadcastChannel
} else {
// channel is open, close the new one
broadcastChannel.Close()
}
} }
return nil subscribe(c, __broadcastChannel, channel)
}) break
}
} }
} }
@ -104,19 +136,26 @@ func subscribe(c *gin.Context, broadcastChannel *melody.Melody, channel string)
} }
} }
/* // postMessage send string message to a broadcast channel
postMessage send string message to a broadcast channel // @param
//
@param // {
// channel: string // channel name
body.channel: channel name // message: string // message payload
body.message: message payload // }
//
@returns // @returns
//
body.data.channel.name: channel name // {
body.data.channel.count: indicate how many websocket session received the message // code: int,
*/ // msg: string,
// data: {
// channel: {
// name: string, //channel name
// count: string, //listener count
// },
// },
// }
func postMessage(c *gin.Context) { func postMessage(c *gin.Context) {
ret := gulu.Ret.NewResult() ret := gulu.Ret.NewResult()
defer c.JSON(http.StatusOK, ret) defer c.JSON(http.StatusOK, ret)
@ -126,16 +165,14 @@ func postMessage(c *gin.Context) {
return return
} }
channel := arg["channel"].(string)
message := arg["message"].(string) message := arg["message"].(string)
channel := &Channel{
Name: arg["channel"].(string),
Count: 0,
}
if _broadcastChannel, ok := BroadcastChannels.Load(channel); !ok { if _broadcastChannel, ok := BroadcastChannels.Load(channel.Name); !ok {
err := fmt.Errorf("broadcast channel [%s] not found", channel) channel.Count = 0
logging.LogWarnf(err.Error())
ret.Code = -1
ret.Msg = err.Error()
return
} else { } else {
var broadcastChannel = _broadcastChannel.(*melody.Melody) var broadcastChannel = _broadcastChannel.(*melody.Melody)
if err := broadcastChannel.Broadcast([]byte(message)); nil != err { if err := broadcastChannel.Broadcast([]byte(message)); nil != err {
@ -146,27 +183,33 @@ func postMessage(c *gin.Context) {
return return
} }
count := broadcastChannel.Len() channel.Count = broadcastChannel.Len()
ret.Data = map[string]interface{}{ }
"channel": &Channel{ ret.Data = map[string]interface{}{
Name: channel, "channel": channel,
Count: count,
},
}
} }
} }
/* // getChannelInfo gets the information of a broadcast channel
getChannelInfo gets the information of a broadcast channel //
// @param
@param //
// {
body.name: channel name // name: string, // channel name
// }
@returns //
// @returns
body.data.channel: the channel information //
*/ // {
// code: int,
// msg: string,
// data: {
// channel: {
// name: string, //channel name
// count: string, //listener count
// },
// },
// }
func getChannelInfo(c *gin.Context) { func getChannelInfo(c *gin.Context) {
ret := gulu.Ret.NewResult() ret := gulu.Ret.NewResult()
defer c.JSON(http.StatusOK, ret) defer c.JSON(http.StatusOK, ret)
@ -176,38 +219,37 @@ func getChannelInfo(c *gin.Context) {
return return
} }
name := arg["name"].(string) channel := &Channel{
Name: arg["name"].(string),
Count: 0,
}
if _broadcastChannel, ok := BroadcastChannels.Load(name); !ok { if _broadcastChannel, ok := BroadcastChannels.Load(channel.Name); !ok {
err := fmt.Errorf("broadcast channel [%s] not found", name) channel.Count = 0
logging.LogWarnf(err.Error())
ret.Code = -1
ret.Msg = err.Error()
return
} else { } else {
var broadcastChannel = _broadcastChannel.(*melody.Melody) var broadcastChannel = _broadcastChannel.(*melody.Melody)
channel.Count = broadcastChannel.Len()
}
count := broadcastChannel.Len() ret.Data = map[string]interface{}{
ret.Data = map[string]interface{}{ "channel": channel,
"channel": &Channel{
Name: name,
Count: count,
},
}
} }
} }
/* // getChannels gets the channel name and lintener number of all broadcast chanel
getChannels gets the channel name and lintener number of all broadcast chanel //
// @returns
@returns //
// {
body.data.channels: { // code: int,
name: channel name // msg: string,
count: listener count // data: {
}[] // channels: {
*/ // name: string, //channel name
// count: string, //listener count
// }[],
// },
// }
func getChannels(c *gin.Context) { func getChannels(c *gin.Context) {
ret := gulu.Ret.NewResult() ret := gulu.Ret.NewResult()
defer c.JSON(http.StatusOK, ret) defer c.JSON(http.StatusOK, ret)