diff --git a/src/textual_webterm/local_server.py b/src/textual_webterm/local_server.py index 110c86d..6a2a58b 100644 --- a/src/textual_webterm/local_server.py +++ b/src/textual_webterm/local_server.py @@ -14,7 +14,6 @@ from pathlib import Path from typing import TYPE_CHECKING import aiohttp -import pyte from aiohttp import WSMsgType, web from rich.ansi import AnsiDecoder from rich.console import Console @@ -34,8 +33,6 @@ log = logging.getLogger("textual-web") DISCONNECT_RESIZE = (132, 45) -# Avoid heavy screenshot rendering from processing unbounded output. -SCREENSHOT_MAX_BYTES = 65536 SCREENSHOT_CACHE_SECONDS = 1.0 SCREENSHOT_MAX_CACHE_SECONDS = 60.0 @@ -108,19 +105,6 @@ def _rewrite_svg_fonts(svg: str) -> str: return svg -def _apply_carriage_returns(text: str, width: int = 80, height: int = 24) -> list[str]: - """Use pyte terminal emulator to properly interpret ANSI escape sequences. - - This handles cursor positioning, screen clearing, and other terminal control - codes that cause issues like tmux status bars "creeping up" in screenshots. - """ - screen = pyte.Screen(width, height) - stream = pyte.Stream(screen) - stream.feed(text) - # Return lines from the display, stripping trailing whitespace - return [line.rstrip() for line in screen.display] - - class LocalServer: def mark_route_activity(self, route_key: str) -> None: self._route_last_activity[route_key] = asyncio.get_event_loop().time() @@ -473,7 +457,7 @@ class LocalServer: ) session_process = self.session_manager.get_session_by_route_key(RouteKey(route_key)) - if session_process is None or not hasattr(session_process, "get_replay_buffer"): + if session_process is None or not hasattr(session_process, "get_screen_lines"): raise web.HTTPNotFound(text="Session not found") # If nothing has changed since the last render, serve cached screenshot without @@ -485,10 +469,10 @@ class LocalServer: if cached_response is not None: return cached_response - replay_data = await session_process.get_replay_buffer() # type: ignore[func-returns-value] - if len(replay_data) > SCREENSHOT_MAX_BYTES: - replay_data = replay_data[-SCREENSHOT_MAX_BYTES:] - ansi_text = replay_data.decode("utf-8", errors="replace") + # Get screen lines directly from the terminal session's pyte screen + # This provides accurate terminal state without replay buffer truncation issues + lines = await session_process.get_screen_lines() # type: ignore[union-attr] + screen_text = "\n".join(lines) try: width = int(request.query.get("width", "120")) @@ -502,10 +486,6 @@ class LocalServer: height = DISCONNECT_RESIZE[1] height = max(5, min(200, height)) - # Use pyte terminal emulator to get clean screen state - lines = _apply_carriage_returns(ansi_text, width, height) - ansi_text = "\n".join(lines) - now = asyncio.get_event_loop().time() ttl = self._get_screenshot_cache_ttl(route_key, now) cached = self._screenshot_cache.get(route_key) @@ -539,7 +519,7 @@ class LocalServer: def _render_svg() -> str: console = Console(record=True, width=width, height=height, file=io.StringIO()) decoder = AnsiDecoder() - for renderable in decoder.decode(ansi_text): + for renderable in decoder.decode(screen_text): console.print(renderable) return console.export_svg( diff --git a/src/textual_webterm/terminal_session.py b/src/textual_webterm/terminal_session.py index a5cac14..86ad039 100644 --- a/src/textual_webterm/terminal_session.py +++ b/src/textual_webterm/terminal_session.py @@ -13,6 +13,7 @@ import termios from collections import deque from typing import TYPE_CHECKING +import pyte import rich.repr from importlib_metadata import version @@ -27,6 +28,10 @@ log = logging.getLogger("textual-web") # Maximum bytes to keep in replay buffer for reconnection REPLAY_BUFFER_SIZE = 64 * 1024 # 64KB +# Default screen size for pyte emulator +DEFAULT_SCREEN_WIDTH = 132 +DEFAULT_SCREEN_HEIGHT = 45 + @rich.repr.auto class TerminalSession(Session): @@ -47,6 +52,10 @@ class TerminalSession(Session): self._replay_buffer: deque[bytes] = deque() self._replay_buffer_size = 0 self._replay_lock = asyncio.Lock() + # pyte screen for accurate terminal state tracking + self._screen = pyte.Screen(DEFAULT_SCREEN_WIDTH, DEFAULT_SCREEN_HEIGHT) + self._stream = pyte.Stream(self._screen) + self._screen_lock = asyncio.Lock() super().__init__() def __rich_repr__(self) -> rich.repr.Result: @@ -55,6 +64,10 @@ class TerminalSession(Session): async def open(self, width: int = 80, height: int = 24) -> None: log.info("Opening terminal session %s with command: %s", self.session_id, self.command) + # Initialize pyte screen with the requested size + self._screen = pyte.Screen(width, height) + self._stream = pyte.Stream(self._screen) + pid, master_fd = pty.fork() self.pid = pid self.master_fd = master_fd @@ -88,6 +101,9 @@ class TerminalSession(Session): async def set_terminal_size(self, width: int, height: int) -> None: loop = asyncio.get_running_loop() await loop.run_in_executor(None, self._set_terminal_size, width, height) + # Resize pyte screen to match + async with self._screen_lock: + self._screen.resize(height, width) async def _add_to_replay_buffer(self, data: bytes) -> None: """Add data to replay buffer, maintaining size limit.""" @@ -98,11 +114,30 @@ class TerminalSession(Session): old_data = self._replay_buffer.popleft() self._replay_buffer_size -= len(old_data) + async def _update_screen(self, data: bytes) -> None: + """Update the pyte screen with new terminal data.""" + async with self._screen_lock: + try: + text = data.decode("utf-8", errors="replace") + self._stream.feed(text) + except Exception: + # Don't let pyte errors crash the session + pass + async def get_replay_buffer(self) -> bytes: """Get the contents of the replay buffer.""" async with self._replay_lock: return b"".join(self._replay_buffer) + async def get_screen_lines(self) -> list[str]: + """Get the current screen state as a list of lines. + + Returns properly rendered terminal content with all escape sequences + interpreted, suitable for screenshot generation. + """ + async with self._screen_lock: + return [line.rstrip() for line in self._screen.display] + def update_connector(self, connector: SessionConnector) -> None: """Update the connector for reconnection without restarting the session.""" self._connector = connector @@ -127,6 +162,8 @@ class TerminalSession(Session): break # Store in replay buffer for reconnection await self._add_to_replay_buffer(data) + # Update pyte screen state for screenshots + await self._update_screen(data) # Send to current connector if self._connector: await self._connector.on_data(data) diff --git a/tests/test_local_server_unit.py b/tests/test_local_server_unit.py index 51aa57d..893504c 100644 --- a/tests/test_local_server_unit.py +++ b/tests/test_local_server_unit.py @@ -9,7 +9,6 @@ from textual_webterm.config import App, Config from textual_webterm.local_server import ( LocalClientConnector, LocalServer, - _apply_carriage_returns, _rewrite_svg_fonts, ) @@ -110,26 +109,6 @@ class TestLocalServer: class TestLocalServerHelpers: """Tests for LocalServer helper methods.""" - def test_apply_carriage_returns_overwrites_line(self): - text = "hello\rworld\r\nnext" - # pyte terminal emulator interprets CR properly - overwrites hello with world - lines = _apply_carriage_returns(text, width=80, height=24) - # First line should have "world" (overwritten), second line "next" - assert lines[0] == "world" - assert lines[1] == "next" - - def test_apply_carriage_returns_handles_cursor_positioning(self): - # Simulate tmux-style cursor positioning to row 5, column 1 (\x1b[5;1H) - # Then clear to end of line (\x1b[K) and write new content - # Use \r\n for proper line endings - text = "line1\r\nline2\r\nline3\r\nline4\r\nline5\x1b[5;1H\x1b[Kupdated" - lines = _apply_carriage_returns(text, width=80, height=10) - # Line 5 (index 4) should be overwritten with "updated" - assert lines[4] == "updated" - # Previous lines should remain - assert lines[0] == "line1" - assert lines[1] == "line2" - @pytest.mark.asyncio async def test_keyboard_interrupt_closes_sessions_and_websockets(self, server, monkeypatch): ws1 = MagicMock() @@ -215,7 +194,7 @@ class TestLocalServerHelpers: request.query = {"route_key": "rk", "width": "80"} session = MagicMock() - session.get_replay_buffer = AsyncMock(return_value=b"hello\r\n") + session.get_screen_lines = AsyncMock(return_value=["hello", ""]) monkeypatch.setattr(server.session_manager, "get_session_by_route_key", lambda _rk: session) @@ -233,7 +212,7 @@ class TestLocalServerHelpers: request.query = {"route_key": "known", "width": "90"} session = MagicMock() - session.get_replay_buffer = AsyncMock(return_value=b"world\r\n") + session.get_screen_lines = AsyncMock(return_value=["world", ""]) # Pretend app exists for slug "known" server.session_manager.apps_by_slug["known"] = App( @@ -565,7 +544,7 @@ class TestLocalServerMoreCoverage: request.headers = {} session = MagicMock() - session.get_replay_buffer = AsyncMock(return_value=b"SHOULD_NOT_BE_READ") + session.get_screen_lines = AsyncMock(return_value=["SHOULD_NOT_BE_READ"]) monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session) server_with_no_apps._screenshot_cache["rk"] = (0.0, "cached") @@ -575,7 +554,7 @@ class TestLocalServerMoreCoverage: resp = await server_with_no_apps._handle_screenshot(request) assert "cached" in resp.text - session.get_replay_buffer.assert_not_awaited() + session.get_screen_lines.assert_not_awaited() @pytest.mark.asyncio async def test_handle_screenshot_invalid_width_height_defaults(self, server_with_no_apps, monkeypatch): @@ -584,7 +563,7 @@ class TestLocalServerMoreCoverage: request.headers = {} session = MagicMock() - session.get_replay_buffer = AsyncMock(return_value=b"hello\n") + session.get_screen_lines = AsyncMock(return_value=["hello", ""]) monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session) resp = await server_with_no_apps._handle_screenshot(request) @@ -756,32 +735,19 @@ class TestLocalServerMoreCoverage: assert created is True @pytest.mark.asyncio - async def test_handle_screenshot_truncates_replay_buffer_before_decode(self, server_with_no_apps, monkeypatch): - from textual_webterm.local_server import SCREENSHOT_MAX_BYTES - + async def test_handle_screenshot_uses_get_screen_lines(self, server_with_no_apps, monkeypatch): + """Test that screenshot uses get_screen_lines() from terminal session.""" request = MagicMock() request.query = {"route_key": "rk"} request.headers = {} session = MagicMock() - session.get_replay_buffer = AsyncMock(return_value=b"x" * (SCREENSHOT_MAX_BYTES + 10)) + session.get_screen_lines = AsyncMock(return_value=["line1", "line2", ""]) monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session) server_with_no_apps._route_last_activity["rk"] = 1.0 - captured = {"len": None} - - def apply_cr(text: str, width: int = 80, height: int = 24): - captured["len"] = len(text) - return ["x"] - - async def fake_to_thread(_fn): - return "" - - monkeypatch.setattr("textual_webterm.local_server._apply_carriage_returns", apply_cr) - monkeypatch.setattr("textual_webterm.local_server.asyncio.to_thread", AsyncMock(side_effect=fake_to_thread)) - monkeypatch.setattr("textual_webterm.local_server._rewrite_svg_fonts", lambda s: s) - resp = await server_with_no_apps._handle_screenshot(request) assert resp.content_type == "image/svg+xml" - assert captured["len"] == SCREENSHOT_MAX_BYTES + assert "