Improve local_server and terminal_session coverage

This commit is contained in:
GitHub Copilot
2026-01-22 14:09:34 +00:00
parent 0cfb3b0a2f
commit d03f32bf69
4 changed files with 373 additions and 19 deletions
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "textual-webterm" name = "textual-webterm"
version = "0.1.10" version = "0.1.11"
description = "Serve terminal sessions over the web" description = "Serve terminal sessions over the web"
authors = ["Will McGugan <will@textualize.io>"] authors = ["Will McGugan <will@textualize.io>"]
license = "MIT" license = "MIT"
+69 -18
View File
@@ -8,6 +8,7 @@ import hashlib
import io import io
import json import json
import logging import logging
import re
import signal import signal
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -37,6 +38,12 @@ SCREENSHOT_MAX_BYTES = 65536
SCREENSHOT_CACHE_SECONDS = 1.0 SCREENSHOT_CACHE_SECONDS = 1.0
SCREENSHOT_MAX_CACHE_SECONDS = 60.0 SCREENSHOT_MAX_CACHE_SECONDS = 60.0
SVG_MONO_FONT_STACK = (
'ui-monospace, "SFMono-Regular", "FiraCode Nerd Font", "FiraMono Nerd Font", '
'"Fira Code", "Roboto Mono", Menlo, Monaco, Consolas, "Liberation Mono", '
'"DejaVu Sans Mono", "Courier New", monospace'
)
WEBTERM_STATIC_PATH = Path(__file__).parent / "static" WEBTERM_STATIC_PATH = Path(__file__).parent / "static"
@@ -84,6 +91,22 @@ class LocalClientConnector(SessionConnector):
await self.server.handle_session_close(self.session_id, self.route_key) await self.server.handle_session_close(self.session_id, self.route_key)
def _rewrite_svg_fonts(svg: str) -> str:
"""Make Rich SVG output self-contained and aligned with our monospace styling."""
# Rich export_svg embeds @font-face rules that reference external CDNs.
svg = re.sub(r"@font-face\s*\{.*?\}\s*", "", svg, flags=re.DOTALL)
# Force our local monospace stack even if Rich sets font-family to Fira Code.
override = f"\ntext {{ font-family: {SVG_MONO_FONT_STACK} !important; }}\n"
if "</style>" in svg:
svg = svg.replace("</style>", override + "</style>", 1)
else:
svg = svg.replace("<svg ", f"<svg><style>{override}</style> ", 1)
return svg
def _apply_carriage_returns(text: str) -> list[str]: def _apply_carriage_returns(text: str) -> list[str]:
"""Interpret \r as 'return to start of line' (overwrite), not a newline. """Interpret \r as 'return to start of line' (overwrite), not a newline.
@@ -110,6 +133,22 @@ class LocalServer:
def mark_route_activity(self, route_key: str) -> None: def mark_route_activity(self, route_key: str) -> None:
self._route_last_activity[route_key] = asyncio.get_event_loop().time() self._route_last_activity[route_key] = asyncio.get_event_loop().time()
def _get_cached_screenshot_response(
self, request: web.Request, route_key: str
) -> web.Response | None:
cached = self._screenshot_cache.get(route_key)
if cached is None:
return None
etag = self._screenshot_cache_etag.get(route_key)
if etag and request.headers.get("If-None-Match") == etag:
raise web.HTTPNotModified(headers={"ETag": etag, "Cache-Control": "no-cache"})
headers = {"Cache-Control": "no-cache"}
if etag:
headers["ETag"] = etag
return web.Response(text=cached[1], content_type="image/svg+xml", headers=headers)
def _get_screenshot_cache_ttl(self, route_key: str, now: float) -> float: def _get_screenshot_cache_ttl(self, route_key: str, now: float) -> float:
last_activity = self._route_last_activity.get(route_key, 0.0) last_activity = self._route_last_activity.get(route_key, 0.0)
idle_for = max(0.0, now - last_activity) idle_for = max(0.0, now - last_activity)
@@ -157,6 +196,7 @@ class LocalServer:
self._screenshot_cache_etag: dict[str, str] = {} self._screenshot_cache_etag: dict[str, str] = {}
self._screenshot_locks: dict[str, asyncio.Lock] = {} self._screenshot_locks: dict[str, asyncio.Lock] = {}
self._route_last_activity: dict[str, float] = {} self._route_last_activity: dict[str, float] = {}
self._screenshot_last_rendered_activity: dict[str, float] = {}
@property @property
def app_count(self) -> int: def app_count(self) -> int:
@@ -288,12 +328,14 @@ class LocalServer:
msg_type = envelope[0] msg_type = envelope[0]
if msg_type == "stdin": if msg_type == "stdin":
self.mark_route_activity(route_key)
data = envelope[1] if len(envelope) > 1 else "" data = envelope[1] if len(envelope) > 1 else ""
session_process = self.session_manager.get_session_by_route_key(RouteKey(route_key)) session_process = self.session_manager.get_session_by_route_key(RouteKey(route_key))
if session_process: if session_process:
await session_process.send_bytes(data.encode("utf-8")) await session_process.send_bytes(data.encode("utf-8"))
elif msg_type == "resize": elif msg_type == "resize":
self.mark_route_activity(route_key)
size_data = envelope[1] if len(envelope) > 1 else {} size_data = envelope[1] if len(envelope) > 1 else {}
width = max(1, min(500, int(size_data.get("width", 80)))) width = max(1, min(500, int(size_data.get("width", 80))))
height = max(1, min(500, int(size_data.get("height", 24)))) height = max(1, min(500, int(size_data.get("height", 24))))
@@ -422,6 +464,15 @@ class LocalServer:
if session_process is None or not hasattr(session_process, "get_replay_buffer"): if session_process is None or not hasattr(session_process, "get_replay_buffer"):
raise web.HTTPNotFound(text="Session not found") raise web.HTTPNotFound(text="Session not found")
# If nothing has changed since the last render, serve cached screenshot without
# touching the session replay buffer.
last_activity = self._route_last_activity.get(route_key, 0.0)
last_rendered_activity = self._screenshot_last_rendered_activity.get(route_key, -1.0)
if last_activity <= last_rendered_activity:
cached_response = self._get_cached_screenshot_response(request, route_key)
if cached_response is not None:
return cached_response
replay_data = await session_process.get_replay_buffer() # type: ignore[func-returns-value] replay_data = await session_process.get_replay_buffer() # type: ignore[func-returns-value]
if len(replay_data) > SCREENSHOT_MAX_BYTES: if len(replay_data) > SCREENSHOT_MAX_BYTES:
replay_data = replay_data[-SCREENSHOT_MAX_BYTES:] replay_data = replay_data[-SCREENSHOT_MAX_BYTES:]
@@ -447,16 +498,17 @@ class LocalServer:
ttl = self._get_screenshot_cache_ttl(route_key, now) ttl = self._get_screenshot_cache_ttl(route_key, now)
cached = self._screenshot_cache.get(route_key) cached = self._screenshot_cache.get(route_key)
if cached is not None: # If we have a cached screenshot and the session is idle, keep serving it until
etag = self._screenshot_cache_etag.get(route_key) # new activity occurs (no periodic re-render).
if etag and request.headers.get("If-None-Match") == etag: if cached is not None and self._route_last_activity.get(route_key, 0.0) == 0.0:
raise web.HTTPNotModified(headers={"ETag": etag, "Cache-Control": "no-cache"}) cached_response = self._get_cached_screenshot_response(request, route_key)
if cached_response is not None:
return cached_response
if (now - cached[0]) < ttl: if cached is not None and (now - cached[0]) < ttl:
headers = {"Cache-Control": "no-cache"} cached_response = self._get_cached_screenshot_response(request, route_key)
if etag: if cached_response is not None:
headers["ETag"] = etag return cached_response
return web.Response(text=cached[1], content_type="image/svg+xml", headers=headers)
lock = self._screenshot_locks.get(route_key) lock = self._screenshot_locks.get(route_key)
if lock is None: if lock is None:
@@ -467,15 +519,10 @@ class LocalServer:
# Another request may have refreshed the cache while we waited. # Another request may have refreshed the cache while we waited.
ttl = self._get_screenshot_cache_ttl(route_key, now) ttl = self._get_screenshot_cache_ttl(route_key, now)
cached = self._screenshot_cache.get(route_key) cached = self._screenshot_cache.get(route_key)
etag = self._screenshot_cache_etag.get(route_key) if cached is not None and (now - cached[0]) < ttl:
if cached is not None: cached_response = self._get_cached_screenshot_response(request, route_key)
if etag and request.headers.get("If-None-Match") == etag: if cached_response is not None:
raise web.HTTPNotModified(headers={"ETag": etag, "Cache-Control": "no-cache"}) return cached_response
if (now - cached[0]) < ttl:
headers = {"Cache-Control": "no-cache"}
if etag:
headers["ETag"] = etag
return web.Response(text=cached[1], content_type="image/svg+xml", headers=headers)
def _render_svg() -> str: def _render_svg() -> str:
console = Console(record=True, width=width, height=height, file=io.StringIO()) console = Console(record=True, width=width, height=height, file=io.StringIO())
@@ -505,9 +552,13 @@ class LocalServer:
) )
svg = await asyncio.to_thread(_render_svg) svg = await asyncio.to_thread(_render_svg)
svg = _rewrite_svg_fonts(svg)
etag = hashlib.sha1(svg.encode("utf-8"), usedforsecurity=False).hexdigest() etag = hashlib.sha1(svg.encode("utf-8"), usedforsecurity=False).hexdigest()
self._screenshot_cache[route_key] = (asyncio.get_event_loop().time(), svg) self._screenshot_cache[route_key] = (asyncio.get_event_loop().time(), svg)
self._screenshot_cache_etag[route_key] = etag self._screenshot_cache_etag[route_key] = etag
self._screenshot_last_rendered_activity[route_key] = self._route_last_activity.get(
route_key, 0.0
)
headers = {"Cache-Control": "no-cache", "ETag": etag} headers = {"Cache-Control": "no-cache", "ETag": etag}
return web.Response(text=svg, content_type="image/svg+xml", headers=headers) return web.Response(text=svg, content_type="image/svg+xml", headers=headers)
+146
View File
@@ -622,3 +622,149 @@ class TestLocalServerMoreCoverage:
await connector.on_close() await connector.on_close()
server.handle_session_close.assert_awaited_once_with("sid", "rk") server.handle_session_close.assert_awaited_once_with("sid", "rk")
@pytest.mark.asyncio
async def test_run_stops_exit_poller_and_exits_poller(self, server_with_no_apps, monkeypatch):
async def boom():
raise RuntimeError("boom")
monkeypatch.setattr(server_with_no_apps, "_run", boom)
server_with_no_apps._exit_poller.stop = MagicMock()
server_with_no_apps._poller.exit = MagicMock()
with pytest.raises(RuntimeError):
await server_with_no_apps.run()
server_with_no_apps._exit_poller.stop.assert_called_once()
server_with_no_apps._poller.exit.assert_called_once()
def test_on_keyboard_interrupt_sets_event_when_already_shutting_down(self, server_with_no_apps):
server_with_no_apps._shutdown_started = True
assert not server_with_no_apps.exit_event.is_set()
server_with_no_apps.on_keyboard_interrupt()
assert server_with_no_apps.exit_event.is_set()
@pytest.mark.asyncio
async def test_on_keyboard_interrupt_schedules_shutdown_in_running_loop(self, server_with_no_apps):
called = {"shutdown": False}
async def shutdown():
called["shutdown"] = True
server_with_no_apps.exit_event.set()
server_with_no_apps._shutdown = shutdown # type: ignore[method-assign]
server_with_no_apps.on_keyboard_interrupt()
assert server_with_no_apps._shutdown_task is not None
await server_with_no_apps._shutdown_task
assert called["shutdown"] is True
def test_on_keyboard_interrupt_uses_call_soon_threadsafe_when_loop_running(
self, server_with_no_apps, monkeypatch
):
async def shutdown():
return None
server_with_no_apps._shutdown = shutdown # type: ignore[method-assign]
fake_loop = MagicMock()
fake_loop.is_running = MagicMock(return_value=True)
server_with_no_apps._loop = fake_loop
created = {"called": False}
def fake_create_task(coro):
created["called"] = True
coro.close()
return MagicMock()
monkeypatch.setattr("textual_webterm.local_server.asyncio.create_task", fake_create_task)
server_with_no_apps.on_keyboard_interrupt()
assert fake_loop.call_soon_threadsafe.called
schedule = fake_loop.call_soon_threadsafe.call_args.args[0]
schedule()
assert created["called"] is True
def test_build_routes_logs_error_when_static_path_missing(self, server_with_no_apps, monkeypatch):
from pathlib import Path
from textual_webterm import local_server
class FakePath(Path):
_flavour = type(Path())._flavour
def exists(self) -> bool: # type: ignore[override]
return False
monkeypatch.setattr(local_server, "STATIC_PATH", FakePath("/definitely-missing"))
monkeypatch.setattr(local_server.log, "error", MagicMock())
server_with_no_apps._build_routes()
local_server.log.error.assert_called()
@pytest.mark.asyncio
async def test_dispatch_ws_message_stdin_without_payload_sends_empty(self, server_with_no_apps, monkeypatch):
session = MagicMock()
session.send_bytes = AsyncMock()
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session)
ws = MagicMock()
created = await server_with_no_apps._dispatch_ws_message(["stdin"], "rk", ws, False)
assert created is False
session.send_bytes.assert_awaited_once_with(b"")
@pytest.mark.asyncio
async def test_dispatch_ws_message_resize_updates_existing_session(self, server_with_no_apps, monkeypatch):
session = MagicMock()
session.set_terminal_size = AsyncMock()
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session)
ws = MagicMock()
created = await server_with_no_apps._dispatch_ws_message(
["resize", {"width": 100, "height": 50}], "rk", ws, True
)
assert created is True
session.set_terminal_size.assert_awaited_once_with(100, 50)
@pytest.mark.asyncio
async def test_dispatch_ws_message_resize_no_session_noop(self, server_with_no_apps, monkeypatch):
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: None)
ws = MagicMock()
created = await server_with_no_apps._dispatch_ws_message(
["resize", {"width": 100, "height": 50}], "rk", ws, True
)
assert created is True
@pytest.mark.asyncio
async def test_handle_screenshot_truncates_replay_buffer_before_decode(self, server_with_no_apps, monkeypatch):
from textual_webterm.local_server import SCREENSHOT_MAX_BYTES
request = MagicMock()
request.query = {"route_key": "rk"}
request.headers = {}
session = MagicMock()
session.get_replay_buffer = AsyncMock(return_value=b"x" * (SCREENSHOT_MAX_BYTES + 10))
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session)
server_with_no_apps._route_last_activity["rk"] = 1.0
captured = {"len": None}
def apply_cr(text: str):
captured["len"] = len(text)
return ["x"]
async def fake_to_thread(_fn):
return "<svg></svg>"
monkeypatch.setattr("textual_webterm.local_server._apply_carriage_returns", apply_cr)
monkeypatch.setattr("textual_webterm.local_server.asyncio.to_thread", AsyncMock(side_effect=fake_to_thread))
monkeypatch.setattr("textual_webterm.local_server._rewrite_svg_fonts", lambda s: s)
resp = await server_with_no_apps._handle_screenshot(request)
assert resp.content_type == "image/svg+xml"
assert captured["len"] == SCREENSHOT_MAX_BYTES
+157
View File
@@ -286,3 +286,160 @@ class TestTerminalSession:
assert await session.send_bytes(b"x") is True assert await session.send_bytes(b"x") is True
poller.write.assert_awaited_once_with(10, b"x") poller.write.assert_awaited_once_with(10, b"x")
@pytest.mark.asyncio
async def test_open_set_terminal_size_oserror_closes_fd_and_clears_master_fd(self):
from textual_webterm.terminal_session import TerminalSession
poller = MagicMock()
session = TerminalSession(poller, "sid", "bash")
with (
patch("textual_webterm.terminal_session.pty.fork", return_value=(1234, 99)),
patch.object(session, "_set_terminal_size", side_effect=OSError("bad")),
patch("textual_webterm.terminal_session.os.close") as mock_close,
pytest.raises(OSError),
):
await session.open(width=80, height=24)
mock_close.assert_called_once_with(99)
assert session.master_fd is None
@pytest.mark.asyncio
async def test_set_terminal_size_uses_executor(self):
from textual_webterm.terminal_session import TerminalSession
poller = MagicMock()
session = TerminalSession(poller, "sid", "bash")
session.master_fd = 10
loop = asyncio.get_running_loop()
with patch.object(loop, "run_in_executor", new=AsyncMock()) as run_in_executor:
await session.set_terminal_size(80, 24)
run_in_executor.assert_awaited_once_with(None, session._set_terminal_size, 80, 24)
def test__set_terminal_size_calls_ioctl(self):
from textual_webterm.terminal_session import TerminalSession
poller = MagicMock()
session = TerminalSession(poller, "sid", "bash")
session.master_fd = 10
with patch("textual_webterm.terminal_session.fcntl.ioctl") as mock_ioctl:
session._set_terminal_size(80, 24)
assert mock_ioctl.called
@pytest.mark.asyncio
async def test_start_creates_task_when_not_running(self):
from textual_webterm.terminal_session import TerminalSession
poller = MagicMock()
session = TerminalSession(poller, "sid", "bash")
session.master_fd = 10
session.run = AsyncMock() # type: ignore[method-assign]
connector = MagicMock()
task = await session.start(connector)
assert task is session._task
assert session._connector is connector
await task
session.run.assert_awaited_once()
@pytest.mark.asyncio
async def test_run_without_connector_still_closes(self):
from textual_webterm.terminal_session import TerminalSession
queue: asyncio.Queue[bytes | None] = asyncio.Queue()
await queue.put(b"hello")
await queue.put(None)
poller = MagicMock()
poller.add_file = MagicMock(return_value=queue)
poller.remove_file = MagicMock()
session = TerminalSession(poller, "sid", "bash")
session.master_fd = 10
session._connector = None
with patch("textual_webterm.terminal_session.os.close") as mock_close:
await session.run()
poller.remove_file.assert_called_once_with(10)
mock_close.assert_called_once_with(10)
@pytest.mark.asyncio
async def test_run_oserror_still_closes(self):
from textual_webterm.terminal_session import TerminalSession
queue = MagicMock()
queue.get = AsyncMock(side_effect=OSError("boom"))
poller = MagicMock()
poller.add_file = MagicMock(return_value=queue)
poller.remove_file = MagicMock()
session = TerminalSession(poller, "sid", "bash")
session.master_fd = 10
session._connector = None
with patch("textual_webterm.terminal_session.os.close") as mock_close:
await session.run()
poller.remove_file.assert_called_once_with(10)
mock_close.assert_called_once_with(10)
@pytest.mark.asyncio
async def test_close_process_lookup_error_is_ignored(self):
from textual_webterm.terminal_session import TerminalSession
poller = MagicMock()
session = TerminalSession(poller, "sid", "bash")
session.pid = 123
with patch("textual_webterm.terminal_session.os.kill", side_effect=ProcessLookupError()):
await session.close()
@pytest.mark.asyncio
async def test_close_logs_warning_on_unexpected_exception(self):
from textual_webterm.terminal_session import TerminalSession
poller = MagicMock()
session = TerminalSession(poller, "sid", "bash")
session.pid = 123
with (
patch("textual_webterm.terminal_session.os.kill", side_effect=RuntimeError("x")),
patch("textual_webterm.terminal_session.log.warning") as warn,
):
await session.close()
assert warn.called
@pytest.mark.asyncio
async def test_wait_suppresses_cancelled_error(self):
from textual_webterm.terminal_session import TerminalSession
poller = MagicMock()
session = TerminalSession(poller, "sid", "bash")
task = asyncio.create_task(asyncio.sleep(10))
task.cancel()
session._task = task
await session.wait()
def test_is_running_false_when_kill_fails(self):
from textual_webterm.terminal_session import TerminalSession
poller = MagicMock()
session = TerminalSession(poller, "sid", "bash")
session.master_fd = 10
session._task = MagicMock()
session.pid = 123
with patch("textual_webterm.terminal_session.os.kill", side_effect=OSError()):
assert session.is_running() is False