From c052a83eceafed07c3ab984d75b070e07cf2b0b8 Mon Sep 17 00:00:00 2001 From: maddalax Date: Mon, 4 Nov 2024 09:34:20 -0600 Subject: [PATCH] fix manager ws delete, add manager tests --- examples/ws-example/main.go | 6 + examples/ws-example/pages/index.go | 30 ++- examples/ws-example/partials/index.go | 4 +- examples/ws-example/partials/repeater.go | 14 +- .../websocket/internal/wsutil/manager.go | 35 ++- .../websocket/internal/wsutil/manager_test.go | 202 ++++++++++++++++++ extensions/websocket/ws/attribute.go | 6 +- extensions/websocket/ws/dispatch.go | 27 ++- extensions/websocket/ws/every.go | 44 ++++ extensions/websocket/ws/handler.go | 2 + framework/h/app.go | 10 + framework/h/attribute.go | 1 - framework/session/state.go | 9 +- 13 files changed, 363 insertions(+), 27 deletions(-) create mode 100644 extensions/websocket/internal/wsutil/manager_test.go create mode 100644 extensions/websocket/ws/every.go diff --git a/examples/ws-example/main.go b/examples/ws-example/main.go index 3e2b9ff..576d8d9 100644 --- a/examples/ws-example/main.go +++ b/examples/ws-example/main.go @@ -5,6 +5,7 @@ import ( ws2 "github.com/maddalax/htmgo/extensions/websocket/opts" "github.com/maddalax/htmgo/framework/h" "github.com/maddalax/htmgo/framework/service" + "github.com/maddalax/htmgo/framework/session" "io/fs" "net/http" "ws-example/__htmgo" @@ -17,6 +18,11 @@ func main() { ServiceLocator: locator, LiveReload: true, Register: func(app *h.App) { + + app.Use(func(ctx *h.RequestContext) { + session.CreateSession(ctx) + }) + websocket.EnableExtension(app, ws2.ExtensionOpts{ WsPath: "/ws", SessionId: func(ctx *h.RequestContext) string { diff --git a/examples/ws-example/pages/index.go b/examples/ws-example/pages/index.go index 4cf1601..2bf090e 100644 --- a/examples/ws-example/pages/index.go +++ b/examples/ws-example/pages/index.go @@ -5,21 +5,39 @@ import ( "github.com/maddalax/htmgo/extensions/websocket/ws" "github.com/maddalax/htmgo/framework/h" "github.com/maddalax/htmgo/framework/session" + "time" "ws-example/partials" ) func IndexPage(ctx *h.RequestContext) *h.Page { sessionId := session.GetSessionId(ctx) + + ws.Every(ctx, time.Second, func() bool { + return ws.PushElementCtx( + ctx, + h.Div( + h.Attribute("hx-swap-oob", "true"), + h.Id("current-time"), + h.TextF("Current time: %s", time.Now().Format("15:04:05")), + ), + ) + }) + return h.NewPage( RootPage( ctx, h.Div( h.Attribute("ws-connect", fmt.Sprintf("/ws?sessionId=%s", sessionId)), h.Class("flex flex-col gap-4 items-center pt-24 min-h-screen bg-neutral-100"), - h.H3(h.Id("intro-text"), h.Text("Repeater Example"), h.Class("text-2xl")), - + h.H3( + h.Id("intro-text"), + h.Text("Repeater Example"), + h.Class("text-2xl"), + ), + h.Div( + h.Id("current-time"), + ), partials.CounterForm(ctx, partials.CounterProps{Id: "counter-1"}), - partials.Repeater(ctx, partials.RepeaterProps{ Id: "repeater-1", OnAdd: func(data ws.HandlerData) { @@ -38,9 +56,11 @@ func IndexPage(ctx *h.RequestContext) *h.Page { ) }, Item: func(index int) *h.Element { - return h.Input("text", + return h.Input( + "text", h.Class("border border-gray-300 rounded p-2"), - h.Value(fmt.Sprintf("item %d", index))) + h.Value(fmt.Sprintf("item %d", index)), + ) }, }), ), diff --git a/examples/ws-example/partials/index.go b/examples/ws-example/partials/index.go index 15b4466..fae8de9 100644 --- a/examples/ws-example/partials/index.go +++ b/examples/ws-example/partials/index.go @@ -59,11 +59,11 @@ func CounterForm(ctx *h.RequestContext, props CounterProps) *h.Element { h.Class("bg-rose-400 hover:bg-rose-500 text-white font-bold py-2 px-4 rounded"), h.Type("submit"), h.Text("Increment"), - ws.OnServerSideEvent(ctx, "increment", func(data ws.HandlerData) { + ws.OnServerEvent(ctx, "increment", func(data ws.HandlerData) { counter.Increment() ws.PushElement(data, CounterForm(ctx, props)) }), - ws.OnServerSideEvent(ctx, "decrement", func(data ws.HandlerData) { + ws.OnServerEvent(ctx, "decrement", func(data ws.HandlerData) { counter.Decrement() ws.PushElement(data, CounterForm(ctx, props)) }), diff --git a/examples/ws-example/partials/repeater.go b/examples/ws-example/partials/repeater.go index 858bc70..c357c83 100644 --- a/examples/ws-example/partials/repeater.go +++ b/examples/ws-example/partials/repeater.go @@ -31,13 +31,18 @@ func repeaterItem(ctx *h.RequestContext, item *h.Element, index int, props *Repe h.Class("flex gap-2 items-center"), h.Id(id), item, - props.RemoveButton(index, + props.RemoveButton( + index, h.ClassIf(index == 0, "opacity-0 disabled"), - h.If(index == 0, h.Disabled()), + h.If( + index == 0, + h.Disabled(), + ), ws.OnClick(ctx, func(data ws.HandlerData) { props.OnRemove(data, index) props.currentIndex-- - ws.PushElement(data, + ws.PushElement( + data, h.Div( h.Attribute("hx-swap-oob", fmt.Sprintf("delete:#%s", id)), h.Div(), @@ -63,7 +68,8 @@ func Repeater(ctx *h.RequestContext, props RepeaterProps) *h.Element { props.AddButton, ws.OnClick(ctx, func(data ws.HandlerData) { props.OnAdd(data) - ws.PushElement(data, + ws.PushElement( + data, h.Div( h.Attribute("hx-swap-oob", "beforebegin:#"+props.addButtonId()), repeaterItem( diff --git a/extensions/websocket/internal/wsutil/manager.go b/extensions/websocket/internal/wsutil/manager.go index 1c11daf..b9eb566 100644 --- a/extensions/websocket/internal/wsutil/manager.go +++ b/extensions/websocket/internal/wsutil/manager.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/maddalax/htmgo/extensions/websocket/opts" "github.com/puzpuzpuz/xsync/v3" + "log/slog" "strings" "time" ) @@ -139,13 +140,25 @@ func (manager *SocketManager) OnClose(id string) { if socket == nil { return } + slog.Debug("ws-extension: removing socket from manager", slog.String("socketId", id)) manager.dispatch(SocketEvent{ SessionId: id, Type: DisconnectedEvent, RoomId: socket.RoomId, Payload: map[string]any{}, }) - manager.sockets.Delete(id) + roomId, ok := manager.idToRoom.Load(id) + if !ok { + return + } + sockets, ok := manager.sockets.Load(roomId) + if !ok { + return + } + sockets.Delete(id) + manager.idToRoom.Delete(id) + slog.Debug("ws-extension: removed socket from manager", slog.String("socketId", id)) + } func (manager *SocketManager) CloseWithMessage(id string, message string) { @@ -178,11 +191,12 @@ func (manager *SocketManager) Get(id string) *SocketConnection { return &conn } -func (manager *SocketManager) Ping(id string) { +func (manager *SocketManager) Ping(id string) bool { conn := manager.Get(id) if conn != nil { - manager.writeText(*conn, "ping") + return manager.writeText(*conn, "ping") } + return false } func (manager *SocketManager) writeCloseRaw(writer WriterChan, message string) { @@ -198,11 +212,12 @@ func (manager *SocketManager) writeTextRaw(writer WriterChan, message string) { } } -func (manager *SocketManager) writeText(socket SocketConnection, message string) { +func (manager *SocketManager) writeText(socket SocketConnection, message string) bool { if socket.Writer == nil { - return + return false } manager.writeTextRaw(socket.Writer, message) + return true } func (manager *SocketManager) BroadcastText(roomId string, message string, predicate func(conn SocketConnection) bool) { @@ -220,19 +235,21 @@ func (manager *SocketManager) BroadcastText(roomId string, message string, predi }) } -func (manager *SocketManager) SendHtml(id string, message string) { +func (manager *SocketManager) SendHtml(id string, message string) bool { conn := manager.Get(id) minified := strings.ReplaceAll(message, "\n", "") minified = strings.ReplaceAll(minified, "\t", "") minified = strings.TrimSpace(minified) if conn != nil { - manager.writeText(*conn, minified) + return manager.writeText(*conn, minified) } + return false } -func (manager *SocketManager) SendText(id string, message string) { +func (manager *SocketManager) SendText(id string, message string) bool { conn := manager.Get(id) if conn != nil { - manager.writeText(*conn, message) + return manager.writeText(*conn, message) } + return false } diff --git a/extensions/websocket/internal/wsutil/manager_test.go b/extensions/websocket/internal/wsutil/manager_test.go new file mode 100644 index 0000000..0ba95d4 --- /dev/null +++ b/extensions/websocket/internal/wsutil/manager_test.go @@ -0,0 +1,202 @@ +package wsutil + +import ( + ws2 "github.com/maddalax/htmgo/extensions/websocket/opts" + "github.com/maddalax/htmgo/framework/h" + "github.com/stretchr/testify/assert" + "testing" +) + +func createManager() *SocketManager { + return NewSocketManager(&ws2.ExtensionOpts{ + WsPath: "/ws", + SessionId: func(ctx *h.RequestContext) string { + return "test" + }, + }) +} + +func addSocket(manager *SocketManager, roomId string, id string) (socketId string, writer WriterChan, done DoneChan) { + writer = make(chan string, 10) + done = make(chan bool, 10) + manager.Add(roomId, id, writer, done) + return id, writer, done +} + +func TestManager(t *testing.T) { + manager := createManager() + socketId, _, _ := addSocket(manager, "123", "456") + socket := manager.Get(socketId) + assert.NotNil(t, socket) + assert.Equal(t, socketId, socket.Id) + + manager.OnClose(socketId) + socket = manager.Get(socketId) + assert.Nil(t, socket) +} + +func TestManagerForEachSocket(t *testing.T) { + manager := createManager() + addSocket(manager, "all", "456") + addSocket(manager, "all", "789") + var count int + manager.ForEachSocket("all", func(conn SocketConnection) { + count++ + }) + assert.Equal(t, 2, count) +} + +func TestSendText(t *testing.T) { + manager := createManager() + socketId, writer, done := addSocket(manager, "all", "456") + manager.SendText(socketId, "hello") + assert.Equal(t, "hello", <-writer) + manager.SendText(socketId, "hello2") + assert.Equal(t, "hello2", <-writer) + done <- true + assert.Equal(t, true, <-done) +} + +func TestBroadcastText(t *testing.T) { + manager := createManager() + _, w1, d1 := addSocket(manager, "all", "456") + _, w2, d2 := addSocket(manager, "all", "789") + manager.BroadcastText("all", "hello", func(conn SocketConnection) bool { + return true + }) + assert.Equal(t, "hello", <-w1) + assert.Equal(t, "hello", <-w2) + d1 <- true + d2 <- true + assert.Equal(t, true, <-d1) + assert.Equal(t, true, <-d2) +} + +func TestBroadcastTextWithPredicate(t *testing.T) { + manager := createManager() + _, w1, _ := addSocket(manager, "all", "456") + _, w2, _ := addSocket(manager, "all", "789") + manager.BroadcastText("all", "hello", func(conn SocketConnection) bool { + return conn.Id != "456" + }) + + assert.Equal(t, 0, len(w1)) + assert.Equal(t, 1, len(w2)) +} + +func TestSendHtml(t *testing.T) { + manager := createManager() + socketId, writer, _ := addSocket(manager, "all", "456") + rendered := h.Render( + h.Div( + h.P( + h.Text("hello"), + ), + )) + manager.SendHtml(socketId, rendered) + assert.Equal(t, "

hello

", <-writer) +} + +func TestOnMessage(t *testing.T) { + manager := createManager() + socketId, _, _ := addSocket(manager, "all", "456") + + listener := make(chan SocketEvent, 10) + + manager.Listen(listener) + + manager.OnMessage(socketId, map[string]any{ + "message": "hello", + }) + + event := <-listener + assert.Equal(t, "hello", event.Payload["message"]) + assert.Equal(t, "456", event.SessionId) + assert.Equal(t, MessageEvent, event.Type) + assert.Equal(t, "all", event.RoomId) +} + +func TestOnClose(t *testing.T) { + manager := createManager() + socketId, _, _ := addSocket(manager, "all", "456") + listener := make(chan SocketEvent, 10) + manager.Listen(listener) + manager.OnClose(socketId) + event := <-listener + assert.Equal(t, "456", event.SessionId) + assert.Equal(t, DisconnectedEvent, event.Type) + assert.Equal(t, "all", event.RoomId) +} + +func TestOnAdd(t *testing.T) { + manager := createManager() + + listener := make(chan SocketEvent, 10) + manager.Listen(listener) + + socketId, _, _ := addSocket(manager, "all", "456") + event := <-listener + + assert.Equal(t, socketId, event.SessionId) + assert.Equal(t, ConnectedEvent, event.Type) + assert.Equal(t, "all", event.RoomId) +} + +func TestCloseWithMessage(t *testing.T) { + manager := createManager() + socketId, w, _ := addSocket(manager, "all", "456") + manager.CloseWithMessage(socketId, "internal error") + assert.Equal(t, "internal error", <-w) + assert.Nil(t, manager.Get(socketId)) +} + +func TestDisconnect(t *testing.T) { + manager := createManager() + socketId, _, _ := addSocket(manager, "all", "456") + manager.Disconnect(socketId) + assert.Nil(t, manager.Get(socketId)) +} + +func TestPing(t *testing.T) { + manager := createManager() + socketId, w, _ := addSocket(manager, "all", "456") + manager.Ping(socketId) + assert.Equal(t, "ping", <-w) +} + +func TestMultipleRooms(t *testing.T) { + manager := createManager() + socketId1, _, _ := addSocket(manager, "room1", "456") + socketId2, _, _ := addSocket(manager, "room2", "789") + + room1Count := 0 + room2Count := 0 + + manager.ForEachSocket("room1", func(conn SocketConnection) { + room1Count++ + }) + + manager.ForEachSocket("room2", func(conn SocketConnection) { + room2Count++ + }) + + assert.Equal(t, 1, room1Count) + assert.Equal(t, 1, room2Count) + + room1Count = 0 + room2Count = 0 + + manager.OnClose(socketId1) + manager.OnClose(socketId2) + + manager.ForEachSocket("room1", func(conn SocketConnection) { + room1Count++ + }) + + manager.ForEachSocket("room2", func(conn SocketConnection) { + room2Count++ + }) + + assert.Equal(t, 0, room1Count) + assert.Equal(t, 0, room2Count) +} diff --git a/extensions/websocket/ws/attribute.go b/extensions/websocket/ws/attribute.go index 696519a..40a048a 100644 --- a/extensions/websocket/ws/attribute.go +++ b/extensions/websocket/ws/attribute.go @@ -6,7 +6,11 @@ func OnClick(ctx *h.RequestContext, handler Handler) *h.AttributeMapOrdered { return AddClientSideHandler(ctx, "click", handler) } -func OnServerSideEvent(ctx *h.RequestContext, eventName string, handler Handler) h.Ren { +func OnClientEvent(ctx *h.RequestContext, eventName string, handler Handler) *h.AttributeMapOrdered { + return AddClientSideHandler(ctx, eventName, handler) +} + +func OnServerEvent(ctx *h.RequestContext, eventName string, handler Handler) h.Ren { AddServerSideHandler(ctx, eventName, handler) return h.Attribute("data-handler-id", "") } diff --git a/extensions/websocket/ws/dispatch.go b/extensions/websocket/ws/dispatch.go index 293bb82..e76fa39 100644 --- a/extensions/websocket/ws/dispatch.go +++ b/extensions/websocket/ws/dispatch.go @@ -1,6 +1,11 @@ package ws -import "github.com/maddalax/htmgo/framework/h" +import ( + "github.com/maddalax/htmgo/extensions/websocket/internal/wsutil" + "github.com/maddalax/htmgo/framework/h" + "github.com/maddalax/htmgo/framework/service" + "github.com/maddalax/htmgo/framework/session" +) // PushServerSideEvent sends a server side event this specific session func PushServerSideEvent(data HandlerData, event string, value map[string]any) { @@ -21,6 +26,22 @@ func BroadcastServerSideEvent(event string, value map[string]any) { } // PushElement sends an element to the current session and swaps it into the page -func PushElement(data HandlerData, el *h.Element) { - data.Manager.SendHtml(data.Socket.Id, h.Render(el)) +func PushElement(data HandlerData, el *h.Element) bool { + return data.Manager.SendHtml(data.Socket.Id, h.Render(el)) +} + +// PushElementCtx sends an element to the current session and swaps it into the page +func PushElementCtx(ctx *h.RequestContext, el *h.Element) bool { + locator := ctx.ServiceLocator() + socketManager := service.Get[wsutil.SocketManager](locator) + socketId := session.GetSessionId(ctx) + socket := socketManager.Get(string(socketId)) + if socket == nil { + return false + } + return PushElement(HandlerData{ + Socket: socket, + Manager: socketManager, + SessionId: socketId, + }, el) } diff --git a/extensions/websocket/ws/every.go b/extensions/websocket/ws/every.go new file mode 100644 index 0000000..1e73fbd --- /dev/null +++ b/extensions/websocket/ws/every.go @@ -0,0 +1,44 @@ +package ws + +import ( + "github.com/maddalax/htmgo/extensions/websocket/internal/wsutil" + "github.com/maddalax/htmgo/framework/h" + "github.com/maddalax/htmgo/framework/service" + "github.com/maddalax/htmgo/framework/session" + "log/slog" + "time" +) + +// Every executes the given callback every interval, until the socket is disconnected, or the callback returns false. +func Every(ctx *h.RequestContext, interval time.Duration, cb func() bool) { + socketId := session.GetSessionId(ctx) + socketIdSlog := slog.String("socketId", string(socketId)) + + slog.Debug("ws-extension: starting every loop", socketIdSlog, slog.Duration("duration", interval)) + + go func() { + tries := 0 + for { + locator := ctx.ServiceLocator() + socketManager := service.Get[wsutil.SocketManager](locator) + socket := socketManager.Get(string(socketId)) + // This can run before the socket is established, lets try a few times and kill it if socket isn't connected after a bit. + if socket == nil { + if tries > 5 { + slog.Debug("ws-extension: socket disconnected, killing goroutine", socketIdSlog) + return + } else { + time.Sleep(time.Second) + tries++ + slog.Debug("ws-extension: socket not connected yet, trying again", socketIdSlog, slog.Int("attempt", tries)) + continue + } + } + success := cb() + if !success { + return + } + time.Sleep(interval) + } + }() +} diff --git a/extensions/websocket/ws/handler.go b/extensions/websocket/ws/handler.go index 72cdff8..a4de0e0 100644 --- a/extensions/websocket/ws/handler.go +++ b/extensions/websocket/ws/handler.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/maddalax/htmgo/extensions/websocket/internal/wsutil" "github.com/maddalax/htmgo/framework/session" + "log/slog" "sync" ) @@ -78,6 +79,7 @@ func (h *MessageHandler) OnDomElementRemoved(handlerId string) { } func (h *MessageHandler) OnSocketDisconnected(event wsutil.SocketEvent) { + slog.Info("ws-extension: socket disconnected", slog.String("socketId", event.SessionId)) sessionId := session.Id(event.SessionId) hashes, ok := sessionIdToHashes.Load(sessionId) if ok { diff --git a/framework/h/app.go b/framework/h/app.go index 3c095a5..f45b43b 100644 --- a/framework/h/app.go +++ b/framework/h/app.go @@ -174,6 +174,16 @@ func (app *App) UseWithContext(h func(w http.ResponseWriter, r *http.Request, co }) } +func (app *App) Use(h func(ctx *RequestContext)) { + app.Router.Use(func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cc := r.Context().Value(RequestContextKey).(*RequestContext) + h(cc) + handler.ServeHTTP(w, r) + }) + }) +} + func GetLogLevel() slog.Level { // Get the log level from the environment variable logLevel := os.Getenv("LOG_LEVEL") diff --git a/framework/h/attribute.go b/framework/h/attribute.go index daab17f..806b408 100644 --- a/framework/h/attribute.go +++ b/framework/h/attribute.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/maddalax/htmgo/framework/datastructure/orderedmap" "github.com/maddalax/htmgo/framework/hx" - "github.com/maddalax/htmgo/framework/internal/datastructure" "github.com/maddalax/htmgo/framework/internal/util" "strings" ) diff --git a/framework/session/state.go b/framework/session/state.go index b3adbd2..24d35ef 100644 --- a/framework/session/state.go +++ b/framework/session/state.go @@ -22,13 +22,18 @@ func NewState(ctx *h.RequestContext) *State { } } +func CreateSession(ctx *h.RequestContext) Id { + sessionId := fmt.Sprintf("session-id-%s", h.GenId(30)) + ctx.Set("session-id", sessionId) + return Id(sessionId) +} + func GetSessionId(ctx *h.RequestContext) Id { sessionIdRaw := ctx.Get("session-id") sessionId := "" if sessionIdRaw == "" || sessionIdRaw == nil { - sessionId = fmt.Sprintf("session-id-%s", h.GenId(30)) - ctx.Set("session-id", sessionId) + panic("session id is not set, please use session.CreateSession(ctx) in middleware to create a session id") } else { sessionId = sessionIdRaw.(string) }