refactor(exec_bridge): manage subprocess context

This commit is contained in:
banteg
2025-12-29 14:10:48 +04:00
parent b8ecf044a1
commit fc9f33c24c
@@ -11,6 +11,7 @@ import shutil
import time import time
from collections import deque from collections import deque
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
@@ -56,6 +57,21 @@ async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -
logger.debug("[codex][stderr] drain error: %s", e) logger.debug("[codex][stderr] drain error: %s", e)
@asynccontextmanager
async def manage_subprocess(*args, **kwargs):
proc = await asyncio.create_subprocess_exec(*args, **kwargs)
try:
yield proc
finally:
if proc.returncode is None:
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=2.0)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
TELEGRAM_TEXT_LIMIT = TELEGRAM_HARD_LIMIT TELEGRAM_TEXT_LIMIT = TELEGRAM_HARD_LIMIT
TELEGRAM_MARKDOWN_LIMIT = 3500 TELEGRAM_MARKDOWN_LIMIT = 3500
@@ -187,110 +203,97 @@ class CodexExecRunner:
else: else:
args.append("-") args.append("-")
proc = await asyncio.create_subprocess_exec( async with manage_subprocess(
*args, *args,
stdin=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
) ) as proc:
assert proc.stdin and proc.stdout and proc.stderr assert proc.stdin and proc.stdout and proc.stderr
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args) logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
stderr_tail: deque[str] = deque(maxlen=200) stderr_tail: deque[str] = deque(maxlen=200)
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail)) stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
found_session: str | None = session_id found_session: str | None = session_id
last_agent_text: str | None = None last_agent_text: str | None = None
saw_agent_message = False saw_agent_message = False
cli_last_item: int | None = None cli_last_item: int | None = None
cancelled = False cancelled = False
rc: int | None = None rc: int | None = None
try: try:
proc.stdin.write(prompt.encode()) proc.stdin.write(prompt.encode())
await proc.stdin.drain() await proc.stdin.drain()
proc.stdin.close() proc.stdin.close()
async for raw_line in proc.stdout: async for raw_line in proc.stdout:
raw = raw_line.decode(errors="replace") raw = raw_line.decode(errors="replace")
logger.debug("[codex][jsonl] %s", raw.rstrip("\n")) logger.debug("[codex][jsonl] %s", raw.rstrip("\n"))
line = raw.strip() line = raw.strip()
if not line: if not line:
continue continue
try:
evt = json.loads(line)
except json.JSONDecodeError:
logger.debug("[codex][jsonl] invalid line: %r", line)
continue
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
for out in out_lines:
logger.info("[codex] %s", out)
if on_event is not None:
try: try:
res = on_event(evt) evt = json.loads(line)
if inspect.isawaitable(res): except json.JSONDecodeError:
await res logger.debug("[codex][jsonl] invalid line: %r", line)
except Exception as e: continue
logger.info("[codex][on_event] callback error: %s", e)
if evt.get("type") == "thread.started": cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
found_session = evt.get("thread_id") or found_session for out in out_lines:
logger.info("[codex] %s", out)
if evt.get("type") == "item.completed": if on_event is not None:
item = evt.get("item") or {} try:
if item.get("type") == "agent_message" and isinstance( res = on_event(evt)
item.get("text"), str if inspect.isawaitable(res):
): await res
last_agent_text = item["text"] except Exception as e:
saw_agent_message = True logger.info("[codex][on_event] callback error: %s", e)
except asyncio.CancelledError:
cancelled = True
if proc.returncode is None:
proc.terminate()
finally:
if cancelled:
task = asyncio.current_task()
if task is not None:
while task.cancelling():
task.uncancel()
try: if evt.get("type") == "thread.started":
rc = await asyncio.wait_for(proc.wait(), timeout=2.0) found_session = evt.get("thread_id") or found_session
except TimeoutError:
logger.debug( if evt.get("type") == "item.completed":
"[codex] terminate timed out pid=%s, sending kill", proc.pid item = evt.get("item") or {}
) if item.get("type") == "agent_message" and isinstance(
if proc.returncode is None: item.get("text"), str
proc.kill() ):
last_agent_text = item["text"]
saw_agent_message = True
except asyncio.CancelledError:
cancelled = True
finally:
if cancelled:
task = asyncio.current_task()
if task is not None:
while task.cancelling():
task.uncancel()
if not cancelled:
rc = await proc.wait() rc = await proc.wait()
else: await asyncio.gather(stderr_task, return_exceptions=True)
rc = await proc.wait()
await asyncio.gather(stderr_task, return_exceptions=True) if cancelled:
raise asyncio.CancelledError
if cancelled: logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
raise asyncio.CancelledError if rc != 0:
tail = "".join(stderr_tail)
raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}")
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc) if not found_session:
if rc != 0: raise RuntimeError(
tail = "".join(stderr_tail) "codex exec finished but no session_id/thread_id was captured"
raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}") )
if not found_session: logger.info("[codex] done run session_id=%r", found_session)
raise RuntimeError( return (
"codex exec finished but no session_id/thread_id was captured" found_session,
(last_agent_text or "(No agent_message captured from JSON stream.)"),
saw_agent_message,
) )
logger.info("[codex] done run session_id=%r", found_session)
return (
found_session,
(last_agent_text or "(No agent_message captured from JSON stream.)"),
saw_agent_message,
)
async def run_serialized( async def run_serialized(
self, self,
prompt: str, prompt: str,