diff --git a/webterm/server.go b/webterm/server.go index 8d07d4c..f9a0c77 100644 --- a/webterm/server.go +++ b/webterm/server.go @@ -207,7 +207,10 @@ func (c *localClientConnector) OnMeta(meta map[string]any) { func (c *localClientConnector) OnClose() { c.server.sessionManager.OnSessionEnd(c.sessionID) - c.server.stopWSClient(c.routeKey) + if activeSessionID, ok := c.server.sessionManager.GetSessionIDByRouteKey(c.routeKey); ok && activeSessionID != c.sessionID { + return + } + c.server.stopWSClient(c.routeKey, nil) } func NewLocalServer(config Config, options ServerOptions) *LocalServer { @@ -331,15 +334,21 @@ func (s *LocalServer) enqueueWSFrame(routeKey string, messageType int, data []by } } -func (s *LocalServer) stopWSClient(routeKey string) { +func (s *LocalServer) stopWSClient(routeKey string, expected *wsClient) { s.mu.Lock() client := s.wsClients[routeKey] + if expected != nil && client != expected { + s.mu.Unlock() + return + } delete(s.wsClients, routeKey) s.mu.Unlock() if client == nil { return } - client.closed.Store(true) + if client.closed.Swap(true) { + return + } close(client.send) <-client.done } @@ -480,7 +489,7 @@ func (s *LocalServer) handleWebSocket(w http.ResponseWriter, r *http.Request) { s.wsClients[routeKey] = client s.mu.Unlock() go s.wsSender(client) - defer s.stopWSClient(routeKey) + defer s.stopWSClient(routeKey, client) // Helper to send JSON through the send channel (avoids concurrent conn writes) sendJSON := func(v any) { @@ -526,6 +535,7 @@ func (s *LocalServer) handleWebSocket(w http.ResponseWriter, r *http.Request) { } return } + _ = conn.SetReadDeadline(time.Now().Add(wsReadTimeout)) if messageType != websocket.TextMessage { continue } diff --git a/webterm/server_test.go b/webterm/server_test.go index b18052b..93cb167 100644 --- a/webterm/server_test.go +++ b/webterm/server_test.go @@ -182,6 +182,102 @@ func TestWebSocketReplayOnReconnect(t *testing.T) { } } +func TestWebSocketOldConnectionCloseDoesNotDropNewClient(t *testing.T) { + _, httpServer, _ := newServerForTests(t, false) + wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "/ws/shell" + + conn1, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("first dial error = %v", err) + } + if err := conn1.WriteJSON([]any{"resize", map[string]any{"width": 80, "height": 24}}); err != nil { + t.Fatalf("resize write: %v", err) + } + time.Sleep(20 * time.Millisecond) + + conn2, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("second dial error = %v", err) + } + defer conn2.Close() + + _ = conn1.Close() + time.Sleep(100 * time.Millisecond) + + if err := conn2.WriteJSON([]any{"ping", "still-open"}); err != nil { + t.Fatalf("conn2 write ping after conn1 close: %v", err) + } + _ = conn2.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, payload, err := conn2.ReadMessage() + if err != nil { + t.Fatalf("conn2 read pong after conn1 close: %v", err) + } + var pong []any + if err := json.Unmarshal(payload, &pong); err != nil { + t.Fatalf("decode pong: %v", err) + } + if pong[0] != "pong" || pong[1] != "still-open" { + t.Fatalf("unexpected pong payload: %v", pong) + } +} + +func TestStaleSessionConnectorCloseDoesNotDropReassignedRouteClient(t *testing.T) { + server, httpServer, _ := newServerForTests(t, false) + wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "/ws/shell" + + conn1, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("first dial error = %v", err) + } + defer conn1.Close() + if err := conn1.WriteJSON([]any{"resize", map[string]any{"width": 80, "height": 24}}); err != nil { + t.Fatalf("resize write: %v", err) + } + var sessionID string + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if sid, ok := server.sessionManager.GetSessionIDByRouteKey("shell"); ok { + sessionID = sid + break + } + time.Sleep(10 * time.Millisecond) + } + if sessionID == "" { + t.Fatalf("expected initial session id") + } + + conn2, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("second dial error = %v", err) + } + defer conn2.Close() + + // Simulate route reassignment before stale connector close callback runs. + server.sessionManager.OnSessionEnd(sessionID) + if _, err := server.sessionManager.NewSession("shell", "replacement-session", "shell", 80, 24); err != nil { + t.Fatalf("replacement session create failed: %v", err) + } + + staleConnector := &localClientConnector{server: server, sessionID: sessionID, routeKey: "shell"} + staleConnector.OnClose() + + if err := conn2.WriteJSON([]any{"ping", "route-still-open"}); err != nil { + t.Fatalf("conn2 write ping after stale close: %v", err) + } + _ = conn2.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, payload, err := conn2.ReadMessage() + if err != nil { + t.Fatalf("conn2 read pong after stale close: %v", err) + } + var pong []any + if err := json.Unmarshal(payload, &pong); err != nil { + t.Fatalf("decode pong: %v", err) + } + if pong[0] != "pong" || pong[1] != "route-still-open" { + t.Fatalf("unexpected pong payload: %v", pong) + } +} + func TestScreenshotAndETag(t *testing.T) { server, httpServer, _ := newServerForTests(t, false) if _, err := server.sessionManager.NewSession("shell", "sid", "shell", 80, 24); err != nil {