Fix stdin stalls and test warnings
This commit is contained in:
@@ -45,6 +45,7 @@ SCREENSHOT_MAX_CACHE_SECONDS = 20.0
|
|||||||
SCREENSHOT_FORCE_REDRAW = constants.get_environ_bool(constants.SCREENSHOT_FORCE_REDRAW_ENV)
|
SCREENSHOT_FORCE_REDRAW = constants.get_environ_bool(constants.SCREENSHOT_FORCE_REDRAW_ENV)
|
||||||
WS_SEND_QUEUE_MAX = 256
|
WS_SEND_QUEUE_MAX = 256
|
||||||
WS_SEND_TIMEOUT = 2.0
|
WS_SEND_TIMEOUT = 2.0
|
||||||
|
STDIN_WRITE_TIMEOUT = 2.0
|
||||||
|
|
||||||
|
|
||||||
WEBTERM_STATIC_PATH = Path(__file__).parent / "static"
|
WEBTERM_STATIC_PATH = Path(__file__).parent / "static"
|
||||||
@@ -446,6 +447,9 @@ class LocalServer:
|
|||||||
# SSE subscribers for activity notifications
|
# SSE subscribers for activity notifications
|
||||||
self._sse_subscribers: list[asyncio.Queue[str]] = []
|
self._sse_subscribers: list[asyncio.Queue[str]] = []
|
||||||
|
|
||||||
|
# Background tasks for fire-and-forget stdin writes (prevent GC)
|
||||||
|
self._stdin_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
# Docker stats collector (only used in compose mode)
|
# Docker stats collector (only used in compose mode)
|
||||||
self._docker_stats: DockerStatsCollector | None = None
|
self._docker_stats: DockerStatsCollector | None = None
|
||||||
# Docker watcher (only used in docker watch mode)
|
# Docker watcher (only used in docker watch mode)
|
||||||
@@ -469,6 +473,18 @@ class LocalServer:
|
|||||||
slug = slug or generate().lower()
|
slug = slug or generate().lower()
|
||||||
self.session_manager.add_app(name, command, slug=slug, terminal=True, theme=theme)
|
self.session_manager.add_app(name, command, slug=slug, terminal=True, theme=theme)
|
||||||
|
|
||||||
|
def _track_stdin_task(self, task: asyncio.Task, route_key: str) -> None:
|
||||||
|
self._stdin_tasks.add(task)
|
||||||
|
task.add_done_callback(lambda done: self._finalize_stdin_task(done, route_key))
|
||||||
|
|
||||||
|
def _finalize_stdin_task(self, task: asyncio.Task, route_key: str) -> None:
|
||||||
|
self._stdin_tasks.discard(task)
|
||||||
|
if task.cancelled():
|
||||||
|
return
|
||||||
|
exc = task.exception()
|
||||||
|
if exc:
|
||||||
|
log.warning("Stdin write task failed for route %s: %s", route_key, exc)
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
try:
|
try:
|
||||||
await self._run()
|
await self._run()
|
||||||
@@ -655,7 +671,23 @@ class LocalServer:
|
|||||||
data = envelope[1] if len(envelope) > 1 else ""
|
data = envelope[1] if len(envelope) > 1 else ""
|
||||||
session_process = self.session_manager.get_session_by_route_key(RouteKey(route_key))
|
session_process = self.session_manager.get_session_by_route_key(RouteKey(route_key))
|
||||||
if session_process:
|
if session_process:
|
||||||
await session_process.send_bytes(data.encode("utf-8"))
|
# Fire-and-forget: don't block the WS receive loop waiting for the
|
||||||
|
# PTY fd to become writable. A slow or stalled child process must
|
||||||
|
# never prevent subsequent keystrokes from being dispatched.
|
||||||
|
task = asyncio.create_task(self._write_stdin(session_process, data, route_key))
|
||||||
|
self._track_stdin_task(task, route_key)
|
||||||
|
|
||||||
|
async def _write_stdin(self, session_process, data: str, route_key: str) -> None:
|
||||||
|
"""Write stdin data to session with a timeout to avoid indefinite stalls."""
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
session_process.send_bytes(data.encode("utf-8")),
|
||||||
|
timeout=STDIN_WRITE_TIMEOUT,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
log.warning("Stdin write timeout for route %s; dropping input", route_key)
|
||||||
|
except OSError as exc:
|
||||||
|
log.warning("Stdin write failed for route %s: %s", route_key, exc)
|
||||||
|
|
||||||
async def _handle_resize(
|
async def _handle_resize(
|
||||||
self, envelope: list, route_key: str, _ws: web.WebSocketResponse
|
self, envelope: list, route_key: str, _ws: web.WebSocketResponse
|
||||||
|
|||||||
+25
-1
@@ -73,6 +73,25 @@ class Poller(Thread):
|
|||||||
return
|
return
|
||||||
await new_write.done_event.wait()
|
await new_write.done_event.wait()
|
||||||
|
|
||||||
|
async def write_with_timeout(
|
||||||
|
self, file_descriptor: int, data: bytes, timeout: float = 2.0
|
||||||
|
) -> bool:
|
||||||
|
"""Write data to a file descriptor with a timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_descriptor: File descriptor.
|
||||||
|
data: Data to write.
|
||||||
|
timeout: Maximum seconds to wait for the write to complete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the write completed, False on timeout.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.write(file_descriptor, data), timeout=timeout)
|
||||||
|
return True
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return False
|
||||||
|
|
||||||
def set_loop(self, loop: asyncio.AbstractEventLoop) -> None:
|
def set_loop(self, loop: asyncio.AbstractEventLoop) -> None:
|
||||||
"""Set the asyncio loop.
|
"""Set the asyncio loop.
|
||||||
|
|
||||||
@@ -110,6 +129,8 @@ class Poller(Thread):
|
|||||||
if event_mask & writeable_events:
|
if event_mask & writeable_events:
|
||||||
write_queue = self._write_queues.get(file_descriptor, None)
|
write_queue = self._write_queues.get(file_descriptor, None)
|
||||||
if write_queue:
|
if write_queue:
|
||||||
|
# Process all pending writes while fd is writable
|
||||||
|
while write_queue:
|
||||||
write = write_queue[0]
|
write = write_queue[0]
|
||||||
remaining_data = write.data[write.position :]
|
remaining_data = write.data[write.position :]
|
||||||
try:
|
try:
|
||||||
@@ -118,12 +139,15 @@ class Poller(Thread):
|
|||||||
# Write failed; signal completion anyway to unblock waiters
|
# Write failed; signal completion anyway to unblock waiters
|
||||||
write_queue.popleft()
|
write_queue.popleft()
|
||||||
loop.call_soon_threadsafe(write.done_event.set)
|
loop.call_soon_threadsafe(write.done_event.set)
|
||||||
continue
|
break
|
||||||
write.position += bytes_written
|
write.position += bytes_written
|
||||||
# Check if all data has been written
|
# Check if all data has been written
|
||||||
if write.position >= len(write.data):
|
if write.position >= len(write.data):
|
||||||
write_queue.popleft()
|
write_queue.popleft()
|
||||||
loop.call_soon_threadsafe(write.done_event.set)
|
loop.call_soon_threadsafe(write.done_event.set)
|
||||||
|
else:
|
||||||
|
# Partial write — fd buffer full, try again next cycle
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
selector.modify(file_descriptor, readable_events)
|
selector.modify(file_descriptor, readable_events)
|
||||||
|
|
||||||
|
|||||||
+16
-4
@@ -5,6 +5,10 @@ from click.testing import CliRunner
|
|||||||
from webterm import cli
|
from webterm import cli
|
||||||
|
|
||||||
|
|
||||||
|
def _close_coroutine(coro) -> None:
|
||||||
|
coro.close()
|
||||||
|
|
||||||
|
|
||||||
class TestCLI:
|
class TestCLI:
|
||||||
"""Tests for CLI command."""
|
"""Tests for CLI command."""
|
||||||
|
|
||||||
@@ -31,7 +35,7 @@ class TestCLI:
|
|||||||
calls["run"] = True
|
calls["run"] = True
|
||||||
|
|
||||||
monkeypatch.setattr(cli, "LocalServer", FakeServer)
|
monkeypatch.setattr(cli, "LocalServer", FakeServer)
|
||||||
monkeypatch.setattr(cli.asyncio, "run", lambda _coro: None)
|
monkeypatch.setattr(cli.asyncio, "run", _close_coroutine)
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
result = runner.invoke(cli.app, ["htop"])
|
result = runner.invoke(cli.app, ["htop"])
|
||||||
@@ -55,7 +59,7 @@ class TestCLI:
|
|||||||
|
|
||||||
monkeypatch.setenv("SHELL", "/bin/zsh")
|
monkeypatch.setenv("SHELL", "/bin/zsh")
|
||||||
monkeypatch.setattr(cli, "LocalServer", FakeServer)
|
monkeypatch.setattr(cli, "LocalServer", FakeServer)
|
||||||
monkeypatch.setattr(cli.asyncio, "run", lambda _coro: None)
|
monkeypatch.setattr(cli.asyncio, "run", _close_coroutine)
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
result = runner.invoke(cli.app, [])
|
result = runner.invoke(cli.app, [])
|
||||||
@@ -131,7 +135,11 @@ def test_cli_docker_watch_mode(monkeypatch):
|
|||||||
calls["run"] = True
|
calls["run"] = True
|
||||||
|
|
||||||
monkeypatch.setattr(cli, "LocalServer", FakeServer)
|
monkeypatch.setattr(cli, "LocalServer", FakeServer)
|
||||||
monkeypatch.setattr(cli.asyncio, "run", lambda _coro: calls.setdefault("run", True))
|
def run_and_close(coro):
|
||||||
|
calls.setdefault("run", True)
|
||||||
|
coro.close()
|
||||||
|
|
||||||
|
monkeypatch.setattr(cli.asyncio, "run", run_and_close)
|
||||||
monkeypatch.setattr(cli.constants, "DEBUG", True)
|
monkeypatch.setattr(cli.constants, "DEBUG", True)
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
@@ -155,7 +163,11 @@ def test_cli_windows_branch(monkeypatch):
|
|||||||
|
|
||||||
monkeypatch.setattr(cli, "LocalServer", FakeServer)
|
monkeypatch.setattr(cli, "LocalServer", FakeServer)
|
||||||
monkeypatch.setattr(cli.constants, "WINDOWS", True)
|
monkeypatch.setattr(cli.constants, "WINDOWS", True)
|
||||||
monkeypatch.setattr(cli.asyncio, "run", lambda _coro: calls.setdefault("run", True))
|
def run_and_close(coro):
|
||||||
|
calls.setdefault("run", True)
|
||||||
|
coro.close()
|
||||||
|
|
||||||
|
monkeypatch.setattr(cli.asyncio, "run", run_and_close)
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
result = runner.invoke(cli.app, ["--docker-watch"])
|
result = runner.invoke(cli.app, ["--docker-watch"])
|
||||||
|
|||||||
@@ -12,6 +12,14 @@ from webterm.local_server import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_asyncmock_call(mock: AsyncMock, timeout: float = 0.1) -> None:
|
||||||
|
async def _wait() -> None:
|
||||||
|
while mock.await_count == 0:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
await asyncio.wait_for(_wait(), timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
class TestGetStaticPath:
|
class TestGetStaticPath:
|
||||||
"""Tests for static path."""
|
"""Tests for static path."""
|
||||||
|
|
||||||
@@ -599,6 +607,8 @@ class TestLocalServerMoreCoverage:
|
|||||||
ws = MagicMock()
|
ws = MagicMock()
|
||||||
created = await server_with_no_apps._dispatch_ws_message(["stdin"], "rk", ws, False)
|
created = await server_with_no_apps._dispatch_ws_message(["stdin"], "rk", ws, False)
|
||||||
assert created is False
|
assert created is False
|
||||||
|
# stdin writes are fire-and-forget; wait until send_bytes is awaited
|
||||||
|
await wait_for_asyncmock_call(session.send_bytes)
|
||||||
session.send_bytes.assert_awaited_once_with(b"")
|
session.send_bytes.assert_awaited_once_with(b"")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -744,3 +754,59 @@ class TestLocalServerMoreCoverage:
|
|||||||
|
|
||||||
assert not queue.empty()
|
assert not queue.empty()
|
||||||
assert queue.get_nowait() == "my-route"
|
assert queue.get_nowait() == "my-route"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_stdin_does_not_block_ws_loop(
|
||||||
|
self, server_with_no_apps, monkeypatch
|
||||||
|
):
|
||||||
|
"""Stdin writes should be fire-and-forget so the WS loop keeps processing."""
|
||||||
|
send_started = asyncio.Event()
|
||||||
|
send_gate = asyncio.Event()
|
||||||
|
|
||||||
|
async def slow_send(_data):
|
||||||
|
send_started.set()
|
||||||
|
await send_gate.wait()
|
||||||
|
return True
|
||||||
|
|
||||||
|
session = MagicMock()
|
||||||
|
session.send_bytes = AsyncMock(side_effect=slow_send)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session
|
||||||
|
)
|
||||||
|
|
||||||
|
ws = MagicMock()
|
||||||
|
# _dispatch_ws_message should return immediately even though send_bytes blocks
|
||||||
|
created = await server_with_no_apps._dispatch_ws_message(
|
||||||
|
["stdin", "hello"], "rk", ws, False
|
||||||
|
)
|
||||||
|
assert created is False
|
||||||
|
|
||||||
|
# The background task should have been created but not finished
|
||||||
|
await send_started.wait()
|
||||||
|
assert not send_gate.is_set()
|
||||||
|
|
||||||
|
# Unblock and let the task finish
|
||||||
|
send_gate.set()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_write_stdin_logs_timeout(
|
||||||
|
self, server_with_no_apps, monkeypatch, caplog
|
||||||
|
):
|
||||||
|
"""_write_stdin should log a warning and not raise on timeout."""
|
||||||
|
async def hang_forever(_data):
|
||||||
|
await asyncio.sleep(999)
|
||||||
|
return True
|
||||||
|
|
||||||
|
session = MagicMock()
|
||||||
|
session.send_bytes = AsyncMock(side_effect=hang_forever)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Use a very short timeout for testing
|
||||||
|
monkeypatch.setattr("webterm.local_server.STDIN_WRITE_TIMEOUT", 0.01)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING, logger="webterm"):
|
||||||
|
await server_with_no_apps._write_stdin(session, "x", "rk")
|
||||||
|
|
||||||
|
assert "Stdin write timeout" in caplog.text
|
||||||
|
|||||||
@@ -136,3 +136,36 @@ class TestPoller:
|
|||||||
# Queues should have None
|
# Queues should have None
|
||||||
assert q1.get_nowait() is None
|
assert q1.get_nowait() is None
|
||||||
assert q2.get_nowait() is None
|
assert q2.get_nowait() is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_write_with_timeout_returns_true_on_success(self):
|
||||||
|
"""write_with_timeout returns True when write completes."""
|
||||||
|
poller = Poller()
|
||||||
|
poller._loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
with patch.object(poller._selector, "register"):
|
||||||
|
poller.add_file(42)
|
||||||
|
|
||||||
|
async def instant_write(fd, data):
|
||||||
|
# Simulate immediate completion
|
||||||
|
pass
|
||||||
|
|
||||||
|
with patch.object(poller, "write", side_effect=instant_write):
|
||||||
|
result = await poller.write_with_timeout(42, b"test", timeout=1.0)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_write_with_timeout_returns_false_on_timeout(self):
|
||||||
|
"""write_with_timeout returns False when write times out."""
|
||||||
|
poller = Poller()
|
||||||
|
poller._loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
with patch.object(poller._selector, "register"):
|
||||||
|
poller.add_file(42)
|
||||||
|
|
||||||
|
async def slow_write(fd, data):
|
||||||
|
await asyncio.sleep(999)
|
||||||
|
|
||||||
|
with patch.object(poller, "write", side_effect=slow_write):
|
||||||
|
result = await poller.write_with_timeout(42, b"test", timeout=0.01)
|
||||||
|
assert result is False
|
||||||
|
|||||||
Reference in New Issue
Block a user