refactor(exec_bridge): manage subprocess context

This commit is contained in:
banteg
2025-12-29 14:10:48 +04:00
parent b8ecf044a1
commit fc9f33c24c
@@ -11,6 +11,7 @@ import shutil
import time
from collections import deque
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any
from weakref import WeakValueDictionary
@@ -56,6 +57,21 @@ async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -
logger.debug("[codex][stderr] drain error: %s", e)
@asynccontextmanager
async def manage_subprocess(*args, **kwargs):
proc = await asyncio.create_subprocess_exec(*args, **kwargs)
try:
yield proc
finally:
if proc.returncode is None:
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=2.0)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
TELEGRAM_TEXT_LIMIT = TELEGRAM_HARD_LIMIT
TELEGRAM_MARKDOWN_LIMIT = 3500
@@ -187,110 +203,97 @@ class CodexExecRunner:
else:
args.append("-")
proc = await asyncio.create_subprocess_exec(
async with manage_subprocess(
*args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
assert proc.stdin and proc.stdout and proc.stderr
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
) as proc:
assert proc.stdin and proc.stdout and proc.stderr
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
stderr_tail: deque[str] = deque(maxlen=200)
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
stderr_tail: deque[str] = deque(maxlen=200)
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
found_session: str | None = session_id
last_agent_text: str | None = None
saw_agent_message = False
cli_last_item: int | None = None
found_session: str | None = session_id
last_agent_text: str | None = None
saw_agent_message = False
cli_last_item: int | None = None
cancelled = False
rc: int | None = None
cancelled = False
rc: int | None = None
try:
proc.stdin.write(prompt.encode())
await proc.stdin.drain()
proc.stdin.close()
try:
proc.stdin.write(prompt.encode())
await proc.stdin.drain()
proc.stdin.close()
async for raw_line in proc.stdout:
raw = raw_line.decode(errors="replace")
logger.debug("[codex][jsonl] %s", raw.rstrip("\n"))
line = raw.strip()
if not line:
continue
try:
evt = json.loads(line)
except json.JSONDecodeError:
logger.debug("[codex][jsonl] invalid line: %r", line)
continue
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
for out in out_lines:
logger.info("[codex] %s", out)
if on_event is not None:
async for raw_line in proc.stdout:
raw = raw_line.decode(errors="replace")
logger.debug("[codex][jsonl] %s", raw.rstrip("\n"))
line = raw.strip()
if not line:
continue
try:
res = on_event(evt)
if inspect.isawaitable(res):
await res
except Exception as e:
logger.info("[codex][on_event] callback error: %s", e)
evt = json.loads(line)
except json.JSONDecodeError:
logger.debug("[codex][jsonl] invalid line: %r", line)
continue
if evt.get("type") == "thread.started":
found_session = evt.get("thread_id") or found_session
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
for out in out_lines:
logger.info("[codex] %s", out)
if evt.get("type") == "item.completed":
item = evt.get("item") or {}
if item.get("type") == "agent_message" and isinstance(
item.get("text"), str
):
last_agent_text = item["text"]
saw_agent_message = True
except asyncio.CancelledError:
cancelled = True
if proc.returncode is None:
proc.terminate()
finally:
if cancelled:
task = asyncio.current_task()
if task is not None:
while task.cancelling():
task.uncancel()
if on_event is not None:
try:
res = on_event(evt)
if inspect.isawaitable(res):
await res
except Exception as e:
logger.info("[codex][on_event] callback error: %s", e)
try:
rc = await asyncio.wait_for(proc.wait(), timeout=2.0)
except TimeoutError:
logger.debug(
"[codex] terminate timed out pid=%s, sending kill", proc.pid
)
if proc.returncode is None:
proc.kill()
if evt.get("type") == "thread.started":
found_session = evt.get("thread_id") or found_session
if evt.get("type") == "item.completed":
item = evt.get("item") or {}
if item.get("type") == "agent_message" and isinstance(
item.get("text"), str
):
last_agent_text = item["text"]
saw_agent_message = True
except asyncio.CancelledError:
cancelled = True
finally:
if cancelled:
task = asyncio.current_task()
if task is not None:
while task.cancelling():
task.uncancel()
if not cancelled:
rc = await proc.wait()
else:
rc = await proc.wait()
await asyncio.gather(stderr_task, return_exceptions=True)
await asyncio.gather(stderr_task, return_exceptions=True)
if cancelled:
raise asyncio.CancelledError
if cancelled:
raise asyncio.CancelledError
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
if rc != 0:
tail = "".join(stderr_tail)
raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}")
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
if rc != 0:
tail = "".join(stderr_tail)
raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}")
if not found_session:
raise RuntimeError(
"codex exec finished but no session_id/thread_id was captured"
)
if not found_session:
raise RuntimeError(
"codex exec finished but no session_id/thread_id was captured"
logger.info("[codex] done run session_id=%r", found_session)
return (
found_session,
(last_agent_text or "(No agent_message captured from JSON stream.)"),
saw_agent_message,
)
logger.info("[codex] done run session_id=%r", found_session)
return (
found_session,
(last_agent_text or "(No agent_message captured from JSON stream.)"),
saw_agent_message,
)
async def run_serialized(
self,
prompt: str,