refactor(exec_bridge): manage subprocess context

This commit is contained in:
banteg
2025-12-29 14:10:48 +04:00
parent b8ecf044a1
commit fc9f33c24c
@@ -11,6 +11,7 @@ import shutil
import time import time
from collections import deque from collections import deque
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
@@ -56,6 +57,21 @@ async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -
logger.debug("[codex][stderr] drain error: %s", e) logger.debug("[codex][stderr] drain error: %s", e)
@asynccontextmanager
async def manage_subprocess(*args, **kwargs):
proc = await asyncio.create_subprocess_exec(*args, **kwargs)
try:
yield proc
finally:
if proc.returncode is None:
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=2.0)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
TELEGRAM_TEXT_LIMIT = TELEGRAM_HARD_LIMIT TELEGRAM_TEXT_LIMIT = TELEGRAM_HARD_LIMIT
TELEGRAM_MARKDOWN_LIMIT = 3500 TELEGRAM_MARKDOWN_LIMIT = 3500
@@ -187,12 +203,12 @@ class CodexExecRunner:
else: else:
args.append("-") args.append("-")
proc = await asyncio.create_subprocess_exec( async with manage_subprocess(
*args, *args,
stdin=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
) ) as proc:
assert proc.stdin and proc.stdout and proc.stderr assert proc.stdin and proc.stdout and proc.stderr
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args) logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
@@ -248,27 +264,14 @@ class CodexExecRunner:
saw_agent_message = True saw_agent_message = True
except asyncio.CancelledError: except asyncio.CancelledError:
cancelled = True cancelled = True
if proc.returncode is None:
proc.terminate()
finally: finally:
if cancelled: if cancelled:
task = asyncio.current_task() task = asyncio.current_task()
if task is not None: if task is not None:
while task.cancelling(): while task.cancelling():
task.uncancel() task.uncancel()
if not cancelled:
try:
rc = await asyncio.wait_for(proc.wait(), timeout=2.0)
except TimeoutError:
logger.debug(
"[codex] terminate timed out pid=%s, sending kill", proc.pid
)
if proc.returncode is None:
proc.kill()
rc = await proc.wait() rc = await proc.wait()
else:
rc = await proc.wait()
await asyncio.gather(stderr_task, return_exceptions=True) await asyncio.gather(stderr_task, return_exceptions=True)
if cancelled: if cancelled: