fix: maintain pyte screen state in TerminalSession for accurate screenshots

Instead of trying to replay a truncated byte buffer through pyte, this
change maintains a pyte Screen object within TerminalSession that gets
updated as terminal data flows through. This provides accurate terminal
state for screenshots without issues from buffer truncation.

Key changes:
- Add pyte Screen and Stream to TerminalSession
- Update screen state as data arrives via _update_screen()
- Add get_screen_lines() to return current screen state
- Resize pyte screen when terminal size changes
- Update local_server to use get_screen_lines() directly
- Remove _apply_carriage_returns() workaround

This properly fixes the tmux status bar 'creeping up' issue by ensuring
the screenshot always reflects the actual terminal state.
This commit is contained in:
GitHub Copilot
2026-01-24 10:33:31 +00:00
parent a58c434eaf
commit 894fb2eaaf
4 changed files with 87 additions and 70 deletions
+6 -26
View File
@@ -14,7 +14,6 @@ from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import aiohttp import aiohttp
import pyte
from aiohttp import WSMsgType, web from aiohttp import WSMsgType, web
from rich.ansi import AnsiDecoder from rich.ansi import AnsiDecoder
from rich.console import Console from rich.console import Console
@@ -34,8 +33,6 @@ log = logging.getLogger("textual-web")
DISCONNECT_RESIZE = (132, 45) DISCONNECT_RESIZE = (132, 45)
# Avoid heavy screenshot rendering from processing unbounded output.
SCREENSHOT_MAX_BYTES = 65536
SCREENSHOT_CACHE_SECONDS = 1.0 SCREENSHOT_CACHE_SECONDS = 1.0
SCREENSHOT_MAX_CACHE_SECONDS = 60.0 SCREENSHOT_MAX_CACHE_SECONDS = 60.0
@@ -108,19 +105,6 @@ def _rewrite_svg_fonts(svg: str) -> str:
return svg return svg
def _apply_carriage_returns(text: str, width: int = 80, height: int = 24) -> list[str]:
"""Use pyte terminal emulator to properly interpret ANSI escape sequences.
This handles cursor positioning, screen clearing, and other terminal control
codes that cause issues like tmux status bars "creeping up" in screenshots.
"""
screen = pyte.Screen(width, height)
stream = pyte.Stream(screen)
stream.feed(text)
# Return lines from the display, stripping trailing whitespace
return [line.rstrip() for line in screen.display]
class LocalServer: class LocalServer:
def mark_route_activity(self, route_key: str) -> None: def mark_route_activity(self, route_key: str) -> None:
self._route_last_activity[route_key] = asyncio.get_event_loop().time() self._route_last_activity[route_key] = asyncio.get_event_loop().time()
@@ -473,7 +457,7 @@ class LocalServer:
) )
session_process = self.session_manager.get_session_by_route_key(RouteKey(route_key)) session_process = self.session_manager.get_session_by_route_key(RouteKey(route_key))
if session_process is None or not hasattr(session_process, "get_replay_buffer"): if session_process is None or not hasattr(session_process, "get_screen_lines"):
raise web.HTTPNotFound(text="Session not found") raise web.HTTPNotFound(text="Session not found")
# If nothing has changed since the last render, serve cached screenshot without # If nothing has changed since the last render, serve cached screenshot without
@@ -485,10 +469,10 @@ class LocalServer:
if cached_response is not None: if cached_response is not None:
return cached_response return cached_response
replay_data = await session_process.get_replay_buffer() # type: ignore[func-returns-value] # Get screen lines directly from the terminal session's pyte screen
if len(replay_data) > SCREENSHOT_MAX_BYTES: # This provides accurate terminal state without replay buffer truncation issues
replay_data = replay_data[-SCREENSHOT_MAX_BYTES:] lines = await session_process.get_screen_lines() # type: ignore[union-attr]
ansi_text = replay_data.decode("utf-8", errors="replace") screen_text = "\n".join(lines)
try: try:
width = int(request.query.get("width", "120")) width = int(request.query.get("width", "120"))
@@ -502,10 +486,6 @@ class LocalServer:
height = DISCONNECT_RESIZE[1] height = DISCONNECT_RESIZE[1]
height = max(5, min(200, height)) height = max(5, min(200, height))
# Use pyte terminal emulator to get clean screen state
lines = _apply_carriage_returns(ansi_text, width, height)
ansi_text = "\n".join(lines)
now = asyncio.get_event_loop().time() now = asyncio.get_event_loop().time()
ttl = self._get_screenshot_cache_ttl(route_key, now) ttl = self._get_screenshot_cache_ttl(route_key, now)
cached = self._screenshot_cache.get(route_key) cached = self._screenshot_cache.get(route_key)
@@ -539,7 +519,7 @@ class LocalServer:
def _render_svg() -> str: def _render_svg() -> str:
console = Console(record=True, width=width, height=height, file=io.StringIO()) console = Console(record=True, width=width, height=height, file=io.StringIO())
decoder = AnsiDecoder() decoder = AnsiDecoder()
for renderable in decoder.decode(ansi_text): for renderable in decoder.decode(screen_text):
console.print(renderable) console.print(renderable)
return console.export_svg( return console.export_svg(
+37
View File
@@ -13,6 +13,7 @@ import termios
from collections import deque from collections import deque
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import pyte
import rich.repr import rich.repr
from importlib_metadata import version from importlib_metadata import version
@@ -27,6 +28,10 @@ log = logging.getLogger("textual-web")
# Maximum bytes to keep in replay buffer for reconnection # Maximum bytes to keep in replay buffer for reconnection
REPLAY_BUFFER_SIZE = 64 * 1024 # 64KB REPLAY_BUFFER_SIZE = 64 * 1024 # 64KB
# Default screen size for pyte emulator
DEFAULT_SCREEN_WIDTH = 132
DEFAULT_SCREEN_HEIGHT = 45
@rich.repr.auto @rich.repr.auto
class TerminalSession(Session): class TerminalSession(Session):
@@ -47,6 +52,10 @@ class TerminalSession(Session):
self._replay_buffer: deque[bytes] = deque() self._replay_buffer: deque[bytes] = deque()
self._replay_buffer_size = 0 self._replay_buffer_size = 0
self._replay_lock = asyncio.Lock() self._replay_lock = asyncio.Lock()
# pyte screen for accurate terminal state tracking
self._screen = pyte.Screen(DEFAULT_SCREEN_WIDTH, DEFAULT_SCREEN_HEIGHT)
self._stream = pyte.Stream(self._screen)
self._screen_lock = asyncio.Lock()
super().__init__() super().__init__()
def __rich_repr__(self) -> rich.repr.Result: def __rich_repr__(self) -> rich.repr.Result:
@@ -55,6 +64,10 @@ class TerminalSession(Session):
async def open(self, width: int = 80, height: int = 24) -> None: async def open(self, width: int = 80, height: int = 24) -> None:
log.info("Opening terminal session %s with command: %s", self.session_id, self.command) log.info("Opening terminal session %s with command: %s", self.session_id, self.command)
# Initialize pyte screen with the requested size
self._screen = pyte.Screen(width, height)
self._stream = pyte.Stream(self._screen)
pid, master_fd = pty.fork() pid, master_fd = pty.fork()
self.pid = pid self.pid = pid
self.master_fd = master_fd self.master_fd = master_fd
@@ -88,6 +101,9 @@ class TerminalSession(Session):
async def set_terminal_size(self, width: int, height: int) -> None: async def set_terminal_size(self, width: int, height: int) -> None:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._set_terminal_size, width, height) await loop.run_in_executor(None, self._set_terminal_size, width, height)
# Resize pyte screen to match
async with self._screen_lock:
self._screen.resize(height, width)
async def _add_to_replay_buffer(self, data: bytes) -> None: async def _add_to_replay_buffer(self, data: bytes) -> None:
"""Add data to replay buffer, maintaining size limit.""" """Add data to replay buffer, maintaining size limit."""
@@ -98,11 +114,30 @@ class TerminalSession(Session):
old_data = self._replay_buffer.popleft() old_data = self._replay_buffer.popleft()
self._replay_buffer_size -= len(old_data) self._replay_buffer_size -= len(old_data)
async def _update_screen(self, data: bytes) -> None:
"""Update the pyte screen with new terminal data."""
async with self._screen_lock:
try:
text = data.decode("utf-8", errors="replace")
self._stream.feed(text)
except Exception:
# Don't let pyte errors crash the session
pass
async def get_replay_buffer(self) -> bytes: async def get_replay_buffer(self) -> bytes:
"""Get the contents of the replay buffer.""" """Get the contents of the replay buffer."""
async with self._replay_lock: async with self._replay_lock:
return b"".join(self._replay_buffer) return b"".join(self._replay_buffer)
async def get_screen_lines(self) -> list[str]:
"""Get the current screen state as a list of lines.
Returns properly rendered terminal content with all escape sequences
interpreted, suitable for screenshot generation.
"""
async with self._screen_lock:
return [line.rstrip() for line in self._screen.display]
def update_connector(self, connector: SessionConnector) -> None: def update_connector(self, connector: SessionConnector) -> None:
"""Update the connector for reconnection without restarting the session.""" """Update the connector for reconnection without restarting the session."""
self._connector = connector self._connector = connector
@@ -127,6 +162,8 @@ class TerminalSession(Session):
break break
# Store in replay buffer for reconnection # Store in replay buffer for reconnection
await self._add_to_replay_buffer(data) await self._add_to_replay_buffer(data)
# Update pyte screen state for screenshots
await self._update_screen(data)
# Send to current connector # Send to current connector
if self._connector: if self._connector:
await self._connector.on_data(data) await self._connector.on_data(data)
+10 -44
View File
@@ -9,7 +9,6 @@ from textual_webterm.config import App, Config
from textual_webterm.local_server import ( from textual_webterm.local_server import (
LocalClientConnector, LocalClientConnector,
LocalServer, LocalServer,
_apply_carriage_returns,
_rewrite_svg_fonts, _rewrite_svg_fonts,
) )
@@ -110,26 +109,6 @@ class TestLocalServer:
class TestLocalServerHelpers: class TestLocalServerHelpers:
"""Tests for LocalServer helper methods.""" """Tests for LocalServer helper methods."""
def test_apply_carriage_returns_overwrites_line(self):
text = "hello\rworld\r\nnext"
# pyte terminal emulator interprets CR properly - overwrites hello with world
lines = _apply_carriage_returns(text, width=80, height=24)
# First line should have "world" (overwritten), second line "next"
assert lines[0] == "world"
assert lines[1] == "next"
def test_apply_carriage_returns_handles_cursor_positioning(self):
# Simulate tmux-style cursor positioning to row 5, column 1 (\x1b[5;1H)
# Then clear to end of line (\x1b[K) and write new content
# Use \r\n for proper line endings
text = "line1\r\nline2\r\nline3\r\nline4\r\nline5\x1b[5;1H\x1b[Kupdated"
lines = _apply_carriage_returns(text, width=80, height=10)
# Line 5 (index 4) should be overwritten with "updated"
assert lines[4] == "updated"
# Previous lines should remain
assert lines[0] == "line1"
assert lines[1] == "line2"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_keyboard_interrupt_closes_sessions_and_websockets(self, server, monkeypatch): async def test_keyboard_interrupt_closes_sessions_and_websockets(self, server, monkeypatch):
ws1 = MagicMock() ws1 = MagicMock()
@@ -215,7 +194,7 @@ class TestLocalServerHelpers:
request.query = {"route_key": "rk", "width": "80"} request.query = {"route_key": "rk", "width": "80"}
session = MagicMock() session = MagicMock()
session.get_replay_buffer = AsyncMock(return_value=b"hello\r\n") session.get_screen_lines = AsyncMock(return_value=["hello", ""])
monkeypatch.setattr(server.session_manager, "get_session_by_route_key", lambda _rk: session) monkeypatch.setattr(server.session_manager, "get_session_by_route_key", lambda _rk: session)
@@ -233,7 +212,7 @@ class TestLocalServerHelpers:
request.query = {"route_key": "known", "width": "90"} request.query = {"route_key": "known", "width": "90"}
session = MagicMock() session = MagicMock()
session.get_replay_buffer = AsyncMock(return_value=b"world\r\n") session.get_screen_lines = AsyncMock(return_value=["world", ""])
# Pretend app exists for slug "known" # Pretend app exists for slug "known"
server.session_manager.apps_by_slug["known"] = App( server.session_manager.apps_by_slug["known"] = App(
@@ -565,7 +544,7 @@ class TestLocalServerMoreCoverage:
request.headers = {} request.headers = {}
session = MagicMock() session = MagicMock()
session.get_replay_buffer = AsyncMock(return_value=b"SHOULD_NOT_BE_READ") session.get_screen_lines = AsyncMock(return_value=["SHOULD_NOT_BE_READ"])
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session) monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session)
server_with_no_apps._screenshot_cache["rk"] = (0.0, "<svg>cached</svg>") server_with_no_apps._screenshot_cache["rk"] = (0.0, "<svg>cached</svg>")
@@ -575,7 +554,7 @@ class TestLocalServerMoreCoverage:
resp = await server_with_no_apps._handle_screenshot(request) resp = await server_with_no_apps._handle_screenshot(request)
assert "cached" in resp.text assert "cached" in resp.text
session.get_replay_buffer.assert_not_awaited() session.get_screen_lines.assert_not_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_screenshot_invalid_width_height_defaults(self, server_with_no_apps, monkeypatch): async def test_handle_screenshot_invalid_width_height_defaults(self, server_with_no_apps, monkeypatch):
@@ -584,7 +563,7 @@ class TestLocalServerMoreCoverage:
request.headers = {} request.headers = {}
session = MagicMock() session = MagicMock()
session.get_replay_buffer = AsyncMock(return_value=b"hello\n") session.get_screen_lines = AsyncMock(return_value=["hello", ""])
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session) monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session)
resp = await server_with_no_apps._handle_screenshot(request) resp = await server_with_no_apps._handle_screenshot(request)
@@ -756,32 +735,19 @@ class TestLocalServerMoreCoverage:
assert created is True assert created is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_screenshot_truncates_replay_buffer_before_decode(self, server_with_no_apps, monkeypatch): async def test_handle_screenshot_uses_get_screen_lines(self, server_with_no_apps, monkeypatch):
from textual_webterm.local_server import SCREENSHOT_MAX_BYTES """Test that screenshot uses get_screen_lines() from terminal session."""
request = MagicMock() request = MagicMock()
request.query = {"route_key": "rk"} request.query = {"route_key": "rk"}
request.headers = {} request.headers = {}
session = MagicMock() session = MagicMock()
session.get_replay_buffer = AsyncMock(return_value=b"x" * (SCREENSHOT_MAX_BYTES + 10)) session.get_screen_lines = AsyncMock(return_value=["line1", "line2", ""])
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session) monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session)
server_with_no_apps._route_last_activity["rk"] = 1.0 server_with_no_apps._route_last_activity["rk"] = 1.0
captured = {"len": None}
def apply_cr(text: str, width: int = 80, height: int = 24):
captured["len"] = len(text)
return ["x"]
async def fake_to_thread(_fn):
return "<svg></svg>"
monkeypatch.setattr("textual_webterm.local_server._apply_carriage_returns", apply_cr)
monkeypatch.setattr("textual_webterm.local_server.asyncio.to_thread", AsyncMock(side_effect=fake_to_thread))
monkeypatch.setattr("textual_webterm.local_server._rewrite_svg_fonts", lambda s: s)
resp = await server_with_no_apps._handle_screenshot(request) resp = await server_with_no_apps._handle_screenshot(request)
assert resp.content_type == "image/svg+xml" assert resp.content_type == "image/svg+xml"
assert captured["len"] == SCREENSHOT_MAX_BYTES assert "<svg" in resp.text
session.get_screen_lines.assert_awaited_once()
+34
View File
@@ -96,6 +96,40 @@ class TestTerminalSession:
# Buffer should be trimmed # Buffer should be trimmed
assert session._replay_buffer_size <= REPLAY_BUFFER_SIZE + chunk_size assert session._replay_buffer_size <= REPLAY_BUFFER_SIZE + chunk_size
@pytest.mark.asyncio
async def test_screen_state_updates_with_data(self):
"""Test that pyte screen updates when data is received."""
from textual_webterm.terminal_session import TerminalSession
mock_poller = MagicMock()
session = TerminalSession(mock_poller, "test-session", "bash")
# Feed some terminal data
await session._update_screen(b"Hello World\r\n")
lines = await session.get_screen_lines()
# First line should contain the text
assert "Hello World" in lines[0]
@pytest.mark.asyncio
async def test_screen_handles_cursor_positioning(self):
"""Test that pyte screen correctly handles cursor positioning (tmux-style)."""
from textual_webterm.terminal_session import TerminalSession
mock_poller = MagicMock()
session = TerminalSession(mock_poller, "test-session", "bash")
# Feed content then reposition cursor and overwrite
await session._update_screen(b"Line 1\r\nLine 2\r\nLine 3\r\n")
# Move cursor to line 2, column 1 and clear line, then write new content
await session._update_screen(b"\x1b[2;1H\x1b[KUpdated Line 2")
lines = await session.get_screen_lines()
assert lines[0] == "Line 1"
assert lines[1] == "Updated Line 2"
assert lines[2] == "Line 3"
def test_update_connector(self): def test_update_connector(self):
"""Test updating connector.""" """Test updating connector."""
from textual_webterm.terminal_session import TerminalSession from textual_webterm.terminal_session import TerminalSession