diff --git a/extensions/websocket/ws/manager.go b/extensions/websocket/ws/manager.go index a4f5a2d..8d5616b 100644 --- a/extensions/websocket/ws/manager.go +++ b/extensions/websocket/ws/manager.go @@ -55,15 +55,16 @@ type ManagerMetrics struct { } type SocketManager struct { - sockets *xsync.MapOf[string, *xsync.MapOf[string, SocketConnection]] - idToRoom *xsync.MapOf[string, string] - listeners []chan SocketEvent - goroutinesRunning atomic.Int32 - opts *opts.ExtensionOpts - lock sync.Mutex - totalMessages atomic.Int64 - messagesPerSecond int - secondsElapsed int + sockets *xsync.MapOf[string, *xsync.MapOf[string, SocketConnection]] + idToRoom *xsync.MapOf[string, string] + listeners []chan SocketEvent + listenersQueuedForDelete map[chan SocketEvent]bool + goroutinesRunning atomic.Int32 + opts *opts.ExtensionOpts + lock sync.Mutex + totalMessages atomic.Int64 + messagesPerSecond int + secondsElapsed int } func (manager *SocketManager) StartMetrics() { @@ -127,10 +128,11 @@ func SocketManagerFromCtx(ctx *h.RequestContext) *SocketManager { func NewSocketManager(opts *opts.ExtensionOpts) *SocketManager { return &SocketManager{ - sockets: xsync.NewMapOf[string, *xsync.MapOf[string, SocketConnection]](), - idToRoom: xsync.NewMapOf[string, string](), - opts: opts, - goroutinesRunning: atomic.Int32{}, + sockets: xsync.NewMapOf[string, *xsync.MapOf[string, SocketConnection]](), + idToRoom: xsync.NewMapOf[string, string](), + listenersQueuedForDelete: make(map[chan SocketEvent]bool), + opts: opts, + goroutinesRunning: atomic.Int32{}, } } @@ -186,14 +188,7 @@ func (manager *SocketManager) Listen(listener chan SocketEvent) { } func (manager *SocketManager) RemoveListener(listener chan SocketEvent) { - for i, l := range manager.listeners { - if l == listener { - slog.Debug("ws-extension: removed listener from manager") - manager.listeners = append(manager.listeners[:i], manager.listeners[i+1:]...) - slog.Debug("ws-extension: total listeners", slog.Int("count", len(manager.listeners))) - return - } - } + manager.listenersQueuedForDelete[listener] = true } func (manager *SocketManager) dispatch(event SocketEvent) { @@ -208,9 +203,27 @@ func (manager *SocketManager) dispatch(event SocketEvent) { } } }() - for _, listener := range manager.listeners { - listener <- event + + if len(manager.listenersQueuedForDelete) > 0 { + newListener := make([]chan SocketEvent, 0) + for _, listener := range manager.listeners { + if _, ok := manager.listenersQueuedForDelete[listener]; !ok { + newListener = append(newListener, listener) + } + } + manager.listeners = newListener } + + wg := sync.WaitGroup{} + for _, listener := range manager.listeners { + wg.Add(1) + go func() { + defer wg.Done() + listener <- event + }() + } + + wg.Wait() done <- struct{}{} }