From 57a93a7cc6f5562535c04c4d305e7773db7a3b82 Mon Sep 17 00:00:00 2001 From: GitHub Copilot Date: Wed, 11 Feb 2026 09:21:29 +0000 Subject: [PATCH] Fix stdin stalls and test warnings --- src/webterm/local_server.py | 34 ++++++++++++++++- src/webterm/poller.py | 52 +++++++++++++++++++------- tests/test_cli.py | 20 ++++++++-- tests/test_local_server_unit.py | 66 +++++++++++++++++++++++++++++++++ tests/test_poller.py | 33 +++++++++++++++++ 5 files changed, 186 insertions(+), 19 deletions(-) diff --git a/src/webterm/local_server.py b/src/webterm/local_server.py index 322e618..eb37ae2 100644 --- a/src/webterm/local_server.py +++ b/src/webterm/local_server.py @@ -45,6 +45,7 @@ SCREENSHOT_MAX_CACHE_SECONDS = 20.0 SCREENSHOT_FORCE_REDRAW = constants.get_environ_bool(constants.SCREENSHOT_FORCE_REDRAW_ENV) WS_SEND_QUEUE_MAX = 256 WS_SEND_TIMEOUT = 2.0 +STDIN_WRITE_TIMEOUT = 2.0 WEBTERM_STATIC_PATH = Path(__file__).parent / "static" @@ -446,6 +447,9 @@ class LocalServer: # SSE subscribers for activity notifications self._sse_subscribers: list[asyncio.Queue[str]] = [] + # Background tasks for fire-and-forget stdin writes (prevent GC) + self._stdin_tasks: set[asyncio.Task] = set() + # Docker stats collector (only used in compose mode) self._docker_stats: DockerStatsCollector | None = None # Docker watcher (only used in docker watch mode) @@ -469,6 +473,18 @@ class LocalServer: slug = slug or generate().lower() self.session_manager.add_app(name, command, slug=slug, terminal=True, theme=theme) + def _track_stdin_task(self, task: asyncio.Task, route_key: str) -> None: + self._stdin_tasks.add(task) + task.add_done_callback(lambda done: self._finalize_stdin_task(done, route_key)) + + def _finalize_stdin_task(self, task: asyncio.Task, route_key: str) -> None: + self._stdin_tasks.discard(task) + if task.cancelled(): + return + exc = task.exception() + if exc: + log.warning("Stdin write task failed for route %s: %s", route_key, exc) + async def run(self) -> None: try: await self._run() @@ -655,7 +671,23 @@ class LocalServer: data = envelope[1] if len(envelope) > 1 else "" session_process = self.session_manager.get_session_by_route_key(RouteKey(route_key)) if session_process: - await session_process.send_bytes(data.encode("utf-8")) + # Fire-and-forget: don't block the WS receive loop waiting for the + # PTY fd to become writable. A slow or stalled child process must + # never prevent subsequent keystrokes from being dispatched. + task = asyncio.create_task(self._write_stdin(session_process, data, route_key)) + self._track_stdin_task(task, route_key) + + async def _write_stdin(self, session_process, data: str, route_key: str) -> None: + """Write stdin data to session with a timeout to avoid indefinite stalls.""" + try: + await asyncio.wait_for( + session_process.send_bytes(data.encode("utf-8")), + timeout=STDIN_WRITE_TIMEOUT, + ) + except asyncio.TimeoutError: + log.warning("Stdin write timeout for route %s; dropping input", route_key) + except OSError as exc: + log.warning("Stdin write failed for route %s: %s", route_key, exc) async def _handle_resize( self, envelope: list, route_key: str, _ws: web.WebSocketResponse diff --git a/src/webterm/poller.py b/src/webterm/poller.py index ceec61f..efcb9dc 100644 --- a/src/webterm/poller.py +++ b/src/webterm/poller.py @@ -73,6 +73,25 @@ class Poller(Thread): return await new_write.done_event.wait() + async def write_with_timeout( + self, file_descriptor: int, data: bytes, timeout: float = 2.0 + ) -> bool: + """Write data to a file descriptor with a timeout. + + Args: + file_descriptor: File descriptor. + data: Data to write. + timeout: Maximum seconds to wait for the write to complete. + + Returns: + True if the write completed, False on timeout. + """ + try: + await asyncio.wait_for(self.write(file_descriptor, data), timeout=timeout) + return True + except asyncio.TimeoutError: + return False + def set_loop(self, loop: asyncio.AbstractEventLoop) -> None: """Set the asyncio loop. @@ -110,20 +129,25 @@ class Poller(Thread): if event_mask & writeable_events: write_queue = self._write_queues.get(file_descriptor, None) if write_queue: - write = write_queue[0] - remaining_data = write.data[write.position :] - try: - bytes_written = os.write(file_descriptor, remaining_data) - except OSError: - # Write failed; signal completion anyway to unblock waiters - write_queue.popleft() - loop.call_soon_threadsafe(write.done_event.set) - continue - write.position += bytes_written - # Check if all data has been written - if write.position >= len(write.data): - write_queue.popleft() - loop.call_soon_threadsafe(write.done_event.set) + # Process all pending writes while fd is writable + while write_queue: + write = write_queue[0] + remaining_data = write.data[write.position :] + try: + bytes_written = os.write(file_descriptor, remaining_data) + except OSError: + # Write failed; signal completion anyway to unblock waiters + write_queue.popleft() + loop.call_soon_threadsafe(write.done_event.set) + break + write.position += bytes_written + # Check if all data has been written + if write.position >= len(write.data): + write_queue.popleft() + loop.call_soon_threadsafe(write.done_event.set) + else: + # Partial write — fd buffer full, try again next cycle + break else: selector.modify(file_descriptor, readable_events) diff --git a/tests/test_cli.py b/tests/test_cli.py index f2d17a7..4c0157a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -5,6 +5,10 @@ from click.testing import CliRunner from webterm import cli +def _close_coroutine(coro) -> None: + coro.close() + + class TestCLI: """Tests for CLI command.""" @@ -31,7 +35,7 @@ class TestCLI: calls["run"] = True monkeypatch.setattr(cli, "LocalServer", FakeServer) - monkeypatch.setattr(cli.asyncio, "run", lambda _coro: None) + monkeypatch.setattr(cli.asyncio, "run", _close_coroutine) runner = CliRunner() result = runner.invoke(cli.app, ["htop"]) @@ -55,7 +59,7 @@ class TestCLI: monkeypatch.setenv("SHELL", "/bin/zsh") monkeypatch.setattr(cli, "LocalServer", FakeServer) - monkeypatch.setattr(cli.asyncio, "run", lambda _coro: None) + monkeypatch.setattr(cli.asyncio, "run", _close_coroutine) runner = CliRunner() result = runner.invoke(cli.app, []) @@ -131,7 +135,11 @@ def test_cli_docker_watch_mode(monkeypatch): calls["run"] = True monkeypatch.setattr(cli, "LocalServer", FakeServer) - monkeypatch.setattr(cli.asyncio, "run", lambda _coro: calls.setdefault("run", True)) + def run_and_close(coro): + calls.setdefault("run", True) + coro.close() + + monkeypatch.setattr(cli.asyncio, "run", run_and_close) monkeypatch.setattr(cli.constants, "DEBUG", True) runner = CliRunner() @@ -155,7 +163,11 @@ def test_cli_windows_branch(monkeypatch): monkeypatch.setattr(cli, "LocalServer", FakeServer) monkeypatch.setattr(cli.constants, "WINDOWS", True) - monkeypatch.setattr(cli.asyncio, "run", lambda _coro: calls.setdefault("run", True)) + def run_and_close(coro): + calls.setdefault("run", True) + coro.close() + + monkeypatch.setattr(cli.asyncio, "run", run_and_close) runner = CliRunner() result = runner.invoke(cli.app, ["--docker-watch"]) diff --git a/tests/test_local_server_unit.py b/tests/test_local_server_unit.py index 20af3f6..d859251 100644 --- a/tests/test_local_server_unit.py +++ b/tests/test_local_server_unit.py @@ -12,6 +12,14 @@ from webterm.local_server import ( ) +async def wait_for_asyncmock_call(mock: AsyncMock, timeout: float = 0.1) -> None: + async def _wait() -> None: + while mock.await_count == 0: + await asyncio.sleep(0) + + await asyncio.wait_for(_wait(), timeout=timeout) + + class TestGetStaticPath: """Tests for static path.""" @@ -599,6 +607,8 @@ class TestLocalServerMoreCoverage: ws = MagicMock() created = await server_with_no_apps._dispatch_ws_message(["stdin"], "rk", ws, False) assert created is False + # stdin writes are fire-and-forget; wait until send_bytes is awaited + await wait_for_asyncmock_call(session.send_bytes) session.send_bytes.assert_awaited_once_with(b"") @pytest.mark.asyncio @@ -744,3 +754,59 @@ class TestLocalServerMoreCoverage: assert not queue.empty() assert queue.get_nowait() == "my-route" + + @pytest.mark.asyncio + async def test_handle_stdin_does_not_block_ws_loop( + self, server_with_no_apps, monkeypatch + ): + """Stdin writes should be fire-and-forget so the WS loop keeps processing.""" + send_started = asyncio.Event() + send_gate = asyncio.Event() + + async def slow_send(_data): + send_started.set() + await send_gate.wait() + return True + + session = MagicMock() + session.send_bytes = AsyncMock(side_effect=slow_send) + monkeypatch.setattr( + server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session + ) + + ws = MagicMock() + # _dispatch_ws_message should return immediately even though send_bytes blocks + created = await server_with_no_apps._dispatch_ws_message( + ["stdin", "hello"], "rk", ws, False + ) + assert created is False + + # The background task should have been created but not finished + await send_started.wait() + assert not send_gate.is_set() + + # Unblock and let the task finish + send_gate.set() + await asyncio.sleep(0) + + @pytest.mark.asyncio + async def test_write_stdin_logs_timeout( + self, server_with_no_apps, monkeypatch, caplog + ): + """_write_stdin should log a warning and not raise on timeout.""" + async def hang_forever(_data): + await asyncio.sleep(999) + return True + + session = MagicMock() + session.send_bytes = AsyncMock(side_effect=hang_forever) + + import logging + + # Use a very short timeout for testing + monkeypatch.setattr("webterm.local_server.STDIN_WRITE_TIMEOUT", 0.01) + + with caplog.at_level(logging.WARNING, logger="webterm"): + await server_with_no_apps._write_stdin(session, "x", "rk") + + assert "Stdin write timeout" in caplog.text diff --git a/tests/test_poller.py b/tests/test_poller.py index 5a3e0fc..81e67e3 100644 --- a/tests/test_poller.py +++ b/tests/test_poller.py @@ -136,3 +136,36 @@ class TestPoller: # Queues should have None assert q1.get_nowait() is None assert q2.get_nowait() is None + + @pytest.mark.asyncio + async def test_write_with_timeout_returns_true_on_success(self): + """write_with_timeout returns True when write completes.""" + poller = Poller() + poller._loop = asyncio.get_event_loop() + + with patch.object(poller._selector, "register"): + poller.add_file(42) + + async def instant_write(fd, data): + # Simulate immediate completion + pass + + with patch.object(poller, "write", side_effect=instant_write): + result = await poller.write_with_timeout(42, b"test", timeout=1.0) + assert result is True + + @pytest.mark.asyncio + async def test_write_with_timeout_returns_false_on_timeout(self): + """write_with_timeout returns False when write times out.""" + poller = Poller() + poller._loop = asyncio.get_event_loop() + + with patch.object(poller._selector, "register"): + poller.add_file(42) + + async def slow_write(fd, data): + await asyncio.sleep(999) + + with patch.object(poller, "write", side_effect=slow_write): + result = await poller.write_with_timeout(42, b"test", timeout=0.01) + assert result is False