diff --git a/examples/ws-example/main.go b/examples/ws-example/main.go index 3e2b9ff..576d8d9 100644 --- a/examples/ws-example/main.go +++ b/examples/ws-example/main.go @@ -5,6 +5,7 @@ import ( ws2 "github.com/maddalax/htmgo/extensions/websocket/opts" "github.com/maddalax/htmgo/framework/h" "github.com/maddalax/htmgo/framework/service" + "github.com/maddalax/htmgo/framework/session" "io/fs" "net/http" "ws-example/__htmgo" @@ -17,6 +18,11 @@ func main() { ServiceLocator: locator, LiveReload: true, Register: func(app *h.App) { + + app.Use(func(ctx *h.RequestContext) { + session.CreateSession(ctx) + }) + websocket.EnableExtension(app, ws2.ExtensionOpts{ WsPath: "/ws", SessionId: func(ctx *h.RequestContext) string { diff --git a/examples/ws-example/pages/index.go b/examples/ws-example/pages/index.go index 4cf1601..2bf090e 100644 --- a/examples/ws-example/pages/index.go +++ b/examples/ws-example/pages/index.go @@ -5,21 +5,39 @@ import ( "github.com/maddalax/htmgo/extensions/websocket/ws" "github.com/maddalax/htmgo/framework/h" "github.com/maddalax/htmgo/framework/session" + "time" "ws-example/partials" ) func IndexPage(ctx *h.RequestContext) *h.Page { sessionId := session.GetSessionId(ctx) + + ws.Every(ctx, time.Second, func() bool { + return ws.PushElementCtx( + ctx, + h.Div( + h.Attribute("hx-swap-oob", "true"), + h.Id("current-time"), + h.TextF("Current time: %s", time.Now().Format("15:04:05")), + ), + ) + }) + return h.NewPage( RootPage( ctx, h.Div( h.Attribute("ws-connect", fmt.Sprintf("/ws?sessionId=%s", sessionId)), h.Class("flex flex-col gap-4 items-center pt-24 min-h-screen bg-neutral-100"), - h.H3(h.Id("intro-text"), h.Text("Repeater Example"), h.Class("text-2xl")), - + h.H3( + h.Id("intro-text"), + h.Text("Repeater Example"), + h.Class("text-2xl"), + ), + h.Div( + h.Id("current-time"), + ), partials.CounterForm(ctx, partials.CounterProps{Id: "counter-1"}), - partials.Repeater(ctx, partials.RepeaterProps{ Id: "repeater-1", OnAdd: func(data ws.HandlerData) { @@ -38,9 +56,11 @@ func IndexPage(ctx *h.RequestContext) *h.Page { ) }, Item: func(index int) *h.Element { - return h.Input("text", + return h.Input( + "text", h.Class("border border-gray-300 rounded p-2"), - h.Value(fmt.Sprintf("item %d", index))) + h.Value(fmt.Sprintf("item %d", index)), + ) }, }), ), diff --git a/examples/ws-example/partials/index.go b/examples/ws-example/partials/index.go index 15b4466..fae8de9 100644 --- a/examples/ws-example/partials/index.go +++ b/examples/ws-example/partials/index.go @@ -59,11 +59,11 @@ func CounterForm(ctx *h.RequestContext, props CounterProps) *h.Element { h.Class("bg-rose-400 hover:bg-rose-500 text-white font-bold py-2 px-4 rounded"), h.Type("submit"), h.Text("Increment"), - ws.OnServerSideEvent(ctx, "increment", func(data ws.HandlerData) { + ws.OnServerEvent(ctx, "increment", func(data ws.HandlerData) { counter.Increment() ws.PushElement(data, CounterForm(ctx, props)) }), - ws.OnServerSideEvent(ctx, "decrement", func(data ws.HandlerData) { + ws.OnServerEvent(ctx, "decrement", func(data ws.HandlerData) { counter.Decrement() ws.PushElement(data, CounterForm(ctx, props)) }), diff --git a/examples/ws-example/partials/repeater.go b/examples/ws-example/partials/repeater.go index 858bc70..c357c83 100644 --- a/examples/ws-example/partials/repeater.go +++ b/examples/ws-example/partials/repeater.go @@ -31,13 +31,18 @@ func repeaterItem(ctx *h.RequestContext, item *h.Element, index int, props *Repe h.Class("flex gap-2 items-center"), h.Id(id), item, - props.RemoveButton(index, + props.RemoveButton( + index, h.ClassIf(index == 0, "opacity-0 disabled"), - h.If(index == 0, h.Disabled()), + h.If( + index == 0, + h.Disabled(), + ), ws.OnClick(ctx, func(data ws.HandlerData) { props.OnRemove(data, index) props.currentIndex-- - ws.PushElement(data, + ws.PushElement( + data, h.Div( h.Attribute("hx-swap-oob", fmt.Sprintf("delete:#%s", id)), h.Div(), @@ -63,7 +68,8 @@ func Repeater(ctx *h.RequestContext, props RepeaterProps) *h.Element { props.AddButton, ws.OnClick(ctx, func(data ws.HandlerData) { props.OnAdd(data) - ws.PushElement(data, + ws.PushElement( + data, h.Div( h.Attribute("hx-swap-oob", "beforebegin:#"+props.addButtonId()), repeaterItem( diff --git a/extensions/websocket/internal/wsutil/manager.go b/extensions/websocket/internal/wsutil/manager.go index 1c11daf..b9eb566 100644 --- a/extensions/websocket/internal/wsutil/manager.go +++ b/extensions/websocket/internal/wsutil/manager.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/maddalax/htmgo/extensions/websocket/opts" "github.com/puzpuzpuz/xsync/v3" + "log/slog" "strings" "time" ) @@ -139,13 +140,25 @@ func (manager *SocketManager) OnClose(id string) { if socket == nil { return } + slog.Debug("ws-extension: removing socket from manager", slog.String("socketId", id)) manager.dispatch(SocketEvent{ SessionId: id, Type: DisconnectedEvent, RoomId: socket.RoomId, Payload: map[string]any{}, }) - manager.sockets.Delete(id) + roomId, ok := manager.idToRoom.Load(id) + if !ok { + return + } + sockets, ok := manager.sockets.Load(roomId) + if !ok { + return + } + sockets.Delete(id) + manager.idToRoom.Delete(id) + slog.Debug("ws-extension: removed socket from manager", slog.String("socketId", id)) + } func (manager *SocketManager) CloseWithMessage(id string, message string) { @@ -178,11 +191,12 @@ func (manager *SocketManager) Get(id string) *SocketConnection { return &conn } -func (manager *SocketManager) Ping(id string) { +func (manager *SocketManager) Ping(id string) bool { conn := manager.Get(id) if conn != nil { - manager.writeText(*conn, "ping") + return manager.writeText(*conn, "ping") } + return false } func (manager *SocketManager) writeCloseRaw(writer WriterChan, message string) { @@ -198,11 +212,12 @@ func (manager *SocketManager) writeTextRaw(writer WriterChan, message string) { } } -func (manager *SocketManager) writeText(socket SocketConnection, message string) { +func (manager *SocketManager) writeText(socket SocketConnection, message string) bool { if socket.Writer == nil { - return + return false } manager.writeTextRaw(socket.Writer, message) + return true } func (manager *SocketManager) BroadcastText(roomId string, message string, predicate func(conn SocketConnection) bool) { @@ -220,19 +235,21 @@ func (manager *SocketManager) BroadcastText(roomId string, message string, predi }) } -func (manager *SocketManager) SendHtml(id string, message string) { +func (manager *SocketManager) SendHtml(id string, message string) bool { conn := manager.Get(id) minified := strings.ReplaceAll(message, "\n", "") minified = strings.ReplaceAll(minified, "\t", "") minified = strings.TrimSpace(minified) if conn != nil { - manager.writeText(*conn, minified) + return manager.writeText(*conn, minified) } + return false } -func (manager *SocketManager) SendText(id string, message string) { +func (manager *SocketManager) SendText(id string, message string) bool { conn := manager.Get(id) if conn != nil { - manager.writeText(*conn, message) + return manager.writeText(*conn, message) } + return false } diff --git a/extensions/websocket/internal/wsutil/manager_test.go b/extensions/websocket/internal/wsutil/manager_test.go new file mode 100644 index 0000000..0ba95d4 --- /dev/null +++ b/extensions/websocket/internal/wsutil/manager_test.go @@ -0,0 +1,202 @@ +package wsutil + +import ( + ws2 "github.com/maddalax/htmgo/extensions/websocket/opts" + "github.com/maddalax/htmgo/framework/h" + "github.com/stretchr/testify/assert" + "testing" +) + +func createManager() *SocketManager { + return NewSocketManager(&ws2.ExtensionOpts{ + WsPath: "/ws", + SessionId: func(ctx *h.RequestContext) string { + return "test" + }, + }) +} + +func addSocket(manager *SocketManager, roomId string, id string) (socketId string, writer WriterChan, done DoneChan) { + writer = make(chan string, 10) + done = make(chan bool, 10) + manager.Add(roomId, id, writer, done) + return id, writer, done +} + +func TestManager(t *testing.T) { + manager := createManager() + socketId, _, _ := addSocket(manager, "123", "456") + socket := manager.Get(socketId) + assert.NotNil(t, socket) + assert.Equal(t, socketId, socket.Id) + + manager.OnClose(socketId) + socket = manager.Get(socketId) + assert.Nil(t, socket) +} + +func TestManagerForEachSocket(t *testing.T) { + manager := createManager() + addSocket(manager, "all", "456") + addSocket(manager, "all", "789") + var count int + manager.ForEachSocket("all", func(conn SocketConnection) { + count++ + }) + assert.Equal(t, 2, count) +} + +func TestSendText(t *testing.T) { + manager := createManager() + socketId, writer, done := addSocket(manager, "all", "456") + manager.SendText(socketId, "hello") + assert.Equal(t, "hello", <-writer) + manager.SendText(socketId, "hello2") + assert.Equal(t, "hello2", <-writer) + done <- true + assert.Equal(t, true, <-done) +} + +func TestBroadcastText(t *testing.T) { + manager := createManager() + _, w1, d1 := addSocket(manager, "all", "456") + _, w2, d2 := addSocket(manager, "all", "789") + manager.BroadcastText("all", "hello", func(conn SocketConnection) bool { + return true + }) + assert.Equal(t, "hello", <-w1) + assert.Equal(t, "hello", <-w2) + d1 <- true + d2 <- true + assert.Equal(t, true, <-d1) + assert.Equal(t, true, <-d2) +} + +func TestBroadcastTextWithPredicate(t *testing.T) { + manager := createManager() + _, w1, _ := addSocket(manager, "all", "456") + _, w2, _ := addSocket(manager, "all", "789") + manager.BroadcastText("all", "hello", func(conn SocketConnection) bool { + return conn.Id != "456" + }) + + assert.Equal(t, 0, len(w1)) + assert.Equal(t, 1, len(w2)) +} + +func TestSendHtml(t *testing.T) { + manager := createManager() + socketId, writer, _ := addSocket(manager, "all", "456") + rendered := h.Render( + h.Div( + h.P( + h.Text("hello"), + ), + )) + manager.SendHtml(socketId, rendered) + assert.Equal(t, "
hello