From 2d50731fa0fb9b4f4b4e529110362d9b0c733cb5 Mon Sep 17 00:00:00 2001 From: GitHub Copilot Date: Mon, 16 Feb 2026 23:04:55 +0000 Subject: [PATCH] Harden websocket and stdin backpressure handling Replace silent output frame dropping with fail-fast slow-client disconnects when websocket send queues saturate, and replace unbounded stdin write goroutine spawning with a bounded queue + worker and timeout-driven disconnect under input backlog. Also add targeted regression tests for queue saturation and stdin backlog disconnect behavior. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- webterm/server.go | 34 ++++++++------ webterm/server_test.go | 103 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 15 deletions(-) diff --git a/webterm/server.go b/webterm/server.go index f9a0c77..7eaed74 100644 --- a/webterm/server.go +++ b/webterm/server.go @@ -322,15 +322,8 @@ func (s *LocalServer) enqueueWSFrame(routeKey string, messageType int, data []by select { case client.send <- frame: default: - // Drop oldest, try again - select { - case <-client.send: - default: - } - select { - case client.send <- frame: - default: - } + log.Printf("websocket send queue saturated route=%s: disconnecting slow client", routeKey) + s.stopWSClient(routeKey, client) } } @@ -526,6 +519,19 @@ func (s *LocalServer) handleWebSocket(w http.ResponseWriter, r *http.Request) { conn.SetPongHandler(func(string) error { return conn.SetReadDeadline(time.Now().Add(wsReadTimeout)) }) + type stdinWrite struct { + session Session + data string + } + stdinQueue := make(chan stdinWrite, wsSendQueueMax) + defer close(stdinQueue) + go func() { + for write := range stdinQueue { + if !write.session.SendBytes([]byte(write.data)) { + log.Printf("stdin write failed route=%s remote=%s", routeKey, r.RemoteAddr) + } + } + }() for { messageType, payload, err := conn.ReadMessage() @@ -553,14 +559,12 @@ func (s *LocalServer) handleWebSocket(w http.ResponseWriter, r *http.Request) { if len(envelope) > 1 { data, _ = envelope[1].(string) } - done := make(chan struct{}) - go func() { - defer close(done) - _ = session.SendBytes([]byte(data)) - }() select { - case <-done: + case stdinQueue <- stdinWrite{session: session, data: data}: case <-time.After(stdinWriteTimeout): + log.Printf("stdin queue saturated route=%s remote=%s: disconnecting client", routeKey, r.RemoteAddr) + sendJSON([]any{"error", "Input backlog detected"}) + return } } case "resize": diff --git a/webterm/server_test.go b/webterm/server_test.go index 93cb167..fbe0d41 100644 --- a/webterm/server_test.go +++ b/webterm/server_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/rcarmo/webterm/internal/terminalstate" ) type failingSSEWriter struct { @@ -62,6 +63,43 @@ type syncSessionMap struct { m map[string]*fakeSession } +type blockingSession struct { + mu sync.Mutex + running bool + blockCh <-chan struct{} +} + +func newBlockingSession(blockCh <-chan struct{}) *blockingSession { + return &blockingSession{running: true, blockCh: blockCh} +} + +func (b *blockingSession) Open(int, int) error { return nil } +func (b *blockingSession) Start(SessionConnector) error { return nil } +func (b *blockingSession) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + b.running = false + return nil +} +func (b *blockingSession) Wait() error { return nil } +func (b *blockingSession) SetTerminalSize(int, int) error { return nil } +func (b *blockingSession) SendMeta(map[string]any) bool { return true } +func (b *blockingSession) GetReplayBuffer() []byte { return nil } +func (b *blockingSession) ForceRedraw() error { return nil } +func (b *blockingSession) UpdateConnector(SessionConnector) {} +func (b *blockingSession) GetScreenSnapshot() terminalstate.Snapshot { + return terminalstate.Snapshot{Width: 80, Height: 24, Buffer: make([][]terminalstate.Cell, 24)} +} +func (b *blockingSession) SendBytes([]byte) bool { + <-b.blockCh + return true +} +func (b *blockingSession) IsRunning() bool { + b.mu.Lock() + defer b.mu.Unlock() + return b.running +} + func TestHealthAndTilesEndpoints(t *testing.T) { _, httpServer, _ := newServerForTests(t, true) resp, err := http.Get(httpServer.URL + "/health") @@ -278,6 +316,71 @@ func TestStaleSessionConnectorCloseDoesNotDropReassignedRouteClient(t *testing.T } } +func TestEnqueueWSFrameQueueSaturationDisconnectsSlowClient(t *testing.T) { + server := NewLocalServer(Config{}, ServerOptions{}) + client := &wsClient{ + routeKey: "shell", + send: make(chan wsOutbound, 1), + done: make(chan struct{}), + } + client.send <- wsOutbound{messageType: websocket.BinaryMessage, payload: []byte("old")} + close(client.done) + + server.mu.Lock() + server.wsClients["shell"] = client + server.mu.Unlock() + + server.enqueueWSFrame("shell", websocket.BinaryMessage, []byte("new")) + + if !client.closed.Load() { + t.Fatalf("expected saturated client to be marked closed") + } + server.mu.RLock() + _, exists := server.wsClients["shell"] + server.mu.RUnlock() + if exists { + t.Fatalf("expected saturated client to be removed from wsClients") + } +} + +func TestWebSocketDisconnectsOnStdinBacklog(t *testing.T) { + blockCh := make(chan struct{}) + t.Cleanup(func() { close(blockCh) }) + config := Config{ + Apps: []App{{Name: "Shell", Slug: "shell", Command: "/bin/sh", Terminal: true}}, + } + server := NewLocalServer(config, ServerOptions{}) + server.sessionManager.SetSessionFactory(func(app App, sessionID string) Session { + return newBlockingSession(blockCh) + }) + httpServer := httptest.NewServer(server.Handler()) + t.Cleanup(httpServer.Close) + + wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "/ws/shell" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("ws dial error = %v", err) + } + defer conn.Close() + if err := conn.WriteJSON([]any{"resize", map[string]any{"width": 80, "height": 24}}); err != nil { + t.Fatalf("write resize: %v", err) + } + time.Sleep(20 * time.Millisecond) + + for i := 0; i < wsSendQueueMax+32; i++ { + if err := conn.WriteJSON([]any{"stdin", "x"}); err != nil { + break + } + } + + _ = conn.SetReadDeadline(time.Now().Add(6 * time.Second)) + for { + if _, _, err := conn.ReadMessage(); err != nil { + return + } + } +} + func TestScreenshotAndETag(t *testing.T) { server, httpServer, _ := newServerForTests(t, false) if _, err := server.sessionManager.NewSession("shell", "sid", "shell", 80, 24); err != nil {