5
0
mirror of https://github.com/wailsapp/wails.git synced 2025-05-03 14:29:50 +08:00

add locker for websocket.WriteMessage

This commit is contained in:
unknown 2021-10-11 21:10:55 +08:00
parent 356774e3f7
commit 995fe38ee4

View File

@ -35,7 +35,7 @@ type DevWebServer struct {
dispatcher frontend.Dispatcher dispatcher frontend.Dispatcher
assetServer *assetserver.BrowserAssetServer assetServer *assetserver.BrowserAssetServer
socketMutex sync.Mutex socketMutex sync.Mutex
websocketClients map[*websocket.Conn]struct{} websocketClients map[*websocket.Conn]*sync.Mutex
menuManager *menumanager.Manager menuManager *menumanager.Manager
starttime string starttime string
@ -58,6 +58,7 @@ func (d *DevWebServer) Run(ctx context.Context) error {
d.server.Get("/wails/ipc", websocket.New(func(c *websocket.Conn) { d.server.Get("/wails/ipc", websocket.New(func(c *websocket.Conn) {
d.newWebsocketSession(c) d.newWebsocketSession(c)
locker := d.websocketClients[c]
// websocket.Conn bindings https://pkg.go.dev/github.com/fasthttp/websocket?tab=doc#pkg-index // websocket.Conn bindings https://pkg.go.dev/github.com/fasthttp/websocket?tab=doc#pkg-index
var ( var (
mt int mt int
@ -85,9 +86,12 @@ func (d *DevWebServer) Run(ctx context.Context) error {
d.logger.Error(err.Error()) d.logger.Error(err.Error())
} }
if result != "" { if result != "" {
locker.Lock()
if err = c.WriteMessage(mt, []byte(result)); err != nil { if err = c.WriteMessage(mt, []byte(result)); err != nil {
locker.Unlock()
break break
} }
locker.Unlock()
} }
} }
@ -293,7 +297,7 @@ func (d *DevWebServer) newWebsocketSession(c *websocket.Conn) {
d.LogDebug(fmt.Sprintf("Websocket client %p disconnected", c)) d.LogDebug(fmt.Sprintf("Websocket client %p disconnected", c))
return nil return nil
}) })
d.websocketClients[c] = struct{}{} d.websocketClients[c] = &sync.Mutex{}
d.LogDebug(fmt.Sprintf("Websocket client %p connected", c)) d.LogDebug(fmt.Sprintf("Websocket client %p connected", c))
} }
@ -305,16 +309,21 @@ type EventNotify struct {
func (d *DevWebServer) broadcast(message string) { func (d *DevWebServer) broadcast(message string) {
d.socketMutex.Lock() d.socketMutex.Lock()
defer d.socketMutex.Unlock() defer d.socketMutex.Unlock()
for client := range d.websocketClients { for client, locker := range d.websocketClients {
go func() {
if client == nil { if client == nil {
d.logger.Error("Lost connection to websocket server") d.logger.Error("Lost connection to websocket server")
return return
} }
locker.Lock()
err := client.WriteMessage(websocket.TextMessage, []byte(message)) err := client.WriteMessage(websocket.TextMessage, []byte(message))
if err != nil { if err != nil {
locker.Unlock()
d.logger.Error(err.Error()) d.logger.Error(err.Error())
return return
} }
locker.Unlock()
}()
} }
} }
@ -335,15 +344,20 @@ func (d *DevWebServer) notify(name string, data ...interface{}) {
func (d *DevWebServer) broadcastExcludingSender(message string, sender *websocket.Conn) { func (d *DevWebServer) broadcastExcludingSender(message string, sender *websocket.Conn) {
d.socketMutex.Lock() d.socketMutex.Lock()
defer d.socketMutex.Unlock() defer d.socketMutex.Unlock()
for client := range d.websocketClients { for client, locker := range d.websocketClients {
go func() {
if client == sender { if client == sender {
continue return
} }
locker.Lock()
err := client.WriteMessage(websocket.TextMessage, []byte(message)) err := client.WriteMessage(websocket.TextMessage, []byte(message))
if err != nil { if err != nil {
locker.Unlock()
d.logger.Error(err.Error()) d.logger.Error(err.Error())
return return
} }
locker.Unlock()
}()
} }
} }
@ -374,7 +388,7 @@ func NewFrontend(ctx context.Context, appoptions *options.App, myLogger *logger.
DisableStartupMessage: true, DisableStartupMessage: true,
}), }),
menuManager: menuManager, menuManager: menuManager,
websocketClients: make(map[*websocket.Conn]struct{}), websocketClients: make(map[*websocket.Conn]*sync.Mutex),
} }
return result return result
} }