fix: prevent websocket backpressure from freezing sessions

Introduce per-route send queues and dedicated sender tasks so terminal output
does not await slow WebSocket clients. Output is buffered up to a bounded
queue; when full, the oldest data is dropped to keep sessions responsive.

Sender tasks enforce a send timeout and close slow/broken sockets, preventing
terminal run loops from stalling indefinitely.

Tests updated/added to verify:
- queued output instead of direct ws.send_bytes
- sender timeout closes sockets
- queue overflow drops oldest
- session close stops sender task
This commit is contained in:
GitHub Copilot
2026-01-29 21:06:26 +00:00
parent 4b4451d3c8
commit 9734a8b43b
3 changed files with 135 additions and 11 deletions
+65 -8
View File
@@ -42,6 +42,8 @@ DEFAULT_TERMINAL_SIZE = (132, 45)
SCREENSHOT_CACHE_SECONDS = 0.3 SCREENSHOT_CACHE_SECONDS = 0.3
SCREENSHOT_MAX_CACHE_SECONDS = 20.0 SCREENSHOT_MAX_CACHE_SECONDS = 20.0
WS_SEND_QUEUE_MAX = 256
WS_SEND_TIMEOUT = 2.0
WEBTERM_STATIC_PATH = Path(__file__).parent / "static" WEBTERM_STATIC_PATH = Path(__file__).parent / "static"
@@ -182,6 +184,8 @@ class LocalServer:
self._exit_poller = ExitPoller(self, idle_wait=exit_on_idle) self._exit_poller = ExitPoller(self, idle_wait=exit_on_idle)
self._websocket_connections: dict[RouteKey, web.WebSocketResponse] = {} self._websocket_connections: dict[RouteKey, web.WebSocketResponse] = {}
self._ws_send_queues: dict[RouteKey, asyncio.Queue[bytes | None]] = {}
self._ws_send_tasks: dict[RouteKey, asyncio.Task] = {}
self._landing_apps = landing_apps or [] self._landing_apps = landing_apps or []
self._compose_mode = compose_mode self._compose_mode = compose_mode
self._compose_project = compose_project self._compose_project = compose_project
@@ -460,6 +464,9 @@ class LocalServer:
log.info("WebSocket connection established for route %s", route_key) log.info("WebSocket connection established for route %s", route_key)
self._websocket_connections[route_key] = ws self._websocket_connections[route_key] = ws
queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=WS_SEND_QUEUE_MAX)
self._ws_send_queues[route_key] = queue
self._ws_send_tasks[route_key] = asyncio.create_task(self._ws_sender(route_key, ws, queue))
session_id = self.session_manager.routes.get(RouteKey(route_key)) session_id = self.session_manager.routes.get(RouteKey(route_key))
session = None session = None
@@ -503,6 +510,7 @@ class LocalServer:
finally: finally:
log.info("WebSocket connection closed for route %s", route_key) log.info("WebSocket connection closed for route %s", route_key)
self._websocket_connections.pop(route_key, None) self._websocket_connections.pop(route_key, None)
await self._stop_ws_sender(route_key)
return ws return ws
@@ -1248,20 +1256,15 @@ class LocalServer:
async def handle_session_data(self, route_key: RouteKey, data: bytes) -> None: async def handle_session_data(self, route_key: RouteKey, data: bytes) -> None:
self.mark_route_activity(str(route_key)) self.mark_route_activity(str(route_key))
ws = self._websocket_connections.get(route_key) self._enqueue_ws_data(route_key, data)
if ws is None:
return
await ws.send_bytes(data)
async def handle_binary_message(self, route_key: RouteKey, payload: bytes) -> None: async def handle_binary_message(self, route_key: RouteKey, payload: bytes) -> None:
self.mark_route_activity(str(route_key)) self.mark_route_activity(str(route_key))
ws = self._websocket_connections.get(route_key) self._enqueue_ws_data(route_key, payload)
if ws is None:
return
await ws.send_bytes(payload)
async def handle_session_close(self, session_id: SessionID, route_key: RouteKey) -> None: async def handle_session_close(self, session_id: SessionID, route_key: RouteKey) -> None:
self.session_manager.on_session_end(session_id) self.session_manager.on_session_end(session_id)
await self._stop_ws_sender(route_key)
ws = self._websocket_connections.get(route_key) ws = self._websocket_connections.get(route_key)
if ws is not None: if ws is not None:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
@@ -1269,3 +1272,57 @@ class LocalServer:
def force_exit(self) -> None: def force_exit(self) -> None:
self.exit_event.set() self.exit_event.set()
def _enqueue_ws_data(self, route_key: RouteKey, data: bytes) -> None:
queue = self._ws_send_queues.get(route_key)
if queue is None:
return
try:
queue.put_nowait(data)
except asyncio.QueueFull:
# Drop oldest data to avoid blocking terminal sessions on slow clients.
with contextlib.suppress(asyncio.QueueEmpty):
queue.get_nowait()
try:
queue.put_nowait(data)
except asyncio.QueueFull:
log.warning("WebSocket send queue full for route %s; dropping output", route_key)
async def _ws_sender(
self,
route_key: RouteKey,
ws: web.WebSocketResponse,
queue: asyncio.Queue[bytes | None],
) -> None:
try:
while True:
data = await queue.get()
if data is None:
break
try:
await asyncio.wait_for(ws.send_bytes(data), timeout=WS_SEND_TIMEOUT)
except asyncio.TimeoutError:
log.warning("WebSocket send timeout for route %s; closing", route_key)
break
except (
ConnectionResetError,
ConnectionAbortedError,
aiohttp.ClientConnectionError,
) as exc:
log.warning("WebSocket send failed for route %s: %s", route_key, exc)
break
finally:
if not ws.closed:
with contextlib.suppress(Exception):
await ws.close()
async def _stop_ws_sender(self, route_key: RouteKey) -> None:
queue = self._ws_send_queues.pop(route_key, None)
if queue is not None:
with contextlib.suppress(asyncio.QueueFull):
queue.put_nowait(None)
task = self._ws_send_tasks.pop(route_key, None)
if task is not None:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
+9 -3
View File
@@ -2,6 +2,7 @@
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import asyncio
import pytest import pytest
from aiohttp import web from aiohttp import web
@@ -370,7 +371,6 @@ class TestLocalServerMoreCoverage:
async def test_handle_session_data_no_ws_noop(self, server_with_no_apps): async def test_handle_session_data_no_ws_noop(self, server_with_no_apps):
await server_with_no_apps.handle_session_data("rk", b"data") await server_with_no_apps.handle_session_data("rk", b"data")
@pytest.mark.asyncio
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("handler", "payload"), ("handler", "payload"),
@@ -382,9 +382,11 @@ class TestLocalServerMoreCoverage:
async def test_handle_message_sends_bytes(self, server_with_no_apps, handler, payload): async def test_handle_message_sends_bytes(self, server_with_no_apps, handler, payload):
ws = MagicMock() ws = MagicMock()
ws.send_bytes = AsyncMock() ws.send_bytes = AsyncMock()
queue = asyncio.Queue(maxsize=10)
server_with_no_apps._websocket_connections["rk"] = ws server_with_no_apps._websocket_connections["rk"] = ws
server_with_no_apps._ws_send_queues["rk"] = queue
await getattr(server_with_no_apps, handler)("rk", payload) await getattr(server_with_no_apps, handler)("rk", payload)
ws.send_bytes.assert_awaited_once_with(payload) assert await queue.get() == payload
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_session_close_ends_session_and_closes_ws( async def test_handle_session_close_ends_session_and_closes_ws(
@@ -393,10 +395,13 @@ class TestLocalServerMoreCoverage:
ws = MagicMock() ws = MagicMock()
ws.close = AsyncMock() ws.close = AsyncMock()
server_with_no_apps._websocket_connections["rk"] = ws server_with_no_apps._websocket_connections["rk"] = ws
server_with_no_apps._ws_send_queues["rk"] = asyncio.Queue(maxsize=10)
server_with_no_apps._ws_send_tasks["rk"] = asyncio.create_task(asyncio.sleep(10))
monkeypatch.setattr(server_with_no_apps.session_manager, "on_session_end", MagicMock()) monkeypatch.setattr(server_with_no_apps.session_manager, "on_session_end", MagicMock())
await server_with_no_apps.handle_session_close("sid", "rk") await server_with_no_apps.handle_session_close("sid", "rk")
server_with_no_apps.session_manager.on_session_end.assert_called_once_with("sid") server_with_no_apps.session_manager.on_session_end.assert_called_once_with("sid")
ws.close.assert_awaited_once() ws.close.assert_awaited_once()
assert "rk" not in server_with_no_apps._ws_send_tasks
def test_force_exit_sets_event(self, server_with_no_apps): def test_force_exit_sets_event(self, server_with_no_apps):
assert not server_with_no_apps.exit_event.is_set() assert not server_with_no_apps.exit_event.is_set()
@@ -710,11 +715,12 @@ class TestLocalServerMoreCoverage:
ws = MagicMock() ws = MagicMock()
ws.send_bytes = AsyncMock() ws.send_bytes = AsyncMock()
server_with_no_apps._websocket_connections["rk"] = ws server_with_no_apps._websocket_connections["rk"] = ws
server_with_no_apps._ws_send_queues["rk"] = asyncio.Queue(maxsize=10)
server_with_no_apps._route_last_activity["rk"] = 0.0 server_with_no_apps._route_last_activity["rk"] = 0.0
await server_with_no_apps.handle_session_data("rk", b"data") await server_with_no_apps.handle_session_data("rk", b"data")
assert server_with_no_apps._route_last_activity["rk"] > 0.0 assert server_with_no_apps._route_last_activity["rk"] > 0.0
ws.send_bytes.assert_awaited_once_with(b"data") assert await server_with_no_apps._ws_send_queues["rk"].get() == b"data"
def test_mark_route_activity_triggers_notification(self, server_with_no_apps): def test_mark_route_activity_triggers_notification(self, server_with_no_apps):
"""Test that mark_route_activity triggers SSE notification.""" """Test that mark_route_activity triggers SSE notification."""
+61
View File
@@ -0,0 +1,61 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from webterm.config import Config
from webterm.local_server import LocalServer, WS_SEND_TIMEOUT
@pytest.mark.asyncio
async def test_ws_sender_flushes_queue():
server = LocalServer(config_path="./", config=Config(apps=[]), host="localhost", port=8080)
ws = MagicMock()
ws.send_bytes = AsyncMock()
ws.closed = False
ws.close = AsyncMock()
queue: asyncio.Queue[bytes | None] = asyncio.Queue()
sender_task = asyncio.create_task(server._ws_sender("rk", ws, queue))
await queue.put(b"hello")
await queue.put(b"world")
await queue.put(None)
await sender_task
ws.send_bytes.assert_any_await(b"hello")
ws.send_bytes.assert_any_await(b"world")
@pytest.mark.asyncio
async def test_ws_sender_timeout_closes():
server = LocalServer(config_path="./", config=Config(apps=[]), host="localhost", port=8080)
ws = MagicMock()
ws.closed = False
ws.close = AsyncMock()
async def slow_send(_data):
await asyncio.sleep(WS_SEND_TIMEOUT * 2)
ws.send_bytes = AsyncMock(side_effect=slow_send)
queue: asyncio.Queue[bytes | None] = asyncio.Queue()
sender_task = asyncio.create_task(server._ws_sender("rk", ws, queue))
await queue.put(b"slow")
await sender_task
ws.close.assert_awaited_once()
@pytest.mark.asyncio
async def test_enqueue_ws_data_drops_oldest_when_full():
server = LocalServer(config_path="./", config=Config(apps=[]), host="localhost", port=8080)
queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
server._ws_send_queues["rk"] = queue
queue.put_nowait(b"first")
server._enqueue_ws_data("rk", b"second")
assert queue.qsize() == 1
assert await queue.get() == b"second"