From e236cd87bdac6adf291260507c41439bcab70f8f Mon Sep 17 00:00:00 2001 From: banteg <4562643+banteg@users.noreply.github.com> Date: Mon, 29 Dec 2025 12:05:56 +0400 Subject: [PATCH] fix(exec-bridge): clean up subprocess on cancellation --- .../src/codex_telegram_bridge/exec_bridge.py | 73 ++++++++++++++----- .../tests/test_exec_bridge.py | 54 ++++++++++++++ 2 files changed, 107 insertions(+), 20 deletions(-) diff --git a/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py b/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py index 483fd88..b7ca4ac 100644 --- a/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py +++ b/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py @@ -24,7 +24,7 @@ from .exec_render import ExecProgressRenderer, render_event_cli from .rendering import render_markdown from .telegram_client import TelegramClient -logger = logging.getLogger("exec_bridge") +logger = logging.getLogger(__name__) UUID_PATTERN = re.compile( r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b" ) @@ -54,16 +54,18 @@ async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) - def setup_logging(log_file: str | None) -> None: - logger.setLevel(logging.DEBUG) - logger.handlers.clear() - logger.propagate = False + root_logger = logging.getLogger() + root_logger.setLevel(logging.DEBUG) + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + handler.close() - fmt = logging.Formatter("%(asctime)s %(message)s") + fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s") console = logging.StreamHandler(sys.stdout) console.setLevel(logging.INFO) console.setFormatter(fmt) - logger.addHandler(console) + root_logger.addHandler(console) if log_file: file_handler = RotatingFileHandler( @@ -74,7 +76,7 @@ def setup_logging(log_file: str | None) -> None: ) file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(fmt) - logger.addHandler(file_handler) + root_logger.addHandler(file_handler) logger.debug("[debug] file logger initialized path=%r", log_file) @@ -127,6 +129,16 @@ def truncate_for_telegram(text: str, limit: int) -> str: return (head + sep + tail)[:limit] +def render_for_telegram( + md: str, *, limit: int +) -> tuple[str, list[dict[str, Any]] | None]: + rendered, entities = render_markdown(md) + if len(rendered) > limit: + rendered = truncate_for_telegram(rendered, limit) + return rendered, None + return rendered, entities or None + + async def _send_markdown( bot: TelegramClient, *, @@ -135,15 +147,12 @@ async def _send_markdown( reply_to_message_id: int | None = None, disable_notification: bool = False, ) -> dict[str, Any]: - rendered, entities = render_markdown(text) - if len(rendered) > TELEGRAM_MARKDOWN_LIMIT: - rendered = truncate_for_telegram(rendered, TELEGRAM_MARKDOWN_LIMIT) - entities = [] + rendered, entities = render_for_telegram(text, limit=TELEGRAM_MARKDOWN_LIMIT) return await bot.send_message( chat_id=chat_id, text=rendered, - entities=entities or None, + entities=entities, reply_to_message_id=reply_to_message_id, disable_notification=disable_notification, ) @@ -207,10 +216,6 @@ class CodexExecRunner: assert proc.stdin and proc.stdout and proc.stderr logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args) - proc.stdin.write(prompt.encode()) - await proc.stdin.drain() - proc.stdin.close() - stderr_tail: deque[str] = deque(maxlen=200) stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail)) @@ -219,7 +224,14 @@ class CodexExecRunner: saw_agent_message = False cli_last_turn: int | None = None + cancelled = False + rc: int | None = None + try: + proc.stdin.write(prompt.encode()) + await proc.stdin.drain() + proc.stdin.close() + async for raw_line in proc.stdout: line = raw_line.decode(errors="replace").strip() if not line: @@ -252,11 +264,32 @@ class CodexExecRunner: last_agent_text = item["text"] saw_agent_message = True except asyncio.CancelledError: - proc.terminate() - raise + cancelled = True + if proc.returncode is None: + proc.terminate() finally: - rc = await proc.wait() - await stderr_task + if cancelled: + task = asyncio.current_task() + if task is not None: + while task.cancelling(): + task.uncancel() + + try: + rc = await asyncio.wait_for(proc.wait(), timeout=2.0) + except asyncio.TimeoutError: + logger.debug( + "[codex] terminate timed out pid=%s, sending kill", proc.pid + ) + if proc.returncode is None: + proc.kill() + rc = await proc.wait() + else: + rc = await proc.wait() + + await asyncio.gather(stderr_task, return_exceptions=True) + + if cancelled: + raise asyncio.CancelledError logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc) if rc != 0: diff --git a/codex_telegram_bridge/tests/test_exec_bridge.py b/codex_telegram_bridge/tests/test_exec_bridge.py index 9687ef7..4632012 100644 --- a/codex_telegram_bridge/tests/test_exec_bridge.py +++ b/codex_telegram_bridge/tests/test_exec_bridge.py @@ -1,4 +1,7 @@ import asyncio +import os + +import pytest from codex_telegram_bridge.exec_bridge import extract_session_id, truncate_for_telegram @@ -148,3 +151,54 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None: assert len(bot.send_calls) == 2 assert bot.send_calls[0]["disable_notification"] is True assert bot.send_calls[1]["disable_notification"] is False + + +def test_codex_runner_cancellation_terminates_subprocess(tmp_path, monkeypatch) -> None: + from codex_telegram_bridge.exec_bridge import CodexExecRunner + + pid_file = tmp_path / "pid" + codex_path = tmp_path / "codex" + codex_path.write_text( + "#!/usr/bin/env python3\n" + "import os\n" + "import time\n" + "\n" + "pid_file = os.environ.get('CODEX_FAKE_PID_FILE')\n" + "if pid_file:\n" + " with open(pid_file, 'w', encoding='utf-8') as f:\n" + " f.write(str(os.getpid()))\n" + " f.flush()\n" + "\n" + "time.sleep(60)\n", + encoding="utf-8", + ) + codex_path.chmod(0o755) + monkeypatch.setenv("CODEX_FAKE_PID_FILE", str(pid_file)) + + runner = CodexExecRunner(codex_cmd=str(codex_path), workspace=None, extra_args=[]) + + async def run_and_cancel() -> None: + task = asyncio.create_task(runner.run("hello", session_id=None)) + + for _ in range(100): + if pid_file.exists(): + break + await asyncio.sleep(0.01) + assert pid_file.exists() + + pid = int(pid_file.read_text(encoding="utf-8").strip()) + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + for _ in range(200): + try: + os.kill(pid, 0) + except ProcessLookupError: + return + await asyncio.sleep(0.01) + + raise AssertionError("cancelled codex subprocess is still running") + + asyncio.run(run_and_cancel())