fix manager ws delete, add manager tests

This commit is contained in:
maddalax 2024-11-04 09:34:20 -06:00
parent 546b787a3b
commit c052a83ece
13 changed files with 363 additions and 27 deletions

View file

@ -5,6 +5,7 @@ import (
ws2 "github.com/maddalax/htmgo/extensions/websocket/opts" ws2 "github.com/maddalax/htmgo/extensions/websocket/opts"
"github.com/maddalax/htmgo/framework/h" "github.com/maddalax/htmgo/framework/h"
"github.com/maddalax/htmgo/framework/service" "github.com/maddalax/htmgo/framework/service"
"github.com/maddalax/htmgo/framework/session"
"io/fs" "io/fs"
"net/http" "net/http"
"ws-example/__htmgo" "ws-example/__htmgo"
@ -17,6 +18,11 @@ func main() {
ServiceLocator: locator, ServiceLocator: locator,
LiveReload: true, LiveReload: true,
Register: func(app *h.App) { Register: func(app *h.App) {
app.Use(func(ctx *h.RequestContext) {
session.CreateSession(ctx)
})
websocket.EnableExtension(app, ws2.ExtensionOpts{ websocket.EnableExtension(app, ws2.ExtensionOpts{
WsPath: "/ws", WsPath: "/ws",
SessionId: func(ctx *h.RequestContext) string { SessionId: func(ctx *h.RequestContext) string {

View file

@ -5,21 +5,39 @@ import (
"github.com/maddalax/htmgo/extensions/websocket/ws" "github.com/maddalax/htmgo/extensions/websocket/ws"
"github.com/maddalax/htmgo/framework/h" "github.com/maddalax/htmgo/framework/h"
"github.com/maddalax/htmgo/framework/session" "github.com/maddalax/htmgo/framework/session"
"time"
"ws-example/partials" "ws-example/partials"
) )
func IndexPage(ctx *h.RequestContext) *h.Page { func IndexPage(ctx *h.RequestContext) *h.Page {
sessionId := session.GetSessionId(ctx) sessionId := session.GetSessionId(ctx)
ws.Every(ctx, time.Second, func() bool {
return ws.PushElementCtx(
ctx,
h.Div(
h.Attribute("hx-swap-oob", "true"),
h.Id("current-time"),
h.TextF("Current time: %s", time.Now().Format("15:04:05")),
),
)
})
return h.NewPage( return h.NewPage(
RootPage( RootPage(
ctx, ctx,
h.Div( h.Div(
h.Attribute("ws-connect", fmt.Sprintf("/ws?sessionId=%s", sessionId)), h.Attribute("ws-connect", fmt.Sprintf("/ws?sessionId=%s", sessionId)),
h.Class("flex flex-col gap-4 items-center pt-24 min-h-screen bg-neutral-100"), h.Class("flex flex-col gap-4 items-center pt-24 min-h-screen bg-neutral-100"),
h.H3(h.Id("intro-text"), h.Text("Repeater Example"), h.Class("text-2xl")), h.H3(
h.Id("intro-text"),
h.Text("Repeater Example"),
h.Class("text-2xl"),
),
h.Div(
h.Id("current-time"),
),
partials.CounterForm(ctx, partials.CounterProps{Id: "counter-1"}), partials.CounterForm(ctx, partials.CounterProps{Id: "counter-1"}),
partials.Repeater(ctx, partials.RepeaterProps{ partials.Repeater(ctx, partials.RepeaterProps{
Id: "repeater-1", Id: "repeater-1",
OnAdd: func(data ws.HandlerData) { OnAdd: func(data ws.HandlerData) {
@ -38,9 +56,11 @@ func IndexPage(ctx *h.RequestContext) *h.Page {
) )
}, },
Item: func(index int) *h.Element { Item: func(index int) *h.Element {
return h.Input("text", return h.Input(
"text",
h.Class("border border-gray-300 rounded p-2"), h.Class("border border-gray-300 rounded p-2"),
h.Value(fmt.Sprintf("item %d", index))) h.Value(fmt.Sprintf("item %d", index)),
)
}, },
}), }),
), ),

View file

@ -59,11 +59,11 @@ func CounterForm(ctx *h.RequestContext, props CounterProps) *h.Element {
h.Class("bg-rose-400 hover:bg-rose-500 text-white font-bold py-2 px-4 rounded"), h.Class("bg-rose-400 hover:bg-rose-500 text-white font-bold py-2 px-4 rounded"),
h.Type("submit"), h.Type("submit"),
h.Text("Increment"), h.Text("Increment"),
ws.OnServerSideEvent(ctx, "increment", func(data ws.HandlerData) { ws.OnServerEvent(ctx, "increment", func(data ws.HandlerData) {
counter.Increment() counter.Increment()
ws.PushElement(data, CounterForm(ctx, props)) ws.PushElement(data, CounterForm(ctx, props))
}), }),
ws.OnServerSideEvent(ctx, "decrement", func(data ws.HandlerData) { ws.OnServerEvent(ctx, "decrement", func(data ws.HandlerData) {
counter.Decrement() counter.Decrement()
ws.PushElement(data, CounterForm(ctx, props)) ws.PushElement(data, CounterForm(ctx, props))
}), }),

View file

@ -31,13 +31,18 @@ func repeaterItem(ctx *h.RequestContext, item *h.Element, index int, props *Repe
h.Class("flex gap-2 items-center"), h.Class("flex gap-2 items-center"),
h.Id(id), h.Id(id),
item, item,
props.RemoveButton(index, props.RemoveButton(
index,
h.ClassIf(index == 0, "opacity-0 disabled"), h.ClassIf(index == 0, "opacity-0 disabled"),
h.If(index == 0, h.Disabled()), h.If(
index == 0,
h.Disabled(),
),
ws.OnClick(ctx, func(data ws.HandlerData) { ws.OnClick(ctx, func(data ws.HandlerData) {
props.OnRemove(data, index) props.OnRemove(data, index)
props.currentIndex-- props.currentIndex--
ws.PushElement(data, ws.PushElement(
data,
h.Div( h.Div(
h.Attribute("hx-swap-oob", fmt.Sprintf("delete:#%s", id)), h.Attribute("hx-swap-oob", fmt.Sprintf("delete:#%s", id)),
h.Div(), h.Div(),
@ -63,7 +68,8 @@ func Repeater(ctx *h.RequestContext, props RepeaterProps) *h.Element {
props.AddButton, props.AddButton,
ws.OnClick(ctx, func(data ws.HandlerData) { ws.OnClick(ctx, func(data ws.HandlerData) {
props.OnAdd(data) props.OnAdd(data)
ws.PushElement(data, ws.PushElement(
data,
h.Div( h.Div(
h.Attribute("hx-swap-oob", "beforebegin:#"+props.addButtonId()), h.Attribute("hx-swap-oob", "beforebegin:#"+props.addButtonId()),
repeaterItem( repeaterItem(

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/maddalax/htmgo/extensions/websocket/opts" "github.com/maddalax/htmgo/extensions/websocket/opts"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"log/slog"
"strings" "strings"
"time" "time"
) )
@ -139,13 +140,25 @@ func (manager *SocketManager) OnClose(id string) {
if socket == nil { if socket == nil {
return return
} }
slog.Debug("ws-extension: removing socket from manager", slog.String("socketId", id))
manager.dispatch(SocketEvent{ manager.dispatch(SocketEvent{
SessionId: id, SessionId: id,
Type: DisconnectedEvent, Type: DisconnectedEvent,
RoomId: socket.RoomId, RoomId: socket.RoomId,
Payload: map[string]any{}, Payload: map[string]any{},
}) })
manager.sockets.Delete(id) roomId, ok := manager.idToRoom.Load(id)
if !ok {
return
}
sockets, ok := manager.sockets.Load(roomId)
if !ok {
return
}
sockets.Delete(id)
manager.idToRoom.Delete(id)
slog.Debug("ws-extension: removed socket from manager", slog.String("socketId", id))
} }
func (manager *SocketManager) CloseWithMessage(id string, message string) { func (manager *SocketManager) CloseWithMessage(id string, message string) {
@ -178,11 +191,12 @@ func (manager *SocketManager) Get(id string) *SocketConnection {
return &conn return &conn
} }
func (manager *SocketManager) Ping(id string) { func (manager *SocketManager) Ping(id string) bool {
conn := manager.Get(id) conn := manager.Get(id)
if conn != nil { if conn != nil {
manager.writeText(*conn, "ping") return manager.writeText(*conn, "ping")
} }
return false
} }
func (manager *SocketManager) writeCloseRaw(writer WriterChan, message string) { func (manager *SocketManager) writeCloseRaw(writer WriterChan, message string) {
@ -198,11 +212,12 @@ func (manager *SocketManager) writeTextRaw(writer WriterChan, message string) {
} }
} }
func (manager *SocketManager) writeText(socket SocketConnection, message string) { func (manager *SocketManager) writeText(socket SocketConnection, message string) bool {
if socket.Writer == nil { if socket.Writer == nil {
return return false
} }
manager.writeTextRaw(socket.Writer, message) manager.writeTextRaw(socket.Writer, message)
return true
} }
func (manager *SocketManager) BroadcastText(roomId string, message string, predicate func(conn SocketConnection) bool) { func (manager *SocketManager) BroadcastText(roomId string, message string, predicate func(conn SocketConnection) bool) {
@ -220,19 +235,21 @@ func (manager *SocketManager) BroadcastText(roomId string, message string, predi
}) })
} }
func (manager *SocketManager) SendHtml(id string, message string) { func (manager *SocketManager) SendHtml(id string, message string) bool {
conn := manager.Get(id) conn := manager.Get(id)
minified := strings.ReplaceAll(message, "\n", "") minified := strings.ReplaceAll(message, "\n", "")
minified = strings.ReplaceAll(minified, "\t", "") minified = strings.ReplaceAll(minified, "\t", "")
minified = strings.TrimSpace(minified) minified = strings.TrimSpace(minified)
if conn != nil { if conn != nil {
manager.writeText(*conn, minified) return manager.writeText(*conn, minified)
} }
return false
} }
func (manager *SocketManager) SendText(id string, message string) { func (manager *SocketManager) SendText(id string, message string) bool {
conn := manager.Get(id) conn := manager.Get(id)
if conn != nil { if conn != nil {
manager.writeText(*conn, message) return manager.writeText(*conn, message)
} }
return false
} }

View file

@ -0,0 +1,202 @@
package wsutil
import (
ws2 "github.com/maddalax/htmgo/extensions/websocket/opts"
"github.com/maddalax/htmgo/framework/h"
"github.com/stretchr/testify/assert"
"testing"
)
func createManager() *SocketManager {
return NewSocketManager(&ws2.ExtensionOpts{
WsPath: "/ws",
SessionId: func(ctx *h.RequestContext) string {
return "test"
},
})
}
func addSocket(manager *SocketManager, roomId string, id string) (socketId string, writer WriterChan, done DoneChan) {
writer = make(chan string, 10)
done = make(chan bool, 10)
manager.Add(roomId, id, writer, done)
return id, writer, done
}
func TestManager(t *testing.T) {
manager := createManager()
socketId, _, _ := addSocket(manager, "123", "456")
socket := manager.Get(socketId)
assert.NotNil(t, socket)
assert.Equal(t, socketId, socket.Id)
manager.OnClose(socketId)
socket = manager.Get(socketId)
assert.Nil(t, socket)
}
func TestManagerForEachSocket(t *testing.T) {
manager := createManager()
addSocket(manager, "all", "456")
addSocket(manager, "all", "789")
var count int
manager.ForEachSocket("all", func(conn SocketConnection) {
count++
})
assert.Equal(t, 2, count)
}
func TestSendText(t *testing.T) {
manager := createManager()
socketId, writer, done := addSocket(manager, "all", "456")
manager.SendText(socketId, "hello")
assert.Equal(t, "hello", <-writer)
manager.SendText(socketId, "hello2")
assert.Equal(t, "hello2", <-writer)
done <- true
assert.Equal(t, true, <-done)
}
func TestBroadcastText(t *testing.T) {
manager := createManager()
_, w1, d1 := addSocket(manager, "all", "456")
_, w2, d2 := addSocket(manager, "all", "789")
manager.BroadcastText("all", "hello", func(conn SocketConnection) bool {
return true
})
assert.Equal(t, "hello", <-w1)
assert.Equal(t, "hello", <-w2)
d1 <- true
d2 <- true
assert.Equal(t, true, <-d1)
assert.Equal(t, true, <-d2)
}
func TestBroadcastTextWithPredicate(t *testing.T) {
manager := createManager()
_, w1, _ := addSocket(manager, "all", "456")
_, w2, _ := addSocket(manager, "all", "789")
manager.BroadcastText("all", "hello", func(conn SocketConnection) bool {
return conn.Id != "456"
})
assert.Equal(t, 0, len(w1))
assert.Equal(t, 1, len(w2))
}
func TestSendHtml(t *testing.T) {
manager := createManager()
socketId, writer, _ := addSocket(manager, "all", "456")
rendered := h.Render(
h.Div(
h.P(
h.Text("hello"),
),
))
manager.SendHtml(socketId, rendered)
assert.Equal(t, "<div><p>hello</p></div>", <-writer)
}
func TestOnMessage(t *testing.T) {
manager := createManager()
socketId, _, _ := addSocket(manager, "all", "456")
listener := make(chan SocketEvent, 10)
manager.Listen(listener)
manager.OnMessage(socketId, map[string]any{
"message": "hello",
})
event := <-listener
assert.Equal(t, "hello", event.Payload["message"])
assert.Equal(t, "456", event.SessionId)
assert.Equal(t, MessageEvent, event.Type)
assert.Equal(t, "all", event.RoomId)
}
func TestOnClose(t *testing.T) {
manager := createManager()
socketId, _, _ := addSocket(manager, "all", "456")
listener := make(chan SocketEvent, 10)
manager.Listen(listener)
manager.OnClose(socketId)
event := <-listener
assert.Equal(t, "456", event.SessionId)
assert.Equal(t, DisconnectedEvent, event.Type)
assert.Equal(t, "all", event.RoomId)
}
func TestOnAdd(t *testing.T) {
manager := createManager()
listener := make(chan SocketEvent, 10)
manager.Listen(listener)
socketId, _, _ := addSocket(manager, "all", "456")
event := <-listener
assert.Equal(t, socketId, event.SessionId)
assert.Equal(t, ConnectedEvent, event.Type)
assert.Equal(t, "all", event.RoomId)
}
func TestCloseWithMessage(t *testing.T) {
manager := createManager()
socketId, w, _ := addSocket(manager, "all", "456")
manager.CloseWithMessage(socketId, "internal error")
assert.Equal(t, "internal error", <-w)
assert.Nil(t, manager.Get(socketId))
}
func TestDisconnect(t *testing.T) {
manager := createManager()
socketId, _, _ := addSocket(manager, "all", "456")
manager.Disconnect(socketId)
assert.Nil(t, manager.Get(socketId))
}
func TestPing(t *testing.T) {
manager := createManager()
socketId, w, _ := addSocket(manager, "all", "456")
manager.Ping(socketId)
assert.Equal(t, "ping", <-w)
}
func TestMultipleRooms(t *testing.T) {
manager := createManager()
socketId1, _, _ := addSocket(manager, "room1", "456")
socketId2, _, _ := addSocket(manager, "room2", "789")
room1Count := 0
room2Count := 0
manager.ForEachSocket("room1", func(conn SocketConnection) {
room1Count++
})
manager.ForEachSocket("room2", func(conn SocketConnection) {
room2Count++
})
assert.Equal(t, 1, room1Count)
assert.Equal(t, 1, room2Count)
room1Count = 0
room2Count = 0
manager.OnClose(socketId1)
manager.OnClose(socketId2)
manager.ForEachSocket("room1", func(conn SocketConnection) {
room1Count++
})
manager.ForEachSocket("room2", func(conn SocketConnection) {
room2Count++
})
assert.Equal(t, 0, room1Count)
assert.Equal(t, 0, room2Count)
}

View file

@ -6,7 +6,11 @@ func OnClick(ctx *h.RequestContext, handler Handler) *h.AttributeMapOrdered {
return AddClientSideHandler(ctx, "click", handler) return AddClientSideHandler(ctx, "click", handler)
} }
func OnServerSideEvent(ctx *h.RequestContext, eventName string, handler Handler) h.Ren { func OnClientEvent(ctx *h.RequestContext, eventName string, handler Handler) *h.AttributeMapOrdered {
return AddClientSideHandler(ctx, eventName, handler)
}
func OnServerEvent(ctx *h.RequestContext, eventName string, handler Handler) h.Ren {
AddServerSideHandler(ctx, eventName, handler) AddServerSideHandler(ctx, eventName, handler)
return h.Attribute("data-handler-id", "") return h.Attribute("data-handler-id", "")
} }

View file

@ -1,6 +1,11 @@
package ws package ws
import "github.com/maddalax/htmgo/framework/h" import (
"github.com/maddalax/htmgo/extensions/websocket/internal/wsutil"
"github.com/maddalax/htmgo/framework/h"
"github.com/maddalax/htmgo/framework/service"
"github.com/maddalax/htmgo/framework/session"
)
// PushServerSideEvent sends a server side event this specific session // PushServerSideEvent sends a server side event this specific session
func PushServerSideEvent(data HandlerData, event string, value map[string]any) { func PushServerSideEvent(data HandlerData, event string, value map[string]any) {
@ -21,6 +26,22 @@ func BroadcastServerSideEvent(event string, value map[string]any) {
} }
// PushElement sends an element to the current session and swaps it into the page // PushElement sends an element to the current session and swaps it into the page
func PushElement(data HandlerData, el *h.Element) { func PushElement(data HandlerData, el *h.Element) bool {
data.Manager.SendHtml(data.Socket.Id, h.Render(el)) return data.Manager.SendHtml(data.Socket.Id, h.Render(el))
}
// PushElementCtx sends an element to the current session and swaps it into the page
func PushElementCtx(ctx *h.RequestContext, el *h.Element) bool {
locator := ctx.ServiceLocator()
socketManager := service.Get[wsutil.SocketManager](locator)
socketId := session.GetSessionId(ctx)
socket := socketManager.Get(string(socketId))
if socket == nil {
return false
}
return PushElement(HandlerData{
Socket: socket,
Manager: socketManager,
SessionId: socketId,
}, el)
} }

View file

@ -0,0 +1,44 @@
package ws
import (
"github.com/maddalax/htmgo/extensions/websocket/internal/wsutil"
"github.com/maddalax/htmgo/framework/h"
"github.com/maddalax/htmgo/framework/service"
"github.com/maddalax/htmgo/framework/session"
"log/slog"
"time"
)
// Every executes the given callback every interval, until the socket is disconnected, or the callback returns false.
func Every(ctx *h.RequestContext, interval time.Duration, cb func() bool) {
socketId := session.GetSessionId(ctx)
socketIdSlog := slog.String("socketId", string(socketId))
slog.Debug("ws-extension: starting every loop", socketIdSlog, slog.Duration("duration", interval))
go func() {
tries := 0
for {
locator := ctx.ServiceLocator()
socketManager := service.Get[wsutil.SocketManager](locator)
socket := socketManager.Get(string(socketId))
// This can run before the socket is established, lets try a few times and kill it if socket isn't connected after a bit.
if socket == nil {
if tries > 5 {
slog.Debug("ws-extension: socket disconnected, killing goroutine", socketIdSlog)
return
} else {
time.Sleep(time.Second)
tries++
slog.Debug("ws-extension: socket not connected yet, trying again", socketIdSlog, slog.Int("attempt", tries))
continue
}
}
success := cb()
if !success {
return
}
time.Sleep(interval)
}
}()
}

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/maddalax/htmgo/extensions/websocket/internal/wsutil" "github.com/maddalax/htmgo/extensions/websocket/internal/wsutil"
"github.com/maddalax/htmgo/framework/session" "github.com/maddalax/htmgo/framework/session"
"log/slog"
"sync" "sync"
) )
@ -78,6 +79,7 @@ func (h *MessageHandler) OnDomElementRemoved(handlerId string) {
} }
func (h *MessageHandler) OnSocketDisconnected(event wsutil.SocketEvent) { func (h *MessageHandler) OnSocketDisconnected(event wsutil.SocketEvent) {
slog.Info("ws-extension: socket disconnected", slog.String("socketId", event.SessionId))
sessionId := session.Id(event.SessionId) sessionId := session.Id(event.SessionId)
hashes, ok := sessionIdToHashes.Load(sessionId) hashes, ok := sessionIdToHashes.Load(sessionId)
if ok { if ok {

View file

@ -174,6 +174,16 @@ func (app *App) UseWithContext(h func(w http.ResponseWriter, r *http.Request, co
}) })
} }
func (app *App) Use(h func(ctx *RequestContext)) {
app.Router.Use(func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cc := r.Context().Value(RequestContextKey).(*RequestContext)
h(cc)
handler.ServeHTTP(w, r)
})
})
}
func GetLogLevel() slog.Level { func GetLogLevel() slog.Level {
// Get the log level from the environment variable // Get the log level from the environment variable
logLevel := os.Getenv("LOG_LEVEL") logLevel := os.Getenv("LOG_LEVEL")

View file

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"github.com/maddalax/htmgo/framework/datastructure/orderedmap" "github.com/maddalax/htmgo/framework/datastructure/orderedmap"
"github.com/maddalax/htmgo/framework/hx" "github.com/maddalax/htmgo/framework/hx"
"github.com/maddalax/htmgo/framework/internal/datastructure"
"github.com/maddalax/htmgo/framework/internal/util" "github.com/maddalax/htmgo/framework/internal/util"
"strings" "strings"
) )

View file

@ -22,13 +22,18 @@ func NewState(ctx *h.RequestContext) *State {
} }
} }
func CreateSession(ctx *h.RequestContext) Id {
sessionId := fmt.Sprintf("session-id-%s", h.GenId(30))
ctx.Set("session-id", sessionId)
return Id(sessionId)
}
func GetSessionId(ctx *h.RequestContext) Id { func GetSessionId(ctx *h.RequestContext) Id {
sessionIdRaw := ctx.Get("session-id") sessionIdRaw := ctx.Get("session-id")
sessionId := "" sessionId := ""
if sessionIdRaw == "" || sessionIdRaw == nil { if sessionIdRaw == "" || sessionIdRaw == nil {
sessionId = fmt.Sprintf("session-id-%s", h.GenId(30)) panic("session id is not set, please use session.CreateSession(ctx) in middleware to create a session id")
ctx.Set("session-id", sessionId)
} else { } else {
sessionId = sessionIdRaw.(string) sessionId = sessionIdRaw.(string)
} }