diff --git a/examples/chat/chat/broadcast.go b/examples/chat/chat/broadcast.go index 7e61aff..3042704 100644 --- a/examples/chat/chat/broadcast.go +++ b/examples/chat/chat/broadcast.go @@ -5,20 +5,22 @@ import ( "chat/ws" "context" "fmt" - "github.com/google/uuid" "github.com/maddalax/htmgo/framework/h" "github.com/maddalax/htmgo/framework/service" + "time" ) type Manager struct { socketManager *ws.SocketManager queries *db.Queries + service *Service } -func NewManager(loader *service.Locator) *Manager { +func NewManager(locator *service.Locator) *Manager { return &Manager{ - socketManager: service.Get[ws.SocketManager](loader), - queries: service.Get[db.Queries](loader), + socketManager: service.Get[ws.SocketManager](locator), + queries: service.Get[db.Queries](locator), + service: NewService(locator), } } @@ -32,61 +34,58 @@ func (m *Manager) StartListener() { switch event.Type { case ws.ConnectedEvent: fmt.Printf("User %s connected\n", event.Id) - m.backFill(event.Id) + m.backFill(event.Id, event.RoomId) case ws.DisconnectedEvent: fmt.Printf("User %s disconnected\n", event.Id) case ws.MessageEvent: - m.onMessage(event.Id, event.Payload) + m.onMessage(event) } } } } -func (m *Manager) backFill(socketId string) { +func (m *Manager) backFill(socketId string, roomId string) { messages, _ := m.queries.GetLastMessages(context.Background(), db.GetLastMessagesParams{ - ChatRoomID: "4ccc3f90a27c9375c98477571034b2e1", - Limit: 50, + ChatRoomID: roomId, + Limit: 200, }) for _, message := range messages { + parsed, _ := time.Parse("2006-01-02 15:04:05", message.CreatedAt) m.socketManager.SendText(socketId, - h.Render(MessageRow(message.Message)), + h.Render(MessageRow(&Message{ + UserId: message.UserID, + UserName: message.UserName, + Message: message.Message, + CreatedAt: parsed, + })), ) } } -func (m *Manager) onMessage(socketId string, payload map[string]any) { - fmt.Printf("Received message from %s: %v\n", socketId, payload) - message := payload["message"].(string) +func (m *Manager) onMessage(e ws.SocketEvent) { + fmt.Printf("Received message from %s: %v\n", e.Id, e.Payload) + message := e.Payload["message"].(string) if message == "" { return } - ctx := context.Background() - - user, err := m.queries.CreateUser(ctx, uuid.NewString()) + user, err := m.queries.GetUserBySessionId(context.Background(), e.Id) if err != nil { - fmt.Printf("Error creating user: %v\n", err) - return - } - //chat, _ := m.queries.CreateChatRoom(ctx, "General") - - err = m.queries.InsertMessage( - context.Background(), - db.InsertMessageParams{ - ChatRoomID: "4ccc3f90a27c9375c98477571034b2e1", - UserID: user.ID, - Message: message, - }, - ) - - if err != nil { - fmt.Printf("Error inserting message: %v\n", err) + fmt.Printf("Error getting user: %v\n", err) return } - m.socketManager.BroadcastText( - h.Render(MessageRow(message)), + saved := m.service.InsertMessage( + &user, + e.RoomId, + message, ) + + if saved != nil { + m.socketManager.BroadcastText( + h.Render(MessageRow(saved)), + ) + } } diff --git a/examples/chat/chat/component.go b/examples/chat/chat/component.go index 2a8c066..4d59429 100644 --- a/examples/chat/chat/component.go +++ b/examples/chat/chat/component.go @@ -1,12 +1,20 @@ package chat -import "github.com/maddalax/htmgo/framework/h" +import ( + "github.com/maddalax/htmgo/framework/h" + "time" +) -func MessageRow(text string) *h.Element { +func MessageRow(message *Message) *h.Element { return h.Div( h.Attribute("hx-swap-oob", "beforeend"), h.Class("flex flex-col gap-2 w-full"), h.Id("messages"), - h.Pf(text), + h.Div( + h.Class("flex gap-2 items-center"), + h.Pf(message.UserName), + h.Pf(message.CreatedAt.In(time.Local).Format("01/02 03:04 PM")), + h.Pf(message.Message), + ), ) } diff --git a/examples/chat/chat/service.go b/examples/chat/chat/service.go index 465b2ec..57adf34 100644 --- a/examples/chat/chat/service.go +++ b/examples/chat/chat/service.go @@ -3,9 +3,20 @@ package chat import ( "chat/internal/db" "context" + "fmt" + "github.com/google/uuid" "github.com/maddalax/htmgo/framework/service" + "log" + "time" ) +type Message struct { + UserId int64 `json:"userId"` + UserName string `json:"userName"` + Message string `json:"message"` + CreatedAt time.Time `json:"createdAt"` +} + type Service struct { queries *db.Queries } @@ -16,6 +27,54 @@ func NewService(locator *service.Locator) *Service { } } +func (s *Service) InsertMessage(user *db.User, roomId string, message string) *Message { + err := s.queries.InsertMessage(context.Background(), db.InsertMessageParams{ + UserID: user.ID, + Username: user.Name, + ChatRoomID: roomId, + Message: message, + }) + if err != nil { + log.Printf("Failed to insert message: %v\n", err) + return nil + } + return &Message{ + UserId: user.ID, + UserName: user.Name, + Message: message, + CreatedAt: time.Now(), + } +} + +func (s *Service) GetUserBySession(sessionId string) (*db.User, error) { + user, err := s.queries.GetUserBySessionId(context.Background(), sessionId) + return &user, err +} + +func (s *Service) CreateUser(name string) (*db.CreateUserRow, error) { + nameWithHash := fmt.Sprintf("%s#%s", name, uuid.NewString()[0:4]) + sessionId := fmt.Sprintf("session-%s-%s", uuid.NewString(), uuid.NewString()) + user, err := s.queries.CreateUser(context.Background(), db.CreateUserParams{ + Name: nameWithHash, + SessionID: sessionId, + }) + if err != nil { + return nil, err + } + return &user, nil +} + +func (s *Service) CreateRoom(name string) (*db.CreateChatRoomRow, error) { + room, err := s.queries.CreateChatRoom(context.Background(), db.CreateChatRoomParams{ + ID: fmt.Sprintf("room-%s-%s", uuid.NewString()[0:8], name), + Name: name, + }) + if err != nil { + return nil, err + } + return &room, nil +} + func (s *Service) GetRoom(id string) (*db.ChatRoom, error) { room, err := s.queries.GetChatRoom(context.Background(), id) if err != nil { diff --git a/examples/chat/internal/db/models.go b/examples/chat/internal/db/models.go index 71aadfb..7b3a08a 100644 --- a/examples/chat/internal/db/models.go +++ b/examples/chat/internal/db/models.go @@ -20,6 +20,7 @@ type Message struct { ID int64 ChatRoomID string UserID int64 + Username string Message string CreatedAt string UpdatedAt string @@ -30,4 +31,5 @@ type User struct { Name string CreatedAt string UpdatedAt string + SessionID string } diff --git a/examples/chat/internal/db/queries.sql b/examples/chat/internal/db/queries.sql index 597223e..46e51d8 100644 --- a/examples/chat/internal/db/queries.sql +++ b/examples/chat/internal/db/queries.sql @@ -1,12 +1,12 @@ -- name: CreateChatRoom :one -INSERT INTO chat_rooms (name, created_at, updated_at) -VALUES (?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) +INSERT INTO chat_rooms (id, name, created_at, updated_at) +VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) RETURNING id, name, created_at, updated_at, last_message_sent_at; -- name: InsertMessage :exec -INSERT INTO messages (chat_room_id, user_id, message, created_at, updated_at) -VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) -RETURNING id, chat_room_id, user_id, message, created_at, updated_at; +INSERT INTO messages (chat_room_id, user_id, username, message, created_at, updated_at) +VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) +RETURNING id, chat_room_id, user_id, username, message, created_at, updated_at; -- name: UpdateChatRoomLastMessageSentAt :exec UPDATE chat_rooms @@ -24,9 +24,9 @@ FROM chat_rooms WHERE chat_rooms.id = ?; -- name: CreateUser :one -INSERT INTO users (name, created_at, updated_at) -VALUES (?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) -RETURNING id, name, created_at, updated_at; +INSERT INTO users (name, session_id, created_at, updated_at) +VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) +RETURNING id, name, session_id, created_at, updated_at; -- name: GetLastMessages :many SELECT @@ -42,3 +42,6 @@ FROM messages WHERE messages.chat_room_id = ? ORDER BY messages.created_at LIMIT ?; + +-- name: GetUserBySessionId :one +SELECT * FROM users WHERE session_id = ?; diff --git a/examples/chat/internal/db/queries.sql.go b/examples/chat/internal/db/queries.sql.go index 8d7cc09..42ee1af 100644 --- a/examples/chat/internal/db/queries.sql.go +++ b/examples/chat/internal/db/queries.sql.go @@ -11,11 +11,16 @@ import ( ) const createChatRoom = `-- name: CreateChatRoom :one -INSERT INTO chat_rooms (name, created_at, updated_at) -VALUES (?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) +INSERT INTO chat_rooms (id, name, created_at, updated_at) +VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) RETURNING id, name, created_at, updated_at, last_message_sent_at ` +type CreateChatRoomParams struct { + ID string + Name string +} + type CreateChatRoomRow struct { ID string Name string @@ -24,8 +29,8 @@ type CreateChatRoomRow struct { LastMessageSentAt sql.NullString } -func (q *Queries) CreateChatRoom(ctx context.Context, name string) (CreateChatRoomRow, error) { - row := q.db.QueryRowContext(ctx, createChatRoom, name) +func (q *Queries) CreateChatRoom(ctx context.Context, arg CreateChatRoomParams) (CreateChatRoomRow, error) { + row := q.db.QueryRowContext(ctx, createChatRoom, arg.ID, arg.Name) var i CreateChatRoomRow err := row.Scan( &i.ID, @@ -38,17 +43,31 @@ func (q *Queries) CreateChatRoom(ctx context.Context, name string) (CreateChatRo } const createUser = `-- name: CreateUser :one -INSERT INTO users (name, created_at, updated_at) -VALUES (?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) -RETURNING id, name, created_at, updated_at +INSERT INTO users (name, session_id, created_at, updated_at) +VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) +RETURNING id, name, session_id, created_at, updated_at ` -func (q *Queries) CreateUser(ctx context.Context, name string) (User, error) { - row := q.db.QueryRowContext(ctx, createUser, name) - var i User +type CreateUserParams struct { + Name string + SessionID string +} + +type CreateUserRow struct { + ID int64 + Name string + SessionID string + CreatedAt string + UpdatedAt string +} + +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (CreateUserRow, error) { + row := q.db.QueryRowContext(ctx, createUser, arg.Name, arg.SessionID) + var i CreateUserRow err := row.Scan( &i.ID, &i.Name, + &i.SessionID, &i.CreatedAt, &i.UpdatedAt, ) @@ -141,20 +160,43 @@ func (q *Queries) GetLastMessages(ctx context.Context, arg GetLastMessagesParams return items, nil } +const getUserBySessionId = `-- name: GetUserBySessionId :one +SELECT id, name, created_at, updated_at, session_id FROM users WHERE session_id = ? +` + +func (q *Queries) GetUserBySessionId(ctx context.Context, sessionID string) (User, error) { + row := q.db.QueryRowContext(ctx, getUserBySessionId, sessionID) + var i User + err := row.Scan( + &i.ID, + &i.Name, + &i.CreatedAt, + &i.UpdatedAt, + &i.SessionID, + ) + return i, err +} + const insertMessage = `-- name: InsertMessage :exec -INSERT INTO messages (chat_room_id, user_id, message, created_at, updated_at) -VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) -RETURNING id, chat_room_id, user_id, message, created_at, updated_at +INSERT INTO messages (chat_room_id, user_id, username, message, created_at, updated_at) +VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) +RETURNING id, chat_room_id, user_id, username, message, created_at, updated_at ` type InsertMessageParams struct { ChatRoomID string UserID int64 + Username string Message string } func (q *Queries) InsertMessage(ctx context.Context, arg InsertMessageParams) error { - _, err := q.db.ExecContext(ctx, insertMessage, arg.ChatRoomID, arg.UserID, arg.Message) + _, err := q.db.ExecContext(ctx, insertMessage, + arg.ChatRoomID, + arg.UserID, + arg.Username, + arg.Message, + ) return err } diff --git a/examples/chat/internal/db/schema.sql b/examples/chat/internal/db/schema.sql index 01c0f80..faf0c14 100644 --- a/examples/chat/internal/db/schema.sql +++ b/examples/chat/internal/db/schema.sql @@ -3,12 +3,13 @@ CREATE TABLE IF NOT EXISTS users id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP + updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, + session_id TEXT NOT NULL ) STRICT; CREATE TABLE IF NOT EXISTS chat_rooms ( - id TEXT PRIMARY KEY DEFAULT (lower(hex(randomblob(16)))), -- Generates a UUID + id TEXT PRIMARY KEY, name TEXT NOT NULL, last_message_sent_at TEXT, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -20,6 +21,7 @@ CREATE TABLE IF NOT EXISTS messages id INTEGER PRIMARY KEY AUTOINCREMENT, chat_room_id TEXT NOT NULL, user_id INTEGER NOT NULL, + username TEXT NOT NULL, message TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, diff --git a/examples/chat/main.go b/examples/chat/main.go index dc16ca7..103077c 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -39,7 +39,7 @@ func main() { http.FileServerFS(sub) app.Router.Handle("/public/*", http.StripPrefix("/public", http.FileServerFS(sub))) - app.Router.Handle("/chat", ws.Handle()) + app.Router.Handle("/ws/chat/{id}", ws.Handle()) __htmgo.Register(app.Router) }, diff --git a/examples/chat/pages/chat.$id.go b/examples/chat/pages/chat.$id.go index f695264..86d3373 100644 --- a/examples/chat/pages/chat.$id.go +++ b/examples/chat/pages/chat.$id.go @@ -1,12 +1,15 @@ package pages import ( + "fmt" + "github.com/go-chi/chi/v5" "github.com/maddalax/htmgo/framework/h" "github.com/maddalax/htmgo/framework/hx" "github.com/maddalax/htmgo/framework/js" ) func ChatRoom(ctx *h.RequestContext) *h.Page { + roomId := chi.URLParam(ctx.Request, "id") return h.NewPage( RootPage( h.Div( @@ -14,13 +17,32 @@ func ChatRoom(ctx *h.RequestContext) *h.Page { h.TriggerChildren(), h.HxExtension("ws"), ), - h.Attribute("ws-connect", "/chat"), - h.Class("flex flex-col gap-4 items-center pt-24 min-h-screen bg-neutral-100"), - Form(ctx), + h.Attribute("ws-connect", fmt.Sprintf("/ws/chat/%s", roomId)), + h.Class("flex flex-row gap-4 min-h-screen bg-neutral-100"), + + // Sidebar for connected users + UserSidebar(), + + // Chat Area h.Div( + h.Class("flex flex-col flex-grow gap-4 bg-white shadow-md rounded-lg p-4"), + + h.OnEvent("hx-on::ws-after-message", + // language=JavaScript + js.EvalJsOnSibling("#messages", ` + element.scrollTop = element.scrollHeight; + `)), + + // Chat Messages h.Div( h.Id("messages"), - h.Class("flex flex-col gap-2 w-full"), + h.Class("flex flex-col gap-2 overflow-auto grow w-full"), + ), + + // Chat Input at the bottom + h.Div( + h.Class("mt-auto"), + Form(ctx), ), ), ), @@ -28,13 +50,27 @@ func ChatRoom(ctx *h.RequestContext) *h.Page { ) } +func UserSidebar() *h.Element { + return h.Div( + h.Class("w-64 bg-slate-200 p-4 flex flex-col gap-4 rounded-l-lg"), + h.H2F("Connected Users", h.Class("text-lg font-bold")), + h.Ul( + h.Class("flex flex-col gap-2"), + // This would be populated dynamically with connected users + h.Li(h.Text("User 1"), h.Class("text-slate-700")), + h.Li(h.Text("User 2"), h.Class("text-slate-700")), + h.Li(h.Text("User 3"), h.Class("text-slate-700")), + ), + ) +} + func MessageInput() *h.Element { return h.Input("text", h.Id("message-input"), h.Required(), - h.Class("p-4 rounded-md border border-slate-200"), + h.Class("p-4 rounded-md border border-slate-200 w-full"), h.Name("message"), - h.Placeholder("Message"), + h.Placeholder("Type a message..."), h.HxBeforeWsSend( js.SetValue(""), ), @@ -44,40 +80,11 @@ func MessageInput() *h.Element { func Form(ctx *h.RequestContext) *h.Element { return h.Div( - h.Class("flex flex-col items-center justify-center p-4 gap-6"), - h.H2F("Form submission with ws example", h.Class("text-2xl font-bold")), + h.Class("flex gap-4 items-center"), h.Form( h.Attribute("ws-send", ""), - h.Class("flex flex-col gap-2"), - h.LabelFor("name", "Your Message"), + h.Class("flex flex-grow"), MessageInput(), - SubmitButton(), - ), - ) -} - -func SubmitButton() *h.Element { - buttonClasses := "rounded items-center px-3 py-2 bg-slate-800 text-white w-full text-center" - return h.Div( - h.HxBeforeRequest( - js.RemoveClassOnChildren(".loading", "hidden"), - js.SetClassOnChildren(".submit", "hidden"), - ), - h.HxAfterRequest( - js.SetClassOnChildren(".loading", "hidden"), - js.RemoveClassOnChildren(".submit", "hidden"), - ), - h.Class("flex gap-2 justify-center"), - h.Button( - h.Class("loading hidden relative text-center", buttonClasses), - Spinner(), - h.Disabled(), - h.Text("Submitting..."), - ), - h.Button( - h.Type("submit"), - h.Class("submit", buttonClasses), - h.Text("Submit"), ), ) } @@ -85,7 +92,7 @@ func SubmitButton() *h.Element { func Spinner(children ...h.Ren) *h.Element { return h.Div( h.Children(children...), - h.Class("absolute left-1 spinner spinner-border animate-spin inline-block w-6 h-6 border-4 rounded-full border-slate-200 border-t-transparent"), + h.Class("spinner spinner-border animate-spin w-4 h-4 border-2 border-t-transparent"), h.Attribute("role", "status"), ) } diff --git a/examples/chat/partials/index.go b/examples/chat/partials/index.go index d9b3c98..008bb2f 100644 --- a/examples/chat/partials/index.go +++ b/examples/chat/partials/index.go @@ -4,26 +4,59 @@ import ( "chat/chat" "chat/components" "github.com/maddalax/htmgo/framework/h" + "net/http" ) func CreateOrJoinRoom(ctx *h.RequestContext) *h.Partial { locator := ctx.ServiceLocator() service := chat.NewService(locator) - chatRoomId := ctx.FormValue("join-chat-room") + chatRoomId := ctx.Request.FormValue("join-chat-room") + username := ctx.Request.FormValue("username") + + if username == "" { + return h.SwapPartial(ctx, components.FormError("Username is required")) + } + + user, err := service.CreateUser(username) + + if err != nil { + return h.SwapPartial(ctx, components.FormError("Failed to create user")) + } + + var redirect = func(path string) *h.Partial { + cookie := &http.Cookie{ + Name: "session_id", + Value: user.SessionID, + Path: "/", + } + return h.SwapManyPartialWithHeaders( + ctx, + h.NewHeaders( + "Set-Cookie", cookie.String(), + "HX-Redirect", path, + ), + h.Fragment(), + ) + } if chatRoomId != "" { room, _ := service.GetRoom(chatRoomId) if room == nil { return h.SwapPartial(ctx, components.FormError("Room not found")) } else { - return h.RedirectPartial("/chat/" + chatRoomId) + return redirect("/chat/" + chatRoomId) } } - chatRoomName := ctx.FormValue("chat-room-name") + chatRoomName := ctx.Request.FormValue("new-chat-room") if chatRoomName != "" { - // create room + room, _ := service.CreateRoom(chatRoomName) + if room == nil { + return h.SwapPartial(ctx, components.FormError("Failed to create room")) + } else { + return redirect("/chat/" + room.ID) + } } return h.SwapPartial(ctx, components.FormError("Create a new room or join an existing one")) diff --git a/examples/chat/ws/handler.go b/examples/chat/ws/handler.go index cbe04be..8ca07de 100644 --- a/examples/chat/ws/handler.go +++ b/examples/chat/ws/handler.go @@ -4,7 +4,7 @@ import ( "context" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" - "github.com/google/uuid" + "github.com/go-chi/chi/v5" "github.com/maddalax/htmgo/framework/h" "github.com/maddalax/htmgo/framework/service" "net/http" @@ -12,8 +12,20 @@ import ( func Handle() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, nil) cc := r.Context().Value(h.RequestContextKey).(*h.RequestContext) + + sessionCookie, err := r.Cookie("session_id") + + cookies := r.Cookies() + + println(cookies) + // no session + if err != nil { + return + } + + c, err := websocket.Accept(w, r, nil) + locator := cc.ServiceLocator() manager := service.Get[SocketManager](locator) @@ -21,22 +33,30 @@ func Handle() http.HandlerFunc { return } - id := uuid.NewString() - manager.Add(id, c) + sessionId := sessionCookie.Value + + roomId := chi.URLParam(r, "id") + + if roomId == "" { + manager.CloseWithError(sessionId, "invalid room") + return + } + + manager.Add(roomId, sessionId, c) defer func() { - manager.Disconnect(id) + manager.Disconnect(sessionId) }() for { var v map[string]any err = wsjson.Read(context.Background(), c, &v) if err != nil { - manager.CloseWithError(id, "failed to read message") + manager.CloseWithError(sessionId, "failed to read message") return } if v != nil { - manager.OnMessage(id, v) + manager.OnMessage(sessionId, v) } } diff --git a/examples/chat/ws/manager.go b/examples/chat/ws/manager.go index 1cd60b8..4098dee 100644 --- a/examples/chat/ws/manager.go +++ b/examples/chat/ws/manager.go @@ -17,18 +17,25 @@ const ( type SocketEvent struct { Id string + RoomId string Type EventType Payload map[string]any } +type SocketConnection struct { + Id string + Conn *websocket.Conn + RoomId string +} + type SocketManager struct { - sockets *xsync.MapOf[string, *websocket.Conn] + sockets *xsync.MapOf[string, SocketConnection] listeners []chan SocketEvent } func NewSocketManager() *SocketManager { return &SocketManager{ - sockets: xsync.NewMapOf[string, *websocket.Conn](), + sockets: xsync.NewMapOf[string, SocketConnection](), } } @@ -46,26 +53,41 @@ func (manager *SocketManager) dispatch(event SocketEvent) { } func (manager *SocketManager) OnMessage(id string, message map[string]any) { + socket := manager.Get(id) + if socket == nil { + return + } manager.dispatch(SocketEvent{ Id: id, Type: MessageEvent, Payload: message, + RoomId: socket.RoomId, }) } -func (manager *SocketManager) Add(id string, conn *websocket.Conn) { - manager.sockets.Store(id, conn) +func (manager *SocketManager) Add(roomId string, id string, conn *websocket.Conn) { + manager.sockets.Store(id, SocketConnection{ + Id: id, + Conn: conn, + RoomId: roomId, + }) manager.dispatch(SocketEvent{ Id: id, Type: ConnectedEvent, + RoomId: roomId, Payload: map[string]any{}, }) } func (manager *SocketManager) OnClose(id string) { + socket := manager.Get(id) + if socket == nil { + return + } manager.dispatch(SocketEvent{ Id: id, Type: DisconnectedEvent, + RoomId: socket.RoomId, Payload: map[string]any{}, }) manager.sockets.Delete(id) @@ -75,7 +97,7 @@ func (manager *SocketManager) CloseWithError(id string, message string) { conn := manager.Get(id) if conn != nil { defer manager.OnClose(id) - conn.Close(websocket.StatusInternalError, message) + conn.Conn.Close(websocket.StatusInternalError, message) } } @@ -83,19 +105,22 @@ func (manager *SocketManager) Disconnect(id string) { conn := manager.Get(id) if conn != nil { defer manager.OnClose(id) - _ = conn.CloseNow() + _ = conn.Conn.CloseNow() } } -func (manager *SocketManager) Get(id string) *websocket.Conn { - conn, _ := manager.sockets.Load(id) - return conn +func (manager *SocketManager) Get(id string) *SocketConnection { + conn, ok := manager.sockets.Load(id) + if !ok { + return nil + } + return &conn } func (manager *SocketManager) Broadcast(message []byte, messageType websocket.MessageType) { ctx := context.Background() - manager.sockets.Range(func(id string, conn *websocket.Conn) bool { - err := conn.Write(ctx, messageType, message) + manager.sockets.Range(func(id string, conn SocketConnection) bool { + err := conn.Conn.Write(ctx, messageType, message) if err != nil { manager.Disconnect(id) } @@ -111,6 +136,6 @@ func (manager *SocketManager) BroadcastText(message string) { func (manager *SocketManager) SendText(id string, message string) { conn := manager.Get(id) if conn != nil { - _ = conn.Write(context.Background(), websocket.MessageText, []byte(message)) + _ = conn.Conn.Write(context.Background(), websocket.MessageText, []byte(message)) } } diff --git a/framework/h/app.go b/framework/h/app.go index 1780a76..51547cf 100644 --- a/framework/h/app.go +++ b/framework/h/app.go @@ -16,7 +16,8 @@ import ( ) type RequestContext struct { - *http.Request + Request *http.Request + Response http.ResponseWriter locator *service.Locator isBoosted bool currentBrowserUrl string @@ -118,9 +119,10 @@ func (app *App) start() { app.Router.Use(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cc := &RequestContext{ - locator: app.Opts.ServiceLocator, - Request: r, - kv: make(map[string]interface{}), + locator: app.Opts.ServiceLocator, + Request: r, + Response: w, + kv: make(map[string]interface{}), } populateHxFields(cc) ctx := context.WithValue(r.Context(), RequestContextKey, cc) diff --git a/framework/h/base.go b/framework/h/base.go index 9301ef8..01ecd76 100644 --- a/framework/h/base.go +++ b/framework/h/base.go @@ -1,7 +1,6 @@ package h import ( - "github.com/maddalax/htmgo/framework/hx" "html" "net/http" "reflect" @@ -57,10 +56,6 @@ func SwapPartial(ctx *RequestContext, swap *Element) *Partial { SwapMany(ctx, swap)) } -func RedirectPartial(url string) *Partial { - return NewPartialWithHeaders(NewHeaders(hx.RedirectHeader, url), Fragment()) -} - func SwapManyPartial(ctx *RequestContext, swaps ...*Element) *Partial { return NewPartial( SwapMany(ctx, swaps...), diff --git a/framework/h/header.go b/framework/h/header.go index 70645db..aba7c0a 100644 --- a/framework/h/header.go +++ b/framework/h/header.go @@ -34,7 +34,7 @@ func CombineHeaders(headers ...*Headers) *Headers { } func CurrentPath(ctx *RequestContext) string { - current := ctx.Header.Get(hx.CurrentUrlHeader) + current := ctx.Request.Header.Get(hx.CurrentUrlHeader) parsed, err := url.Parse(current) if err != nil { return "" diff --git a/framework/h/qs.go b/framework/h/qs.go index 31b2b93..75fad83 100644 --- a/framework/h/qs.go +++ b/framework/h/qs.go @@ -49,7 +49,7 @@ func (q *Qs) ToString() string { } func GetQueryParam(ctx *RequestContext, key string) string { - value, ok := ctx.URL.Query()[key] + value, ok := ctx.Request.URL.Query()[key] if value == nil || !ok { current := ctx.currentBrowserUrl if current != "" {