diff --git a/src/webterm/local_server.py b/src/webterm/local_server.py index 57112c3..cd5a2a5 100644 --- a/src/webterm/local_server.py +++ b/src/webterm/local_server.py @@ -42,6 +42,8 @@ DEFAULT_TERMINAL_SIZE = (132, 45) SCREENSHOT_CACHE_SECONDS = 0.3 SCREENSHOT_MAX_CACHE_SECONDS = 20.0 +WS_SEND_QUEUE_MAX = 256 +WS_SEND_TIMEOUT = 2.0 WEBTERM_STATIC_PATH = Path(__file__).parent / "static" @@ -182,6 +184,8 @@ class LocalServer: self._exit_poller = ExitPoller(self, idle_wait=exit_on_idle) self._websocket_connections: dict[RouteKey, web.WebSocketResponse] = {} + self._ws_send_queues: dict[RouteKey, asyncio.Queue[bytes | None]] = {} + self._ws_send_tasks: dict[RouteKey, asyncio.Task] = {} self._landing_apps = landing_apps or [] self._compose_mode = compose_mode self._compose_project = compose_project @@ -460,6 +464,9 @@ class LocalServer: log.info("WebSocket connection established for route %s", route_key) self._websocket_connections[route_key] = ws + queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=WS_SEND_QUEUE_MAX) + self._ws_send_queues[route_key] = queue + self._ws_send_tasks[route_key] = asyncio.create_task(self._ws_sender(route_key, ws, queue)) session_id = self.session_manager.routes.get(RouteKey(route_key)) session = None @@ -503,6 +510,7 @@ class LocalServer: finally: log.info("WebSocket connection closed for route %s", route_key) self._websocket_connections.pop(route_key, None) + await self._stop_ws_sender(route_key) return ws @@ -1248,20 +1256,15 @@ class LocalServer: async def handle_session_data(self, route_key: RouteKey, data: bytes) -> None: self.mark_route_activity(str(route_key)) - ws = self._websocket_connections.get(route_key) - if ws is None: - return - await ws.send_bytes(data) + self._enqueue_ws_data(route_key, data) async def handle_binary_message(self, route_key: RouteKey, payload: bytes) -> None: self.mark_route_activity(str(route_key)) - ws = self._websocket_connections.get(route_key) - if ws is None: - return - await ws.send_bytes(payload) + self._enqueue_ws_data(route_key, payload) async def handle_session_close(self, session_id: SessionID, route_key: RouteKey) -> None: self.session_manager.on_session_end(session_id) + await self._stop_ws_sender(route_key) ws = self._websocket_connections.get(route_key) if ws is not None: with contextlib.suppress(Exception): @@ -1269,3 +1272,57 @@ class LocalServer: def force_exit(self) -> None: self.exit_event.set() + + def _enqueue_ws_data(self, route_key: RouteKey, data: bytes) -> None: + queue = self._ws_send_queues.get(route_key) + if queue is None: + return + try: + queue.put_nowait(data) + except asyncio.QueueFull: + # Drop oldest data to avoid blocking terminal sessions on slow clients. + with contextlib.suppress(asyncio.QueueEmpty): + queue.get_nowait() + try: + queue.put_nowait(data) + except asyncio.QueueFull: + log.warning("WebSocket send queue full for route %s; dropping output", route_key) + + async def _ws_sender( + self, + route_key: RouteKey, + ws: web.WebSocketResponse, + queue: asyncio.Queue[bytes | None], + ) -> None: + try: + while True: + data = await queue.get() + if data is None: + break + try: + await asyncio.wait_for(ws.send_bytes(data), timeout=WS_SEND_TIMEOUT) + except asyncio.TimeoutError: + log.warning("WebSocket send timeout for route %s; closing", route_key) + break + except ( + ConnectionResetError, + ConnectionAbortedError, + aiohttp.ClientConnectionError, + ) as exc: + log.warning("WebSocket send failed for route %s: %s", route_key, exc) + break + finally: + if not ws.closed: + with contextlib.suppress(Exception): + await ws.close() + + async def _stop_ws_sender(self, route_key: RouteKey) -> None: + queue = self._ws_send_queues.pop(route_key, None) + if queue is not None: + with contextlib.suppress(asyncio.QueueFull): + queue.put_nowait(None) + task = self._ws_send_tasks.pop(route_key, None) + if task is not None: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task diff --git a/tests/test_local_server_unit.py b/tests/test_local_server_unit.py index 95a68a8..88e9d9f 100644 --- a/tests/test_local_server_unit.py +++ b/tests/test_local_server_unit.py @@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, MagicMock +import asyncio import pytest from aiohttp import web @@ -370,7 +371,6 @@ class TestLocalServerMoreCoverage: async def test_handle_session_data_no_ws_noop(self, server_with_no_apps): await server_with_no_apps.handle_session_data("rk", b"data") - @pytest.mark.asyncio @pytest.mark.asyncio @pytest.mark.parametrize( ("handler", "payload"), @@ -382,9 +382,11 @@ class TestLocalServerMoreCoverage: async def test_handle_message_sends_bytes(self, server_with_no_apps, handler, payload): ws = MagicMock() ws.send_bytes = AsyncMock() + queue = asyncio.Queue(maxsize=10) server_with_no_apps._websocket_connections["rk"] = ws + server_with_no_apps._ws_send_queues["rk"] = queue await getattr(server_with_no_apps, handler)("rk", payload) - ws.send_bytes.assert_awaited_once_with(payload) + assert await queue.get() == payload @pytest.mark.asyncio async def test_handle_session_close_ends_session_and_closes_ws( @@ -393,10 +395,13 @@ class TestLocalServerMoreCoverage: ws = MagicMock() ws.close = AsyncMock() server_with_no_apps._websocket_connections["rk"] = ws + server_with_no_apps._ws_send_queues["rk"] = asyncio.Queue(maxsize=10) + server_with_no_apps._ws_send_tasks["rk"] = asyncio.create_task(asyncio.sleep(10)) monkeypatch.setattr(server_with_no_apps.session_manager, "on_session_end", MagicMock()) await server_with_no_apps.handle_session_close("sid", "rk") server_with_no_apps.session_manager.on_session_end.assert_called_once_with("sid") ws.close.assert_awaited_once() + assert "rk" not in server_with_no_apps._ws_send_tasks def test_force_exit_sets_event(self, server_with_no_apps): assert not server_with_no_apps.exit_event.is_set() @@ -710,11 +715,12 @@ class TestLocalServerMoreCoverage: ws = MagicMock() ws.send_bytes = AsyncMock() server_with_no_apps._websocket_connections["rk"] = ws + server_with_no_apps._ws_send_queues["rk"] = asyncio.Queue(maxsize=10) server_with_no_apps._route_last_activity["rk"] = 0.0 await server_with_no_apps.handle_session_data("rk", b"data") assert server_with_no_apps._route_last_activity["rk"] > 0.0 - ws.send_bytes.assert_awaited_once_with(b"data") + assert await server_with_no_apps._ws_send_queues["rk"].get() == b"data" def test_mark_route_activity_triggers_notification(self, server_with_no_apps): """Test that mark_route_activity triggers SSE notification.""" diff --git a/tests/test_ws_sender.py b/tests/test_ws_sender.py new file mode 100644 index 0000000..8ce9e8c --- /dev/null +++ b/tests/test_ws_sender.py @@ -0,0 +1,61 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from webterm.config import Config +from webterm.local_server import LocalServer, WS_SEND_TIMEOUT + + +@pytest.mark.asyncio +async def test_ws_sender_flushes_queue(): + server = LocalServer(config_path="./", config=Config(apps=[]), host="localhost", port=8080) + ws = MagicMock() + ws.send_bytes = AsyncMock() + ws.closed = False + ws.close = AsyncMock() + + queue: asyncio.Queue[bytes | None] = asyncio.Queue() + sender_task = asyncio.create_task(server._ws_sender("rk", ws, queue)) + + await queue.put(b"hello") + await queue.put(b"world") + await queue.put(None) + + await sender_task + ws.send_bytes.assert_any_await(b"hello") + ws.send_bytes.assert_any_await(b"world") + + +@pytest.mark.asyncio +async def test_ws_sender_timeout_closes(): + server = LocalServer(config_path="./", config=Config(apps=[]), host="localhost", port=8080) + ws = MagicMock() + ws.closed = False + ws.close = AsyncMock() + + async def slow_send(_data): + await asyncio.sleep(WS_SEND_TIMEOUT * 2) + + ws.send_bytes = AsyncMock(side_effect=slow_send) + + queue: asyncio.Queue[bytes | None] = asyncio.Queue() + sender_task = asyncio.create_task(server._ws_sender("rk", ws, queue)) + + await queue.put(b"slow") + await sender_task + + ws.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_enqueue_ws_data_drops_oldest_when_full(): + server = LocalServer(config_path="./", config=Config(apps=[]), host="localhost", port=8080) + queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1) + server._ws_send_queues["rk"] = queue + + queue.put_nowait(b"first") + server._enqueue_ws_data("rk", b"second") + + assert queue.qsize() == 1 + assert await queue.get() == b"second"