Files
webterm/tests/test_local_server_websocket_integration.py
T
GitHub Copilot d5343117d3 Filter DA1 responses from replay buffer on WebSocket connect
The replay buffer can contain DA1/DA2 terminal attribute responses
(e.g., \x1b[?1;10;0c) that were captured before filtering was added
to the session classes. These responses appear as visible text like
'1;10;0c' when sent to the client on reconnect.

This adds an additional filter pass when sending the replay buffer,
ensuring no DA1 responses reach the client regardless of when they
were captured.
2026-01-29 19:13:40 +00:00

198 lines
5.8 KiB
Python

from __future__ import annotations
import json
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING
from unittest.mock import AsyncMock
import pytest
from aiohttp import WSMsgType, web
from aiohttp.test_utils import TestClient, TestServer
from webterm.config import App, Config
from webterm.local_server import LocalServer
from webterm.types import RouteKey, SessionID
if TYPE_CHECKING:
from collections.abc import AsyncIterator
async def _make_client(server: LocalServer) -> TestClient:
app = web.Application()
app.add_routes(server._build_routes())
test_server = TestServer(app)
client = TestClient(test_server)
await client.start_server()
return client
@pytest.fixture
def server_factory(tmp_path):
counter = {"i": 0}
def _make(apps: list[App] | None = None) -> LocalServer:
counter["i"] += 1
config = Config(
apps=apps
or [App(name="Test", slug="test", path=".", command="echo test", terminal=True)]
)
config_file = tmp_path / f"config-{counter['i']}.toml"
config_file.write_text("")
return LocalServer(config_path=str(config_file), config=config)
return _make
@pytest.fixture
def client_factory():
@asynccontextmanager
async def _factory(server: LocalServer) -> AsyncIterator[TestClient]:
client = await _make_client(server)
try:
yield client
finally:
await client.close()
return _factory
@pytest.mark.asyncio
async def test_websocket_creates_session_on_resize(tmp_path):
config = Config(
apps=[App(name="Test", slug="test", path=".", command="echo test", terminal=True)]
)
config_file = tmp_path / "config.toml"
config_file.write_text("")
server = LocalServer(config_path=str(config_file), config=config)
# Avoid spawning any real processes.
created = {"args": None}
async def fake_create(route_key: str, width: int, height: int) -> None:
created["args"] = (route_key, width, height)
server._create_terminal_session = fake_create # type: ignore[method-assign]
client = await _make_client(server)
try:
ws = await client.ws_connect("/ws/test")
await ws.send_str(json.dumps(["resize", {"width": 90, "height": 25}]))
await ws.close()
finally:
await client.close()
assert created["args"] == ("test", 90, 25)
# Reconnect to an existing session should reuse it and send replay buffer
class DummySession:
def is_running(self):
return True
server.session_manager.routes["test"] = "sid"
server.session_manager.sessions["sid"] = DummySession()
# Replay buffer should be sent on reconnect
replay_session = server.session_manager.sessions["sid"]
replay_session.get_replay_buffer = AsyncMock(return_value=b"replay")
client = await _make_client(server)
try:
ws = await client.ws_connect("/ws/test")
msg = await ws.receive(timeout=1)
assert msg.type == WSMsgType.BINARY
assert msg.data == b"replay"
await ws.close()
finally:
await client.close()
@pytest.mark.asyncio
async def test_websocket_ping_pong(tmp_path):
config = Config(
apps=[App(name="Test", slug="test", path=".", command="echo test", terminal=True)]
)
config_file = tmp_path / "config.toml"
config_file.write_text("")
server = LocalServer(config_path=str(config_file), config=config)
client = await _make_client(server)
try:
ws = await client.ws_connect("/ws/test")
await ws.send_str(json.dumps(["ping", "123"]))
msg = await ws.receive(timeout=1)
assert msg.type == WSMsgType.TEXT
assert json.loads(msg.data) == ["pong", "123"]
await ws.close()
finally:
await client.close()
@pytest.mark.asyncio
async def test_websocket_ignores_invalid_envelopes(tmp_path):
config = Config(
apps=[App(name="Test", slug="test", path=".", command="echo test", terminal=True)]
)
config_file = tmp_path / "config.toml"
config_file.write_text("")
server = LocalServer(config_path=str(config_file), config=config)
client = await _make_client(server)
try:
ws = await client.ws_connect("/ws/test")
await ws.send_str("not json")
await ws.send_str(json.dumps({"not": "a list"}))
await ws.send_str(json.dumps([]))
await ws.close()
finally:
await client.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
("payload", "is_binary"),
[
("not json", False),
(json.dumps({"not": "a list"}), False),
(json.dumps([]), False),
(b"\x00\x01\x02", True),
],
)
async def test_websocket_invalid_payloads_keep_connection(
server_factory, client_factory, payload, is_binary
):
server = server_factory()
async with client_factory(server) as client:
ws = await client.ws_connect("/ws/test")
if is_binary:
await ws.send_bytes(payload)
else:
await ws.send_str(payload)
await ws.send_str(json.dumps(["ping", "ok"]))
msg = await ws.receive(timeout=1)
assert msg.type == WSMsgType.TEXT
assert json.loads(msg.data) == ["pong", "ok"]
await ws.close()
@pytest.mark.asyncio
async def test_websocket_clears_stale_session(server_factory, client_factory):
server = server_factory()
class DummySession:
def is_running(self):
return False
session_id = SessionID("sid")
route_key = RouteKey("test")
server.session_manager.routes[route_key] = session_id
server.session_manager.sessions[session_id] = DummySession()
async with client_factory(server) as client:
ws = await client.ws_connect("/ws/test")
assert server.session_manager.get_session(session_id) is None
assert server.session_manager.routes.get(route_key) is None
await ws.close()