refactor: migrate exec bridge to anyio and harden cancellation (#6)

This commit is contained in:
banteg
2025-12-31 01:51:46 +04:00
committed by GitHub
parent 6687a435c9
commit 8eda3f5e84
9 changed files with 492 additions and 310 deletions
+212 -141
View File
@@ -1,21 +1,24 @@
from __future__ import annotations
import asyncio
import inspect
import json
import logging
import os
import re
import shutil
import subprocess
import time
from collections import deque
from collections.abc import Awaitable, Callable
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, cast
from typing import Any
from weakref import WeakValueDictionary
import anyio
import typer
from anyio.abc import ByteReceiveStream, Process
from anyio.streams.text import TextReceiveStream
from . import __version__
from .config import ConfigError, load_telegram_config
@@ -61,32 +64,59 @@ def resolve_resume_session(text: str | None, reply_text: str | None) -> str | No
return extract_session_id(text) or extract_session_id(reply_text)
async def _drain_stderr(stderr: asyncio.StreamReader, tail: deque[str]) -> None:
try:
async def _iter_text_lines(stream: ByteReceiveStream) -> AsyncIterator[str]:
text_stream = TextReceiveStream(stream, errors="replace")
buffer = ""
while True:
try:
chunk = await text_stream.receive()
except anyio.EndOfStream:
if buffer:
yield buffer
return
buffer += chunk
while True:
line = await stderr.readline()
if not line:
return
decoded = line.decode(errors="replace")
logger.info("[codex][stderr] %s", decoded.rstrip())
tail.append(decoded)
split_at = buffer.find("\n")
if split_at < 0:
break
line = buffer[: split_at + 1]
buffer = buffer[split_at + 1 :]
yield line
async def _drain_stderr(stderr: ByteReceiveStream, tail: deque[str]) -> None:
try:
async for line in _iter_text_lines(stderr):
logger.info("[codex][stderr] %s", line.rstrip())
tail.append(line)
except Exception as e:
logger.debug("[codex][stderr] drain error: %s", e)
async def _wait_for_process(proc: Process, timeout: float) -> bool:
with anyio.move_on_after(timeout) as scope:
await proc.wait()
return scope.cancel_called
@asynccontextmanager
async def manage_subprocess(*args, **kwargs):
proc = await asyncio.create_subprocess_exec(*args, **kwargs)
async def manage_subprocess(*args, terminate_timeout: float = 2.0, **kwargs):
proc = await anyio.open_process(args, **kwargs)
try:
yield proc
finally:
if proc.returncode is None:
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=2.0)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
with anyio.CancelScope(shield=True):
try:
proc.terminate()
except ProcessLookupError:
pass
timed_out = await _wait_for_process(proc, terminate_timeout)
if timed_out:
logger.debug(
"[codex] terminate timed out pid=%s; leaving process to exit",
proc.pid,
)
TELEGRAM_MARKDOWN_LIMIT = 3500
@@ -212,17 +242,17 @@ class ProgressEdits:
self.last_rendered = last_rendered
self._event_seq = 0
self._published_seq = 0
self.wakeup = asyncio.Event()
self.task: asyncio.Task[None] | None = (
asyncio.create_task(self.run()) if self.progress_id is not None else None
)
self.wakeup = anyio.Event()
async def _wait_for_wakeup(self) -> None:
await self.wakeup.wait()
self.wakeup = anyio.Event()
async def run(self) -> None:
if self.progress_id is None:
return
while True:
await self.wakeup.wait()
self.wakeup.clear()
await self._wait_for_wakeup()
while self._published_seq < self._event_seq:
await self.sleep(
max(
@@ -250,7 +280,6 @@ class ProgressEdits:
self.last_rendered = rendered
self._published_seq = seq_at_render
self.wakeup.clear()
async def on_event(self, evt: dict[str, Any]) -> None:
if not self.renderer.note_event(evt):
@@ -260,12 +289,6 @@ class ProgressEdits:
self._event_seq += 1
self.wakeup.set()
async def shutdown(self) -> None:
if self.task is None:
return
self.task.cancel()
await asyncio.gather(self.task, return_exceptions=True)
class CodexExecRunner:
def __init__(
@@ -277,14 +300,14 @@ class CodexExecRunner:
self.extra_args = extra_args
# Per-session locks to prevent concurrent resumes to the same session_id.
self._session_locks: WeakValueDictionary[str, asyncio.Lock] = (
self._session_locks: WeakValueDictionary[str, anyio.Lock] = (
WeakValueDictionary()
)
async def _lock_for(self, session_id: str) -> asyncio.Lock:
async def _lock_for(self, session_id: str) -> anyio.Lock:
lock = self._session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
lock = anyio.Lock()
self._session_locks[session_id] = lock
return lock
@@ -306,19 +329,23 @@ class CodexExecRunner:
else:
args.append("-")
cancelled_exc_type = anyio.get_cancelled_exc_class()
cancelled_exc: BaseException | None = None
async with manage_subprocess(
*args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
proc_stdin = cast(asyncio.StreamWriter, proc.stdin)
proc_stdout = cast(asyncio.StreamReader, proc.stdout)
proc_stderr = cast(asyncio.StreamReader, proc.stderr)
if proc.stdin is None or proc.stdout is None or proc.stderr is None:
raise RuntimeError("codex exec failed to open subprocess pipes")
proc_stdin = proc.stdin
proc_stdout = proc.stdout
proc_stderr = proc.stderr
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
stderr_tail: deque[str] = deque(maxlen=200)
stderr_task = asyncio.create_task(_drain_stderr(proc_stderr, stderr_tail))
rc: int | None = None
found_session: str | None = session_id
last_agent_text: str | None = None
@@ -326,62 +353,57 @@ class CodexExecRunner:
cli_last_item: int | None = None
cancelled = False
rc: int | None = None
async with anyio.create_task_group() as tg:
tg.start_soon(_drain_stderr, proc_stderr, stderr_tail)
try:
proc_stdin.write(prompt.encode())
await proc_stdin.drain()
proc_stdin.close()
try:
await proc_stdin.send(prompt.encode())
await proc_stdin.aclose()
async for raw_line in proc_stdout:
raw = raw_line.decode(errors="replace")
logger.debug("[codex][jsonl] %s", raw.rstrip("\n"))
line = raw.strip()
if not line:
continue
try:
evt = json.loads(line)
except json.JSONDecodeError:
logger.debug("[codex][jsonl] invalid line: %r", line)
continue
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
for out in out_lines:
logger.info("[codex] %s", out)
if on_event is not None:
async for raw_line in _iter_text_lines(proc_stdout):
raw = raw_line.rstrip("\n")
logger.debug("[codex][jsonl] %s", raw)
line = raw.strip()
if not line:
continue
try:
res = on_event(evt)
if inspect.isawaitable(res):
await res
except Exception as e:
logger.info("[codex][on_event] callback error: %s", e)
evt = json.loads(line)
except json.JSONDecodeError:
logger.debug("[codex][jsonl] invalid line: %r", line)
continue
if evt["type"] == "thread.started":
found_session = evt.get("thread_id") or found_session
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
for out in out_lines:
logger.info("[codex] %s", out)
if evt["type"] == "item.completed":
item = evt.get("item") or {}
if item.get("type") == "agent_message" and isinstance(
item.get("text"), str
):
last_agent_text = item["text"]
saw_agent_message = True
except asyncio.CancelledError:
cancelled = True
finally:
if cancelled:
if not stderr_task.done():
stderr_task.cancel()
task = cast(asyncio.Task, asyncio.current_task())
while task.cancelling():
task.uncancel()
if not cancelled:
rc = await proc.wait()
await asyncio.gather(stderr_task, return_exceptions=True)
if on_event is not None:
try:
res = on_event(evt)
if inspect.isawaitable(res):
await res
except Exception as e:
logger.info("[codex][on_event] callback error: %s", e)
if evt["type"] == "thread.started":
found_session = evt.get("thread_id") or found_session
if evt["type"] == "item.completed":
item = evt.get("item") or {}
if item.get("type") == "agent_message" and isinstance(
item.get("text"), str
):
last_agent_text = item["text"]
saw_agent_message = True
except cancelled_exc_type as exc:
cancelled = True
cancelled_exc = exc
tg.cancel_scope.cancel()
finally:
if not cancelled:
rc = await proc.wait()
if cancelled:
raise asyncio.CancelledError
raise cancelled_exc # type: ignore[misc]
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
if rc != 0:
@@ -406,11 +428,31 @@ class CodexExecRunner:
session_id: str | None,
on_event: EventCallback | None = None,
) -> tuple[str, str, bool]:
if not session_id:
return await self.run(prompt, session_id=None, on_event=on_event)
lock = await self._lock_for(session_id)
async with lock:
return await self.run(prompt, session_id=session_id, on_event=on_event)
if session_id:
lock = await self._lock_for(session_id)
async with lock:
return await self.run(prompt, session_id=session_id, on_event=on_event)
session_lock: anyio.Lock | None = None
async def on_event_with_lock(evt: dict[str, Any]) -> None:
nonlocal session_lock
if session_lock is None and evt.get("type") == "thread.started":
thread_id = evt.get("thread_id")
if isinstance(thread_id, str) and thread_id:
session_lock = await self._lock_for(thread_id)
await session_lock.acquire()
if on_event is None:
return
res = on_event(evt)
if inspect.isawaitable(res):
await res
try:
return await self.run(prompt, session_id=None, on_event=on_event_with_lock)
finally:
if session_lock is not None:
session_lock.release()
@dataclass(frozen=True)
@@ -423,6 +465,12 @@ class BridgeConfig:
max_concurrency: int
@dataclass
class RunningTask:
scope: anyio.CancelScope
session_id: str | None = None
def _parse_bridge_config(
*,
final_notify: bool,
@@ -508,9 +556,9 @@ async def handle_message(
user_msg_id: int,
text: str,
resume_session: str | None,
running_tasks: dict[str, asyncio.Task[Any]] | None = None,
running_tasks: dict[int, RunningTask] | None = None,
clock: Callable[[], float] = time.monotonic,
sleep: Callable[[float], Awaitable[None]] = asyncio.sleep,
sleep: Callable[[float], Awaitable[None]] = anyio.sleep,
progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
) -> None:
logger.debug(
@@ -565,31 +613,54 @@ async def handle_message(
last_rendered=last_rendered,
)
exec_task: asyncio.Task[tuple[str, str, bool]] | None = None
exec_scope = anyio.CancelScope()
cancelled = False
error: Exception | None = None
session_id: str | None = None
answer: str | None = None
saw_agent_message: bool | None = None
running_task: RunningTask | None = None
if running_tasks is not None and progress_id is not None:
running_task = RunningTask(scope=exec_scope)
running_tasks[progress_id] = running_task
if resume_session is not None:
running_task.session_id = resume_session
async def on_event(evt: dict[str, Any]) -> None:
if (
evt["type"] == "thread.started"
and running_tasks is not None
and exec_task is not None
running_task is not None
and running_task.session_id is None
and evt.get("type") == "thread.started"
):
running_tasks[evt["thread_id"]] = exec_task
thread_id = evt.get("thread_id")
if isinstance(thread_id, str) and thread_id:
running_task.session_id = thread_id
await edits.on_event(evt)
exec_task = asyncio.create_task(
cfg.runner.run_serialized(text, resume_session, on_event=on_event)
)
async with anyio.create_task_group() as tg:
if progress_id is not None:
tg.start_soon(edits.run)
cancelled = False
try:
session_id, answer, saw_agent_message = await exec_task
except asyncio.CancelledError:
cancelled = True
session_id = progress_renderer.resume_session or resume_session
except Exception as e:
await edits.shutdown()
try:
with exec_scope:
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
text, resume_session, on_event=on_event
)
except Exception as e:
error = e
finally:
if running_task is not None:
if running_tasks is not None and progress_id is not None:
running_tasks.pop(progress_id, None)
if exec_scope.cancelled_caught and not cancelled and error is None:
cancelled = True
session_id = progress_renderer.resume_session or resume_session
if not cancelled and error is None:
await anyio.sleep(0)
tg.cancel_scope.cancel()
err = _clamp_tg_text(f"Error:\n{e}")
if error is not None:
err = _clamp_tg_text(f"Error:\n{error}")
logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err)
await _send_or_edit_markdown(
cfg.bot,
@@ -601,16 +672,11 @@ async def handle_message(
limit=TELEGRAM_MARKDOWN_LIMIT,
)
return
finally:
if running_tasks is not None:
for sid, task in list(running_tasks.items()):
if task is exec_task:
running_tasks.pop(sid, None)
await edits.shutdown()
elapsed = clock() - started_at
if cancelled:
if session_id is None:
session_id = progress_renderer.resume_session or resume_session
logger.info(
"[handle] cancelled session_id=%s elapsed=%.1fs", session_id, elapsed
)
@@ -627,6 +693,9 @@ async def handle_message(
)
return
if session_id is None or answer is None or saw_agent_message is None:
raise RuntimeError("codex exec finished without a result")
status = "done" if saw_agent_message else "error"
progress_renderer.resume_session = session_id
final_md = progress_renderer.render_final(elapsed, answer, status=status)
@@ -679,7 +748,7 @@ async def poll_updates(cfg: BridgeConfig):
)
if updates is None:
logger.info("[loop] getUpdates failed")
await asyncio.sleep(2)
await anyio.sleep(2)
continue
logger.debug("[loop] updates: %s", updates)
@@ -696,7 +765,7 @@ async def poll_updates(cfg: BridgeConfig):
async def _handle_cancel(
cfg: BridgeConfig,
msg: dict[str, Any],
running_tasks: dict[str, asyncio.Task[Any]],
running_tasks: dict[int, RunningTask],
) -> None:
chat_id = msg["chat"]["id"]
user_msg_id = msg["message_id"]
@@ -710,8 +779,8 @@ async def _handle_cancel(
)
return
session_id = extract_session_id(reply.get("text"))
if not session_id:
progress_id = reply.get("message_id")
if progress_id is None:
await cfg.bot.send_message(
chat_id=chat_id,
text="nothing is currently running for that message.",
@@ -719,8 +788,8 @@ async def _handle_cancel(
)
return
task = running_tasks.get(session_id)
if not task or task.done():
running_task = running_tasks.get(int(progress_id))
if running_task is None:
await cfg.bot.send_message(
chat_id=chat_id,
text="nothing is currently running for that message.",
@@ -728,20 +797,20 @@ async def _handle_cancel(
)
return
logger.info("[cancel] cancelling session_id=%s", session_id)
task.cancel()
logger.info("[cancel] cancelling progress_message_id=%s", progress_id)
running_task.scope.cancel()
async def _run_main_loop(cfg: BridgeConfig) -> None:
worker_count = max(1, min(cfg.max_concurrency, 16))
queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue(
maxsize=worker_count * 2
send_stream, receive_stream = anyio.create_memory_object_stream(
max_buffer_size=worker_count * 2
)
running_tasks: dict[str, asyncio.Task[Any]] = {}
running_tasks: dict[int, RunningTask] = {}
async def worker() -> None:
while True:
chat_id, user_msg_id, text, resume_session = await queue.get()
chat_id, user_msg_id, text, resume_session = await receive_stream.receive()
try:
await handle_message(
cfg,
@@ -753,26 +822,28 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
)
except Exception:
logger.exception("[handle] worker failed")
finally:
queue.task_done()
try:
async with asyncio.TaskGroup() as tg:
async with anyio.create_task_group() as tg:
for _ in range(worker_count):
tg.create_task(worker())
tg.start_soon(worker)
async for msg in poll_updates(cfg):
text = msg["text"]
user_msg_id = msg["message_id"]
if text == "/cancel":
tg.create_task(_handle_cancel(cfg, msg, running_tasks))
tg.start_soon(_handle_cancel, cfg, msg, running_tasks)
continue
r = msg.get("reply_to_message") or {}
resume_session = resolve_resume_session(text, r.get("text"))
await queue.put((msg["chat"]["id"], user_msg_id, text, resume_session))
await send_stream.send(
(msg["chat"]["id"], user_msg_id, text, resume_session)
)
finally:
await send_stream.aclose()
await receive_stream.aclose()
await cfg.bot.close()
@@ -813,7 +884,7 @@ def run(
except ConfigError as e:
typer.echo(str(e), err=True)
raise typer.Exit(code=1)
asyncio.run(_run_main_loop(cfg))
anyio.run(_run_main_loop, cfg)
def main() -> None: