diff --git a/src/textual_webterm/local_server.py b/src/textual_webterm/local_server.py index a4c721b..7e29b3d 100644 --- a/src/textual_webterm/local_server.py +++ b/src/textual_webterm/local_server.py @@ -31,6 +31,10 @@ log = logging.getLogger("textual-web") DISCONNECT_RESIZE = (132, 45) +# Avoid heavy screenshot rendering from processing unbounded output. +SCREENSHOT_MAX_BYTES = 65536 +SCREENSHOT_CACHE_SECONDS = 1.0 + WEBTERM_STATIC_PATH = Path(__file__).parent / "static" @@ -77,6 +81,28 @@ class LocalClientConnector(SessionConnector): await self.server.handle_session_close(self.session_id, self.route_key) +def _apply_carriage_returns(text: str) -> list[str]: + """Interpret \r as 'return to start of line' (overwrite), not a newline. + + This prevents terminals that redraw a single line (progress bars, prompts) from + expanding into many duplicate lines in screenshots. + """ + + lines: list[str] = [] + current: list[str] = [] + for ch in text: + if ch == "\r": + current.clear() + elif ch == "\n": + lines.append("".join(current)) + current.clear() + else: + current.append(ch) + if current: + lines.append("".join(current)) + return lines + + class LocalServer: """Manages local Textual apps and terminals without Ganglion server.""" @@ -108,6 +134,9 @@ class LocalServer: self._websocket_connections: dict[RouteKey, web.WebSocketResponse] = {} self._landing_apps = landing_apps or [] + self._screenshot_cache: dict[str, tuple[float, str]] = {} + self._screenshot_locks: dict[str, asyncio.Lock] = {} + @property def app_count(self) -> int: return len(self.session_manager.apps) @@ -373,6 +402,8 @@ class LocalServer: raise web.HTTPNotFound(text="Session not found") replay_data = await session_process.get_replay_buffer() # type: ignore[func-returns-value] + if len(replay_data) > SCREENSHOT_MAX_BYTES: + replay_data = replay_data[-SCREENSHOT_MAX_BYTES:] ansi_text = replay_data.decode("utf-8", errors="replace") try: @@ -387,36 +418,56 @@ class LocalServer: height = DISCONNECT_RESIZE[1] height = max(5, min(200, height)) - lines = ansi_text.splitlines() + lines = _apply_carriage_returns(ansi_text) if len(lines) > height: ansi_text = "\n".join(lines[-height:]) + "\n" - console = Console(record=True, width=width, height=height, file=io.StringIO()) - decoder = AnsiDecoder() - for renderable in decoder.decode(ansi_text): - console.print(renderable) + now = asyncio.get_event_loop().time() + cached = self._screenshot_cache.get(route_key) + if cached is not None and (now - cached[0]) < SCREENSHOT_CACHE_SECONDS: + return web.Response(text=cached[1], content_type="image/svg+xml") - svg = console.export_svg( - title="textual-webterm", - code_format=( - '' - '' - '' - '' - '' - '' - '{lines}' - '' - '' - '' - '{backgrounds}' - '{matrix}' - '' - '' - ), - ) - return web.Response(text=svg, content_type="image/svg+xml") + lock = self._screenshot_locks.get(route_key) + if lock is None: + lock = asyncio.Lock() + self._screenshot_locks[route_key] = lock + + async with lock: + # Another request may have refreshed the cache while we waited. + cached = self._screenshot_cache.get(route_key) + if cached is not None and (now - cached[0]) < SCREENSHOT_CACHE_SECONDS: + return web.Response(text=cached[1], content_type="image/svg+xml") + + def _render_svg() -> str: + console = Console(record=True, width=width, height=height, file=io.StringIO()) + decoder = AnsiDecoder() + for renderable in decoder.decode(ansi_text): + console.print(renderable) + + return console.export_svg( + title="textual-webterm", + code_format=( + '' + '' + '' + '' + '' + '' + '{lines}' + '' + '' + '' + '{backgrounds}' + '{matrix}' + '' + '' + ), + ) + + svg = await asyncio.to_thread(_render_svg) + self._screenshot_cache[route_key] = (asyncio.get_event_loop().time(), svg) + return web.Response(text=svg, content_type="image/svg+xml") async def _handle_health_check(self, _request: web.Request) -> web.Response: return web.Response(text="Local server is running") diff --git a/tests/test_local_server_unit.py b/tests/test_local_server_unit.py index 6d24b57..8677fd1 100644 --- a/tests/test_local_server_unit.py +++ b/tests/test_local_server_unit.py @@ -6,7 +6,7 @@ import pytest from aiohttp import web from textual_webterm.config import App, Config -from textual_webterm.local_server import LocalServer +from textual_webterm.local_server import LocalServer, _apply_carriage_returns class TestGetStaticPath: @@ -105,6 +105,10 @@ class TestLocalServer: class TestLocalServerHelpers: """Tests for LocalServer helper methods.""" + def test_apply_carriage_returns_overwrites_line(self): + text = "hello\rworld\nnext" + assert _apply_carriage_returns(text) == ["world", "next"] + @pytest.mark.asyncio async def test_keyboard_interrupt_closes_sessions_and_websockets(self, server, monkeypatch): ws1 = MagicMock()