refactor(exec_bridge): manage subprocess context
This commit is contained in:
@@ -11,6 +11,7 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
@@ -56,6 +57,21 @@ async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -
|
|||||||
logger.debug("[codex][stderr] drain error: %s", e)
|
logger.debug("[codex][stderr] drain error: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def manage_subprocess(*args, **kwargs):
|
||||||
|
proc = await asyncio.create_subprocess_exec(*args, **kwargs)
|
||||||
|
try:
|
||||||
|
yield proc
|
||||||
|
finally:
|
||||||
|
if proc.returncode is None:
|
||||||
|
proc.terminate()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(proc.wait(), timeout=2.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
proc.kill()
|
||||||
|
await proc.wait()
|
||||||
|
|
||||||
|
|
||||||
TELEGRAM_TEXT_LIMIT = TELEGRAM_HARD_LIMIT
|
TELEGRAM_TEXT_LIMIT = TELEGRAM_HARD_LIMIT
|
||||||
TELEGRAM_MARKDOWN_LIMIT = 3500
|
TELEGRAM_MARKDOWN_LIMIT = 3500
|
||||||
|
|
||||||
@@ -187,110 +203,97 @@ class CodexExecRunner:
|
|||||||
else:
|
else:
|
||||||
args.append("-")
|
args.append("-")
|
||||||
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
async with manage_subprocess(
|
||||||
*args,
|
*args,
|
||||||
stdin=asyncio.subprocess.PIPE,
|
stdin=asyncio.subprocess.PIPE,
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
) as proc:
|
||||||
assert proc.stdin and proc.stdout and proc.stderr
|
assert proc.stdin and proc.stdout and proc.stderr
|
||||||
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
||||||
|
|
||||||
stderr_tail: deque[str] = deque(maxlen=200)
|
stderr_tail: deque[str] = deque(maxlen=200)
|
||||||
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
|
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
|
||||||
|
|
||||||
found_session: str | None = session_id
|
found_session: str | None = session_id
|
||||||
last_agent_text: str | None = None
|
last_agent_text: str | None = None
|
||||||
saw_agent_message = False
|
saw_agent_message = False
|
||||||
cli_last_item: int | None = None
|
cli_last_item: int | None = None
|
||||||
|
|
||||||
cancelled = False
|
cancelled = False
|
||||||
rc: int | None = None
|
rc: int | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
proc.stdin.write(prompt.encode())
|
proc.stdin.write(prompt.encode())
|
||||||
await proc.stdin.drain()
|
await proc.stdin.drain()
|
||||||
proc.stdin.close()
|
proc.stdin.close()
|
||||||
|
|
||||||
async for raw_line in proc.stdout:
|
async for raw_line in proc.stdout:
|
||||||
raw = raw_line.decode(errors="replace")
|
raw = raw_line.decode(errors="replace")
|
||||||
logger.debug("[codex][jsonl] %s", raw.rstrip("\n"))
|
logger.debug("[codex][jsonl] %s", raw.rstrip("\n"))
|
||||||
line = raw.strip()
|
line = raw.strip()
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
try:
|
|
||||||
evt = json.loads(line)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.debug("[codex][jsonl] invalid line: %r", line)
|
|
||||||
continue
|
|
||||||
|
|
||||||
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
|
|
||||||
for out in out_lines:
|
|
||||||
logger.info("[codex] %s", out)
|
|
||||||
|
|
||||||
if on_event is not None:
|
|
||||||
try:
|
try:
|
||||||
res = on_event(evt)
|
evt = json.loads(line)
|
||||||
if inspect.isawaitable(res):
|
except json.JSONDecodeError:
|
||||||
await res
|
logger.debug("[codex][jsonl] invalid line: %r", line)
|
||||||
except Exception as e:
|
continue
|
||||||
logger.info("[codex][on_event] callback error: %s", e)
|
|
||||||
|
|
||||||
if evt.get("type") == "thread.started":
|
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
|
||||||
found_session = evt.get("thread_id") or found_session
|
for out in out_lines:
|
||||||
|
logger.info("[codex] %s", out)
|
||||||
|
|
||||||
if evt.get("type") == "item.completed":
|
if on_event is not None:
|
||||||
item = evt.get("item") or {}
|
try:
|
||||||
if item.get("type") == "agent_message" and isinstance(
|
res = on_event(evt)
|
||||||
item.get("text"), str
|
if inspect.isawaitable(res):
|
||||||
):
|
await res
|
||||||
last_agent_text = item["text"]
|
except Exception as e:
|
||||||
saw_agent_message = True
|
logger.info("[codex][on_event] callback error: %s", e)
|
||||||
except asyncio.CancelledError:
|
|
||||||
cancelled = True
|
|
||||||
if proc.returncode is None:
|
|
||||||
proc.terminate()
|
|
||||||
finally:
|
|
||||||
if cancelled:
|
|
||||||
task = asyncio.current_task()
|
|
||||||
if task is not None:
|
|
||||||
while task.cancelling():
|
|
||||||
task.uncancel()
|
|
||||||
|
|
||||||
try:
|
if evt.get("type") == "thread.started":
|
||||||
rc = await asyncio.wait_for(proc.wait(), timeout=2.0)
|
found_session = evt.get("thread_id") or found_session
|
||||||
except TimeoutError:
|
|
||||||
logger.debug(
|
if evt.get("type") == "item.completed":
|
||||||
"[codex] terminate timed out pid=%s, sending kill", proc.pid
|
item = evt.get("item") or {}
|
||||||
)
|
if item.get("type") == "agent_message" and isinstance(
|
||||||
if proc.returncode is None:
|
item.get("text"), str
|
||||||
proc.kill()
|
):
|
||||||
|
last_agent_text = item["text"]
|
||||||
|
saw_agent_message = True
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
cancelled = True
|
||||||
|
finally:
|
||||||
|
if cancelled:
|
||||||
|
task = asyncio.current_task()
|
||||||
|
if task is not None:
|
||||||
|
while task.cancelling():
|
||||||
|
task.uncancel()
|
||||||
|
if not cancelled:
|
||||||
rc = await proc.wait()
|
rc = await proc.wait()
|
||||||
else:
|
await asyncio.gather(stderr_task, return_exceptions=True)
|
||||||
rc = await proc.wait()
|
|
||||||
|
|
||||||
await asyncio.gather(stderr_task, return_exceptions=True)
|
if cancelled:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
if cancelled:
|
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
|
||||||
raise asyncio.CancelledError
|
if rc != 0:
|
||||||
|
tail = "".join(stderr_tail)
|
||||||
|
raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}")
|
||||||
|
|
||||||
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
|
if not found_session:
|
||||||
if rc != 0:
|
raise RuntimeError(
|
||||||
tail = "".join(stderr_tail)
|
"codex exec finished but no session_id/thread_id was captured"
|
||||||
raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}")
|
)
|
||||||
|
|
||||||
if not found_session:
|
logger.info("[codex] done run session_id=%r", found_session)
|
||||||
raise RuntimeError(
|
return (
|
||||||
"codex exec finished but no session_id/thread_id was captured"
|
found_session,
|
||||||
|
(last_agent_text or "(No agent_message captured from JSON stream.)"),
|
||||||
|
saw_agent_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("[codex] done run session_id=%r", found_session)
|
|
||||||
return (
|
|
||||||
found_session,
|
|
||||||
(last_agent_text or "(No agent_message captured from JSON stream.)"),
|
|
||||||
saw_agent_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run_serialized(
|
async def run_serialized(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|||||||
Reference in New Issue
Block a user