Fix resize and poller races; add coverage

- Fix resize message handling when session already exists
- Guard poller selector.modify against removed fds
- Handle send_bytes race when master_fd closes
- Add tests for resize edge case, poller write KeyError, send_bytes race
This commit is contained in:
GitHub Copilot
2026-01-26 20:07:40 +00:00
parent 245849ba9f
commit 63e8cba0ac
6 changed files with 56 additions and 6 deletions
+3 -3
View File
@@ -346,9 +346,9 @@ class LocalServer:
if msg_type == "stdin": if msg_type == "stdin":
await self._handle_stdin(envelope, route_key, ws) await self._handle_stdin(envelope, route_key, ws)
elif msg_type == "resize": elif msg_type == "resize":
if not session_created and await self._handle_resize(envelope, route_key, ws): if not session_created:
session_created = True session_created = await self._handle_resize(envelope, route_key, ws)
elif session_created: else:
await self._handle_resize(envelope, route_key, ws) await self._handle_resize(envelope, route_key, ws)
elif msg_type == "ping": elif msg_type == "ping":
await self._handle_ping(envelope, route_key, ws) await self._handle_ping(envelope, route_key, ws)
+11 -1
View File
@@ -62,7 +62,17 @@ class Poller(Thread):
self._write_queues[file_descriptor] = deque() self._write_queues[file_descriptor] = deque()
new_write = Write(data) new_write = Write(data)
self._write_queues[file_descriptor].append(new_write) self._write_queues[file_descriptor].append(new_write)
self._selector.modify(file_descriptor, selectors.EVENT_READ | selectors.EVENT_WRITE) try:
self._selector.modify(file_descriptor, selectors.EVENT_READ | selectors.EVENT_WRITE)
except KeyError:
# File descriptor removed concurrently
new_write.done_event.set()
return
await new_write.done_event.wait() await new_write.done_event.wait()
def set_loop(self, loop: asyncio.AbstractEventLoop) -> None: def set_loop(self, loop: asyncio.AbstractEventLoop) -> None:
+6 -2
View File
@@ -253,9 +253,13 @@ class TerminalSession(Session):
os.close(fd) os.close(fd)
async def send_bytes(self, data: bytes) -> bool: async def send_bytes(self, data: bytes) -> bool:
if self.master_fd is None: fd = self.master_fd
if fd is None:
return False
try:
await self.poller.write(fd, data)
except (KeyError, OSError):
return False return False
await self.poller.write(self.master_fd, data)
return True return True
async def send_meta(self, data: Meta) -> bool: async def send_meta(self, data: Meta) -> bool:
+13
View File
@@ -679,6 +679,19 @@ class TestLocalServerMoreCoverage:
session.send_bytes.assert_awaited_once_with(b"") session.send_bytes.assert_awaited_once_with(b"")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.asyncio
async def test_dispatch_ws_message_resize_existing_session_flag_false(self, server_with_no_apps, monkeypatch):
session = MagicMock()
session.set_terminal_size = AsyncMock()
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session)
ws = MagicMock()
created = await server_with_no_apps._dispatch_ws_message(
["resize", {"width": 100, "height": 50}], "rk", ws, False
)
assert created is False
session.set_terminal_size.assert_awaited_once_with(100, 50)
async def test_dispatch_ws_message_resize_updates_existing_session(self, server_with_no_apps, monkeypatch): async def test_dispatch_ws_message_resize_updates_existing_session(self, server_with_no_apps, monkeypatch):
session = MagicMock() session = MagicMock()
session.set_terminal_size = AsyncMock() session.set_terminal_size = AsyncMock()
+11
View File
@@ -71,6 +71,17 @@ class TestPoller:
# Should not raise # Should not raise
poller.remove_file(999) poller.remove_file(999)
@pytest.mark.asyncio
async def test_write_handles_removed_fd(self):
poller = Poller()
poller._loop = asyncio.get_event_loop()
with patch.object(poller._selector, "register"):
poller.add_file(42)
with patch.object(poller._selector, "modify", side_effect=KeyError()):
await poller.write(42, b"test")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_write_creates_queue(self): async def test_write_creates_queue(self):
"""Test that write creates a write queue if needed.""" """Test that write creates a write queue if needed."""
+12
View File
@@ -288,6 +288,18 @@ class TestTerminalSession:
mock_exit.assert_called_once_with(1) mock_exit.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_send_bytes_handles_closed_fd(self):
from textual_webterm.terminal_session import TerminalSession
poller = MagicMock()
poller.write = AsyncMock(side_effect=KeyError)
session = TerminalSession(poller, "sid", "bash")
session.master_fd = 10
ok = await session.send_bytes(b"test")
assert ok is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_reads_from_poller_and_closes(self): async def test_run_reads_from_poller_and_closes(self):
from textual_webterm.terminal_session import TerminalSession from textual_webterm.terminal_session import TerminalSession