Fix resize and poller races; add coverage
- Fix resize message handling when session already exists - Guard poller selector.modify against removed fds - Handle send_bytes race when master_fd closes - Add tests for resize edge case, poller write KeyError, send_bytes race
This commit is contained in:
@@ -346,9 +346,9 @@ class LocalServer:
|
|||||||
if msg_type == "stdin":
|
if msg_type == "stdin":
|
||||||
await self._handle_stdin(envelope, route_key, ws)
|
await self._handle_stdin(envelope, route_key, ws)
|
||||||
elif msg_type == "resize":
|
elif msg_type == "resize":
|
||||||
if not session_created and await self._handle_resize(envelope, route_key, ws):
|
if not session_created:
|
||||||
session_created = True
|
session_created = await self._handle_resize(envelope, route_key, ws)
|
||||||
elif session_created:
|
else:
|
||||||
await self._handle_resize(envelope, route_key, ws)
|
await self._handle_resize(envelope, route_key, ws)
|
||||||
elif msg_type == "ping":
|
elif msg_type == "ping":
|
||||||
await self._handle_ping(envelope, route_key, ws)
|
await self._handle_ping(envelope, route_key, ws)
|
||||||
|
|||||||
@@ -62,7 +62,17 @@ class Poller(Thread):
|
|||||||
self._write_queues[file_descriptor] = deque()
|
self._write_queues[file_descriptor] = deque()
|
||||||
new_write = Write(data)
|
new_write = Write(data)
|
||||||
self._write_queues[file_descriptor].append(new_write)
|
self._write_queues[file_descriptor].append(new_write)
|
||||||
|
try:
|
||||||
|
|
||||||
self._selector.modify(file_descriptor, selectors.EVENT_READ | selectors.EVENT_WRITE)
|
self._selector.modify(file_descriptor, selectors.EVENT_READ | selectors.EVENT_WRITE)
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
|
||||||
|
# File descriptor removed concurrently
|
||||||
|
|
||||||
|
new_write.done_event.set()
|
||||||
|
|
||||||
|
return
|
||||||
await new_write.done_event.wait()
|
await new_write.done_event.wait()
|
||||||
|
|
||||||
def set_loop(self, loop: asyncio.AbstractEventLoop) -> None:
|
def set_loop(self, loop: asyncio.AbstractEventLoop) -> None:
|
||||||
|
|||||||
@@ -253,9 +253,13 @@ class TerminalSession(Session):
|
|||||||
os.close(fd)
|
os.close(fd)
|
||||||
|
|
||||||
async def send_bytes(self, data: bytes) -> bool:
|
async def send_bytes(self, data: bytes) -> bool:
|
||||||
if self.master_fd is None:
|
fd = self.master_fd
|
||||||
|
if fd is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
await self.poller.write(fd, data)
|
||||||
|
except (KeyError, OSError):
|
||||||
return False
|
return False
|
||||||
await self.poller.write(self.master_fd, data)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def send_meta(self, data: Meta) -> bool:
|
async def send_meta(self, data: Meta) -> bool:
|
||||||
|
|||||||
@@ -679,6 +679,19 @@ class TestLocalServerMoreCoverage:
|
|||||||
session.send_bytes.assert_awaited_once_with(b"")
|
session.send_bytes.assert_awaited_once_with(b"")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_ws_message_resize_existing_session_flag_false(self, server_with_no_apps, monkeypatch):
|
||||||
|
session = MagicMock()
|
||||||
|
session.set_terminal_size = AsyncMock()
|
||||||
|
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session)
|
||||||
|
|
||||||
|
ws = MagicMock()
|
||||||
|
created = await server_with_no_apps._dispatch_ws_message(
|
||||||
|
["resize", {"width": 100, "height": 50}], "rk", ws, False
|
||||||
|
)
|
||||||
|
assert created is False
|
||||||
|
session.set_terminal_size.assert_awaited_once_with(100, 50)
|
||||||
|
|
||||||
async def test_dispatch_ws_message_resize_updates_existing_session(self, server_with_no_apps, monkeypatch):
|
async def test_dispatch_ws_message_resize_updates_existing_session(self, server_with_no_apps, monkeypatch):
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
session.set_terminal_size = AsyncMock()
|
session.set_terminal_size = AsyncMock()
|
||||||
|
|||||||
@@ -71,6 +71,17 @@ class TestPoller:
|
|||||||
# Should not raise
|
# Should not raise
|
||||||
poller.remove_file(999)
|
poller.remove_file(999)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_write_handles_removed_fd(self):
|
||||||
|
poller = Poller()
|
||||||
|
poller._loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
with patch.object(poller._selector, "register"):
|
||||||
|
poller.add_file(42)
|
||||||
|
|
||||||
|
with patch.object(poller._selector, "modify", side_effect=KeyError()):
|
||||||
|
await poller.write(42, b"test")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_write_creates_queue(self):
|
async def test_write_creates_queue(self):
|
||||||
"""Test that write creates a write queue if needed."""
|
"""Test that write creates a write queue if needed."""
|
||||||
|
|||||||
@@ -288,6 +288,18 @@ class TestTerminalSession:
|
|||||||
|
|
||||||
mock_exit.assert_called_once_with(1)
|
mock_exit.assert_called_once_with(1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_bytes_handles_closed_fd(self):
|
||||||
|
from textual_webterm.terminal_session import TerminalSession
|
||||||
|
|
||||||
|
poller = MagicMock()
|
||||||
|
poller.write = AsyncMock(side_effect=KeyError)
|
||||||
|
session = TerminalSession(poller, "sid", "bash")
|
||||||
|
session.master_fd = 10
|
||||||
|
|
||||||
|
ok = await session.send_bytes(b"test")
|
||||||
|
assert ok is False
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_reads_from_poller_and_closes(self):
|
async def test_run_reads_from_poller_and_closes(self):
|
||||||
from textual_webterm.terminal_session import TerminalSession
|
from textual_webterm.terminal_session import TerminalSession
|
||||||
|
|||||||
Reference in New Issue
Block a user