diff --git a/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py b/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py index 4bf0fdd..550dead 100644 --- a/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py +++ b/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py @@ -11,6 +11,7 @@ import shutil import time from collections import deque from collections.abc import Awaitable, Callable +from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any from weakref import WeakValueDictionary @@ -56,6 +57,21 @@ async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) - logger.debug("[codex][stderr] drain error: %s", e) +@asynccontextmanager +async def manage_subprocess(*args, **kwargs): + proc = await asyncio.create_subprocess_exec(*args, **kwargs) + try: + yield proc + finally: + if proc.returncode is None: + proc.terminate() + try: + await asyncio.wait_for(proc.wait(), timeout=2.0) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + + TELEGRAM_TEXT_LIMIT = TELEGRAM_HARD_LIMIT TELEGRAM_MARKDOWN_LIMIT = 3500 @@ -187,110 +203,97 @@ class CodexExecRunner: else: args.append("-") - proc = await asyncio.create_subprocess_exec( + async with manage_subprocess( *args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - ) - assert proc.stdin and proc.stdout and proc.stderr - logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args) + ) as proc: + assert proc.stdin and proc.stdout and proc.stderr + logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args) - stderr_tail: deque[str] = deque(maxlen=200) - stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail)) + stderr_tail: deque[str] = deque(maxlen=200) + stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail)) - found_session: str | None = session_id - last_agent_text: str | None = None - saw_agent_message = False - cli_last_item: int | None = None + found_session: str | None = session_id + last_agent_text: str | None = None + saw_agent_message = False + cli_last_item: int | None = None - cancelled = False - rc: int | None = None + cancelled = False + rc: int | None = None - try: - proc.stdin.write(prompt.encode()) - await proc.stdin.drain() - proc.stdin.close() + try: + proc.stdin.write(prompt.encode()) + await proc.stdin.drain() + proc.stdin.close() - async for raw_line in proc.stdout: - raw = raw_line.decode(errors="replace") - logger.debug("[codex][jsonl] %s", raw.rstrip("\n")) - line = raw.strip() - if not line: - continue - try: - evt = json.loads(line) - except json.JSONDecodeError: - logger.debug("[codex][jsonl] invalid line: %r", line) - continue - - cli_last_item, out_lines = render_event_cli(evt, cli_last_item) - for out in out_lines: - logger.info("[codex] %s", out) - - if on_event is not None: + async for raw_line in proc.stdout: + raw = raw_line.decode(errors="replace") + logger.debug("[codex][jsonl] %s", raw.rstrip("\n")) + line = raw.strip() + if not line: + continue try: - res = on_event(evt) - if inspect.isawaitable(res): - await res - except Exception as e: - logger.info("[codex][on_event] callback error: %s", e) + evt = json.loads(line) + except json.JSONDecodeError: + logger.debug("[codex][jsonl] invalid line: %r", line) + continue - if evt.get("type") == "thread.started": - found_session = evt.get("thread_id") or found_session + cli_last_item, out_lines = render_event_cli(evt, cli_last_item) + for out in out_lines: + logger.info("[codex] %s", out) - if evt.get("type") == "item.completed": - item = evt.get("item") or {} - if item.get("type") == "agent_message" and isinstance( - item.get("text"), str - ): - last_agent_text = item["text"] - saw_agent_message = True - except asyncio.CancelledError: - cancelled = True - if proc.returncode is None: - proc.terminate() - finally: - if cancelled: - task = asyncio.current_task() - if task is not None: - while task.cancelling(): - task.uncancel() + if on_event is not None: + try: + res = on_event(evt) + if inspect.isawaitable(res): + await res + except Exception as e: + logger.info("[codex][on_event] callback error: %s", e) - try: - rc = await asyncio.wait_for(proc.wait(), timeout=2.0) - except TimeoutError: - logger.debug( - "[codex] terminate timed out pid=%s, sending kill", proc.pid - ) - if proc.returncode is None: - proc.kill() + if evt.get("type") == "thread.started": + found_session = evt.get("thread_id") or found_session + + if evt.get("type") == "item.completed": + item = evt.get("item") or {} + if item.get("type") == "agent_message" and isinstance( + item.get("text"), str + ): + last_agent_text = item["text"] + saw_agent_message = True + except asyncio.CancelledError: + cancelled = True + finally: + if cancelled: + task = asyncio.current_task() + if task is not None: + while task.cancelling(): + task.uncancel() + if not cancelled: rc = await proc.wait() - else: - rc = await proc.wait() + await asyncio.gather(stderr_task, return_exceptions=True) - await asyncio.gather(stderr_task, return_exceptions=True) + if cancelled: + raise asyncio.CancelledError - if cancelled: - raise asyncio.CancelledError + logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc) + if rc != 0: + tail = "".join(stderr_tail) + raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}") - logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc) - if rc != 0: - tail = "".join(stderr_tail) - raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}") + if not found_session: + raise RuntimeError( + "codex exec finished but no session_id/thread_id was captured" + ) - if not found_session: - raise RuntimeError( - "codex exec finished but no session_id/thread_id was captured" + logger.info("[codex] done run session_id=%r", found_session) + return ( + found_session, + (last_agent_text or "(No agent_message captured from JSON stream.)"), + saw_agent_message, ) - logger.info("[codex] done run session_id=%r", found_session) - return ( - found_session, - (last_agent_text or "(No agent_message captured from JSON stream.)"), - saw_agent_message, - ) - async def run_serialized( self, prompt: str,