diff --git a/src/textual_webterm/local_server.py b/src/textual_webterm/local_server.py index 96db4a5..859983f 100644 --- a/src/textual_webterm/local_server.py +++ b/src/textual_webterm/local_server.py @@ -346,9 +346,9 @@ class LocalServer: if msg_type == "stdin": await self._handle_stdin(envelope, route_key, ws) elif msg_type == "resize": - if not session_created and await self._handle_resize(envelope, route_key, ws): - session_created = True - elif session_created: + if not session_created: + session_created = await self._handle_resize(envelope, route_key, ws) + else: await self._handle_resize(envelope, route_key, ws) elif msg_type == "ping": await self._handle_ping(envelope, route_key, ws) diff --git a/src/textual_webterm/poller.py b/src/textual_webterm/poller.py index 8462c95..80e4529 100644 --- a/src/textual_webterm/poller.py +++ b/src/textual_webterm/poller.py @@ -62,7 +62,17 @@ class Poller(Thread): self._write_queues[file_descriptor] = deque() new_write = Write(data) self._write_queues[file_descriptor].append(new_write) - self._selector.modify(file_descriptor, selectors.EVENT_READ | selectors.EVENT_WRITE) + try: + + self._selector.modify(file_descriptor, selectors.EVENT_READ | selectors.EVENT_WRITE) + + except KeyError: + + # File descriptor removed concurrently + + new_write.done_event.set() + + return await new_write.done_event.wait() def set_loop(self, loop: asyncio.AbstractEventLoop) -> None: diff --git a/src/textual_webterm/terminal_session.py b/src/textual_webterm/terminal_session.py index 739b7df..bd694a6 100644 --- a/src/textual_webterm/terminal_session.py +++ b/src/textual_webterm/terminal_session.py @@ -253,9 +253,13 @@ class TerminalSession(Session): os.close(fd) async def send_bytes(self, data: bytes) -> bool: - if self.master_fd is None: + fd = self.master_fd + if fd is None: + return False + try: + await self.poller.write(fd, data) + except (KeyError, OSError): return False - await self.poller.write(self.master_fd, data) return True async def send_meta(self, data: Meta) -> bool: diff --git a/tests/test_local_server_unit.py b/tests/test_local_server_unit.py index c2ff8e6..53d8d0a 100644 --- a/tests/test_local_server_unit.py +++ b/tests/test_local_server_unit.py @@ -679,6 +679,19 @@ class TestLocalServerMoreCoverage: session.send_bytes.assert_awaited_once_with(b"") @pytest.mark.asyncio + @pytest.mark.asyncio + async def test_dispatch_ws_message_resize_existing_session_flag_false(self, server_with_no_apps, monkeypatch): + session = MagicMock() + session.set_terminal_size = AsyncMock() + monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session) + + ws = MagicMock() + created = await server_with_no_apps._dispatch_ws_message( + ["resize", {"width": 100, "height": 50}], "rk", ws, False + ) + assert created is False + session.set_terminal_size.assert_awaited_once_with(100, 50) + async def test_dispatch_ws_message_resize_updates_existing_session(self, server_with_no_apps, monkeypatch): session = MagicMock() session.set_terminal_size = AsyncMock() diff --git a/tests/test_poller.py b/tests/test_poller.py index 65a29bb..1c5fd67 100644 --- a/tests/test_poller.py +++ b/tests/test_poller.py @@ -71,6 +71,17 @@ class TestPoller: # Should not raise poller.remove_file(999) + @pytest.mark.asyncio + async def test_write_handles_removed_fd(self): + poller = Poller() + poller._loop = asyncio.get_event_loop() + + with patch.object(poller._selector, "register"): + poller.add_file(42) + + with patch.object(poller._selector, "modify", side_effect=KeyError()): + await poller.write(42, b"test") + @pytest.mark.asyncio async def test_write_creates_queue(self): """Test that write creates a write queue if needed.""" diff --git a/tests/test_terminal_session.py b/tests/test_terminal_session.py index 938f01d..f827b1e 100644 --- a/tests/test_terminal_session.py +++ b/tests/test_terminal_session.py @@ -288,6 +288,18 @@ class TestTerminalSession: mock_exit.assert_called_once_with(1) + @pytest.mark.asyncio + async def test_send_bytes_handles_closed_fd(self): + from textual_webterm.terminal_session import TerminalSession + + poller = MagicMock() + poller.write = AsyncMock(side_effect=KeyError) + session = TerminalSession(poller, "sid", "bash") + session.master_fd = 10 + + ok = await session.send_bytes(b"test") + assert ok is False + @pytest.mark.asyncio async def test_run_reads_from_poller_and_closes(self): from textual_webterm.terminal_session import TerminalSession