diff --git a/websocket/pool.go b/websocket/pool.go index 79c94c2ec..a6a1f2b2f 100644 --- a/websocket/pool.go +++ b/websocket/pool.go @@ -70,7 +70,15 @@ func (pool *Pool) Start() { } func (pool *Pool) SendTicketMessage(message TicketMessage) error { + + if pool == nil { + return fmt.Errorf("pool is nil") + } + if message.BroadcastType == "direct" { + if message.SourceSessionID == "" { + return fmt.Errorf("client not found") + } // check if client if client, ok := pool.Clients[message.SourceSessionID]; ok { return client.Client.Conn.WriteJSON(message) diff --git a/websocket/pool_test.go b/websocket/pool_test.go index 019097af5..41f1d6c51 100644 --- a/websocket/pool_test.go +++ b/websocket/pool_test.go @@ -1,8 +1,12 @@ package websocket import ( + "net/http" + "net/http/httptest" + "strings" "testing" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) @@ -67,3 +71,252 @@ func TestNewPool(t *testing.T) { }) } } + +func TestSendTicketMessage(t *testing.T) { + t.Run("Direct Broadcast with Valid Client", func(t *testing.T) { + pool := NewPool() + ws, server := setupTestWebsocket(t) + defer server.Close() + defer ws.Close() + + client := &Client{ + Host: "test-client", + Conn: ws, + Pool: pool, + } + + pool.Clients = make(map[string]*ClientData) + pool.Clients[client.Host] = &ClientData{ + Client: client, + Status: true, + } + + message := TicketMessage{ + BroadcastType: "direct", + SourceSessionID: "test-client", + Message: "Test message", + } + + err := pool.SendTicketMessage(message) + assert.NoError(t, err) + }) + + t.Run("Non-Direct Broadcast", func(t *testing.T) { + pool := NewPool() + ws, server := setupTestWebsocket(t) + defer server.Close() + defer ws.Close() + + client := &Client{ + Host: "test-client", + Conn: ws, + Pool: pool, + } + + pool.Clients = make(map[string]*ClientData) + pool.Clients[client.Host] = &ClientData{ + Client: client, + Status: true, + } + + message := TicketMessage{ + BroadcastType: "broadcast", + SourceSessionID: "test-client", + Message: "Test broadcast message", + } + + err := pool.SendTicketMessage(message) + assert.NoError(t, err) + }) + + t.Run("Empty SourceSessionID", func(t *testing.T) { + pool := NewPool() + message := TicketMessage{ + BroadcastType: "direct", + SourceSessionID: "", + Message: "Test message", + } + + err := pool.SendTicketMessage(message) + assert.Error(t, err) + assert.Contains(t, err.Error(), "client not found") + }) + + t.Run("Empty BroadcastType", func(t *testing.T) { + pool := NewPool() + ws, server := setupTestWebsocket(t) + defer server.Close() + defer ws.Close() + + client := &Client{ + Host: "test-client", + Conn: ws, + Pool: pool, + } + + pool.Clients = make(map[string]*ClientData) + pool.Clients[client.Host] = &ClientData{ + Client: client, + Status: true, + } + + message := TicketMessage{ + BroadcastType: "", + SourceSessionID: "test-client", + Message: "Test message", + } + + err := pool.SendTicketMessage(message) + assert.NoError(t, err) + }) + + t.Run("Client Not Found", func(t *testing.T) { + pool := NewPool() + message := TicketMessage{ + BroadcastType: "direct", + SourceSessionID: "non-existent-client", + Message: "Test message", + } + + err := pool.SendTicketMessage(message) + assert.Error(t, err) + assert.Contains(t, err.Error(), "client not found") + }) + + t.Run("WriteJSON Error", func(t *testing.T) { + pool := NewPool() + + ws, server := setupTestWebsocket(t) + server.Close() + ws.Close() + + client := &Client{ + Host: "test-client", + Conn: ws, + Pool: pool, + } + + pool.Clients = make(map[string]*ClientData) + pool.Clients[client.Host] = &ClientData{ + Client: client, + Status: true, + } + + message := TicketMessage{ + BroadcastType: "direct", + SourceSessionID: "test-client", + Message: "Test message", + } + + err := pool.SendTicketMessage(message) + assert.Error(t, err) + }) + + t.Run("Large Message Payload", func(t *testing.T) { + pool := NewPool() + ws, server := setupTestWebsocket(t) + defer server.Close() + defer ws.Close() + + client := &Client{ + Host: "test-client", + Conn: ws, + Pool: pool, + } + + pool.Clients = make(map[string]*ClientData) + pool.Clients[client.Host] = &ClientData{ + Client: client, + Status: true, + } + + largeMessage := strings.Repeat("a", 1024*1024) + message := TicketMessage{ + BroadcastType: "direct", + SourceSessionID: "test-client", + Message: largeMessage, + } + + err := pool.SendTicketMessage(message) + assert.NoError(t, err) + }) + + t.Run("Multiple Clients with Same SessionID", func(t *testing.T) { + pool := NewPool() + ws1, server1 := setupTestWebsocket(t) + ws2, server2 := setupTestWebsocket(t) + defer server1.Close() + defer server2.Close() + defer ws1.Close() + defer ws2.Close() + + client1 := &Client{ + Host: "same-session-id", + Conn: ws1, + Pool: pool, + } + + client2 := &Client{ + Host: "same-session-id", + Conn: ws2, + Pool: pool, + } + + pool.Clients = make(map[string]*ClientData) + pool.Clients[client1.Host] = &ClientData{ + Client: client1, + Status: true, + } + pool.Clients[client2.Host] = &ClientData{ + Client: client2, + Status: true, + } + + message := TicketMessage{ + BroadcastType: "direct", + SourceSessionID: "same-session-id", + Message: "Test message", + } + + err := pool.SendTicketMessage(message) + assert.NoError(t, err) + }) + + t.Run("Null or Uninitialized Pool", func(t *testing.T) { + var pool *Pool + message := TicketMessage{ + BroadcastType: "direct", + SourceSessionID: "test-client", + Message: "Test message", + } + + err := pool.SendTicketMessage(message) + assert.Error(t, err) + }) +} + +func setupTestWebsocket(t *testing.T) (*websocket.Conn, *httptest.Server) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + for { + _, _, err := conn.ReadMessage() + if err != nil { + break + } + } + })) + + wsURL := "ws" + server.URL[4:] + ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatal(err) + } + + return ws, server +}