fixes with managing the sse connections

This commit is contained in:
maddalax 2024-10-02 10:48:41 -05:00
parent 5b60b9e915
commit 33b4b3299e
4 changed files with 112 additions and 68 deletions

View file

@ -25,7 +25,7 @@ func NewManager(locator *service.Locator) *Manager {
} }
func (m *Manager) StartListener() { func (m *Manager) StartListener() {
c := make(chan ws.SocketEvent) c := make(chan ws.SocketEvent, 1)
m.socketManager.Listen(c) m.socketManager.Listen(c)
for { for {
@ -33,11 +33,13 @@ func (m *Manager) StartListener() {
case event := <-c: case event := <-c:
switch event.Type { switch event.Type {
case ws.ConnectedEvent: case ws.ConnectedEvent:
m.OnConnected(event) go m.OnConnected(event)
case ws.DisconnectedEvent: case ws.DisconnectedEvent:
m.OnDisconnected(event) go m.OnDisconnected(event)
case ws.MessageEvent: case ws.MessageEvent:
m.onMessage(event) go m.onMessage(event)
default:
fmt.Printf("Unknown event type: %s\n", event.Type)
} }
} }
} }
@ -80,7 +82,7 @@ func (m *Manager) OnConnected(e ws.SocketEvent) {
}, },
) )
go m.backFill(e.Id, e.RoomId) m.backFill(e.Id, e.RoomId)
} }
func (m *Manager) OnDisconnected(e ws.SocketEvent) { func (m *Manager) OnDisconnected(e ws.SocketEvent) {

View file

@ -5,10 +5,13 @@ import (
"chat/chat" "chat/chat"
"chat/internal/db" "chat/internal/db"
"chat/ws" "chat/ws"
"fmt"
"github.com/maddalax/htmgo/framework/h" "github.com/maddalax/htmgo/framework/h"
"github.com/maddalax/htmgo/framework/service" "github.com/maddalax/htmgo/framework/service"
"io/fs" "io/fs"
"net/http" "net/http"
"runtime"
"time"
) )
func main() { func main() {
@ -22,6 +25,14 @@ func main() {
chatManager := chat.NewManager(locator) chatManager := chat.NewManager(locator)
go chatManager.StartListener() go chatManager.StartListener()
go func() {
for {
count := runtime.NumGoroutine()
fmt.Printf("goroutines: %d\n", count)
time.Sleep(10 * time.Second)
}
}()
h.Start(h.AppOpts{ h.Start(h.AppOpts{
ServiceLocator: locator, ServiceLocator: locator,
LiveReload: true, LiveReload: true,

View file

@ -7,6 +7,7 @@ import (
"github.com/maddalax/htmgo/framework/service" "github.com/maddalax/htmgo/framework/service"
"log/slog" "log/slog"
"net/http" "net/http"
"sync"
"time" "time"
) )
@ -21,55 +22,79 @@ func Handle() http.HandlerFunc {
cc := r.Context().Value(h.RequestContextKey).(*h.RequestContext) cc := r.Context().Value(h.RequestContextKey).(*h.RequestContext)
locator := cc.ServiceLocator() locator := cc.ServiceLocator()
manager := service.Get[SocketManager](locator) manager := service.Get[SocketManager](locator)
// Flush the headers immediately
flusher, ok := w.(http.Flusher)
sessionCookie, _ := r.Cookie("session_id") sessionCookie, _ := r.Cookie("session_id")
sessionId := ""
if sessionCookie == nil { if sessionCookie != nil {
manager.writeCloseRaw(w, flusher, "no session") sessionId = sessionCookie.Value
return
} }
sessionId := sessionCookie.Value ctx := r.Context()
done := make(chan CloseEvent, 1)
writer := make(WriterChan, 1)
roomId := chi.URLParam(r, "id") wg := sync.WaitGroup{}
wg.Add(1)
if roomId == "" { /*
slog.Error("invalid room", slog.String("room_id", roomId)) * This goroutine is responsible for writing messages to the client
manager.writeCloseRaw(w, flusher, "invalid room") */
return go func() {
} defer wg.Done()
defer manager.Disconnect(sessionId)
done := make(chan CloseEvent, 50)
flush := make(chan bool, 50)
manager.Add(roomId, sessionId, w, done, flush)
defer func() {
manager.Disconnect(sessionId)
}()
if !ok {
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
return
}
ticker := time.NewTicker(5 * time.Second) ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ctx.Done():
return
case reason := <-done:
fmt.Printf("closing connection: %s\n", reason.Reason)
return
case <-ticker.C: case <-ticker.C:
manager.Ping(sessionId) manager.Ping(sessionId)
case <-flush: case message := <-writer:
if flusher != nil { _, err := fmt.Fprintf(w, message)
if err != nil {
done <- CloseEvent{
Code: -1,
Reason: err.Error(),
}
} else {
flusher, ok := w.(http.Flusher)
if ok {
flusher.Flush() flusher.Flush()
} }
case <-done: // Client closed the connection }
fmt.Println("Client disconnected") }
}
}()
/**
* This goroutine is responsible for adding the client to the room
*/
wg.Add(1)
go func() {
defer wg.Done()
if sessionId == "" {
manager.writeCloseRaw(writer, "no session")
return return
} }
}
roomId := chi.URLParam(r, "id")
if roomId == "" {
slog.Error("invalid room", slog.String("room_id", roomId))
manager.writeCloseRaw(writer, "invalid room")
return
}
manager.Add(roomId, sessionId, writer, done)
}()
wg.Wait()
} }
} }

View file

@ -3,10 +3,12 @@ package ws
import ( import (
"fmt" "fmt"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"net/http" "time"
) )
type EventType string type EventType string
type WriterChan chan string
type DoneChan chan CloseEvent
const ( const (
ConnectedEvent EventType = "connected" ConnectedEvent EventType = "connected"
@ -28,10 +30,9 @@ type CloseEvent struct {
type SocketConnection struct { type SocketConnection struct {
Id string Id string
Writer http.ResponseWriter
RoomId string RoomId string
Done chan CloseEvent Done DoneChan
Flush chan bool Writer WriterChan
} }
type SocketManager struct { type SocketManager struct {
@ -62,13 +63,29 @@ func (manager *SocketManager) Listen(listener chan SocketEvent) {
if manager.listeners == nil { if manager.listeners == nil {
manager.listeners = make([]chan SocketEvent, 0) manager.listeners = make([]chan SocketEvent, 0)
} }
if listener != nil {
manager.listeners = append(manager.listeners, listener) manager.listeners = append(manager.listeners, listener)
} }
}
func (manager *SocketManager) dispatch(event SocketEvent) { func (manager *SocketManager) dispatch(event SocketEvent) {
fmt.Printf("dispatching event: %s\n", event.Type)
done := make(chan struct{}, 1)
go func() {
for {
select {
case <-done:
fmt.Printf("dispatched event: %s\n", event.Type)
return
case <-time.After(5 * time.Second):
fmt.Printf("havent dispatched event after 5s, chan blocked: %s\n", event.Type)
}
}
}()
for _, listener := range manager.listeners { for _, listener := range manager.listeners {
listener <- event listener <- event
} }
done <- struct{}{}
} }
func (manager *SocketManager) OnMessage(id string, message map[string]any) { func (manager *SocketManager) OnMessage(id string, message map[string]any) {
@ -84,7 +101,7 @@ func (manager *SocketManager) OnMessage(id string, message map[string]any) {
}) })
} }
func (manager *SocketManager) Add(roomId string, id string, writer http.ResponseWriter, done chan CloseEvent, flush chan bool) { func (manager *SocketManager) Add(roomId string, id string, writer chan string, done chan CloseEvent) {
manager.idToRoom.Store(id, roomId) manager.idToRoom.Store(id, roomId)
sockets, ok := manager.sockets.LoadOrCompute(roomId, func() *xsync.MapOf[string, SocketConnection] { sockets, ok := manager.sockets.LoadOrCompute(roomId, func() *xsync.MapOf[string, SocketConnection] {
@ -96,7 +113,6 @@ func (manager *SocketManager) Add(roomId string, id string, writer http.Response
Writer: writer, Writer: writer,
RoomId: roomId, RoomId: roomId,
Done: done, Done: done,
Flush: flush,
}) })
s, ok := sockets.Load(id) s, ok := sockets.Load(id)
@ -110,6 +126,8 @@ func (manager *SocketManager) Add(roomId string, id string, writer http.Response
RoomId: s.RoomId, RoomId: s.RoomId,
Payload: map[string]any{}, Payload: map[string]any{},
}) })
fmt.Printf("User %s connected to %s\n", id, roomId)
} }
func (manager *SocketManager) OnClose(id string) { func (manager *SocketManager) OnClose(id string) {
@ -141,7 +159,7 @@ func (manager *SocketManager) CloseWithMessage(id string, message string) {
func (manager *SocketManager) Disconnect(id string) { func (manager *SocketManager) Disconnect(id string) {
conn := manager.Get(id) conn := manager.Get(id)
if conn != nil { if conn != nil {
go manager.OnClose(id) manager.OnClose(id)
conn.Done <- CloseEvent{ conn.Done <- CloseEvent{
Code: -1, Code: -1,
Reason: "", Reason: "",
@ -169,35 +187,23 @@ func (manager *SocketManager) Ping(id string) {
} }
} }
func (manager *SocketManager) writeCloseRaw(writer http.ResponseWriter, flusher http.Flusher, message string) { func (manager *SocketManager) writeCloseRaw(writer WriterChan, message string) {
err := manager.writeTextRaw(writer, "close", message) manager.writeTextRaw(writer, "close", message)
if err == nil {
flusher.Flush()
}
} }
func (manager *SocketManager) writeTextRaw(writer http.ResponseWriter, event string, message string) error { func (manager *SocketManager) writeTextRaw(writer WriterChan, event string, message string) {
if writer == nil {
return nil
}
var err error
if event != "" { if event != "" {
_, err = fmt.Fprintf(writer, "event: %s\ndata: %s\n\n", event, message) writer <- fmt.Sprintf("event: %s\ndata: %s\n\n", event, message)
} else { } else {
_, err = fmt.Fprintf(writer, "data: %s\n\n", message) writer <- fmt.Sprintf("data: %s\n\n", message)
} }
return err
} }
func (manager *SocketManager) writeText(socket SocketConnection, event string, message string) { func (manager *SocketManager) writeText(socket SocketConnection, event string, message string) {
if socket.Writer == nil { if socket.Writer == nil {
return return
} }
err := manager.writeTextRaw(socket.Writer, event, message) manager.writeTextRaw(socket.Writer, event, message)
if err != nil && event != "close" {
manager.CloseWithMessage(socket.Id, "failed to write message")
}
socket.Flush <- true
} }
func (manager *SocketManager) BroadcastText(roomId string, message string, predicate func(conn SocketConnection) bool) { func (manager *SocketManager) BroadcastText(roomId string, message string, predicate func(conn SocketConnection) bool) {