refactor: migrate exec bridge to anyio and harden cancellation (#6)
This commit is contained in:
+212
-141
@@ -1,21 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import anyio
|
||||
import typer
|
||||
from anyio.abc import ByteReceiveStream, Process
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
|
||||
from . import __version__
|
||||
from .config import ConfigError, load_telegram_config
|
||||
@@ -61,32 +64,59 @@ def resolve_resume_session(text: str | None, reply_text: str | None) -> str | No
|
||||
return extract_session_id(text) or extract_session_id(reply_text)
|
||||
|
||||
|
||||
async def _drain_stderr(stderr: asyncio.StreamReader, tail: deque[str]) -> None:
|
||||
try:
|
||||
async def _iter_text_lines(stream: ByteReceiveStream) -> AsyncIterator[str]:
|
||||
text_stream = TextReceiveStream(stream, errors="replace")
|
||||
buffer = ""
|
||||
while True:
|
||||
try:
|
||||
chunk = await text_stream.receive()
|
||||
except anyio.EndOfStream:
|
||||
if buffer:
|
||||
yield buffer
|
||||
return
|
||||
buffer += chunk
|
||||
while True:
|
||||
line = await stderr.readline()
|
||||
if not line:
|
||||
return
|
||||
decoded = line.decode(errors="replace")
|
||||
logger.info("[codex][stderr] %s", decoded.rstrip())
|
||||
tail.append(decoded)
|
||||
split_at = buffer.find("\n")
|
||||
if split_at < 0:
|
||||
break
|
||||
line = buffer[: split_at + 1]
|
||||
buffer = buffer[split_at + 1 :]
|
||||
yield line
|
||||
|
||||
|
||||
async def _drain_stderr(stderr: ByteReceiveStream, tail: deque[str]) -> None:
|
||||
try:
|
||||
async for line in _iter_text_lines(stderr):
|
||||
logger.info("[codex][stderr] %s", line.rstrip())
|
||||
tail.append(line)
|
||||
except Exception as e:
|
||||
logger.debug("[codex][stderr] drain error: %s", e)
|
||||
|
||||
|
||||
async def _wait_for_process(proc: Process, timeout: float) -> bool:
|
||||
with anyio.move_on_after(timeout) as scope:
|
||||
await proc.wait()
|
||||
return scope.cancel_called
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def manage_subprocess(*args, **kwargs):
|
||||
proc = await asyncio.create_subprocess_exec(*args, **kwargs)
|
||||
async def manage_subprocess(*args, terminate_timeout: float = 2.0, **kwargs):
|
||||
proc = await anyio.open_process(args, **kwargs)
|
||||
try:
|
||||
yield proc
|
||||
finally:
|
||||
if proc.returncode is None:
|
||||
proc.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
with anyio.CancelScope(shield=True):
|
||||
try:
|
||||
proc.terminate()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
timed_out = await _wait_for_process(proc, terminate_timeout)
|
||||
if timed_out:
|
||||
logger.debug(
|
||||
"[codex] terminate timed out pid=%s; leaving process to exit",
|
||||
proc.pid,
|
||||
)
|
||||
|
||||
|
||||
TELEGRAM_MARKDOWN_LIMIT = 3500
|
||||
@@ -212,17 +242,17 @@ class ProgressEdits:
|
||||
self.last_rendered = last_rendered
|
||||
self._event_seq = 0
|
||||
self._published_seq = 0
|
||||
self.wakeup = asyncio.Event()
|
||||
self.task: asyncio.Task[None] | None = (
|
||||
asyncio.create_task(self.run()) if self.progress_id is not None else None
|
||||
)
|
||||
self.wakeup = anyio.Event()
|
||||
|
||||
async def _wait_for_wakeup(self) -> None:
|
||||
await self.wakeup.wait()
|
||||
self.wakeup = anyio.Event()
|
||||
|
||||
async def run(self) -> None:
|
||||
if self.progress_id is None:
|
||||
return
|
||||
while True:
|
||||
await self.wakeup.wait()
|
||||
self.wakeup.clear()
|
||||
await self._wait_for_wakeup()
|
||||
while self._published_seq < self._event_seq:
|
||||
await self.sleep(
|
||||
max(
|
||||
@@ -250,7 +280,6 @@ class ProgressEdits:
|
||||
self.last_rendered = rendered
|
||||
|
||||
self._published_seq = seq_at_render
|
||||
self.wakeup.clear()
|
||||
|
||||
async def on_event(self, evt: dict[str, Any]) -> None:
|
||||
if not self.renderer.note_event(evt):
|
||||
@@ -260,12 +289,6 @@ class ProgressEdits:
|
||||
self._event_seq += 1
|
||||
self.wakeup.set()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.task is None:
|
||||
return
|
||||
self.task.cancel()
|
||||
await asyncio.gather(self.task, return_exceptions=True)
|
||||
|
||||
|
||||
class CodexExecRunner:
|
||||
def __init__(
|
||||
@@ -277,14 +300,14 @@ class CodexExecRunner:
|
||||
self.extra_args = extra_args
|
||||
|
||||
# Per-session locks to prevent concurrent resumes to the same session_id.
|
||||
self._session_locks: WeakValueDictionary[str, asyncio.Lock] = (
|
||||
self._session_locks: WeakValueDictionary[str, anyio.Lock] = (
|
||||
WeakValueDictionary()
|
||||
)
|
||||
|
||||
async def _lock_for(self, session_id: str) -> asyncio.Lock:
|
||||
async def _lock_for(self, session_id: str) -> anyio.Lock:
|
||||
lock = self._session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
lock = anyio.Lock()
|
||||
self._session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
@@ -306,19 +329,23 @@ class CodexExecRunner:
|
||||
else:
|
||||
args.append("-")
|
||||
|
||||
cancelled_exc_type = anyio.get_cancelled_exc_class()
|
||||
cancelled_exc: BaseException | None = None
|
||||
async with manage_subprocess(
|
||||
*args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as proc:
|
||||
proc_stdin = cast(asyncio.StreamWriter, proc.stdin)
|
||||
proc_stdout = cast(asyncio.StreamReader, proc.stdout)
|
||||
proc_stderr = cast(asyncio.StreamReader, proc.stderr)
|
||||
if proc.stdin is None or proc.stdout is None or proc.stderr is None:
|
||||
raise RuntimeError("codex exec failed to open subprocess pipes")
|
||||
proc_stdin = proc.stdin
|
||||
proc_stdout = proc.stdout
|
||||
proc_stderr = proc.stderr
|
||||
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
||||
|
||||
stderr_tail: deque[str] = deque(maxlen=200)
|
||||
stderr_task = asyncio.create_task(_drain_stderr(proc_stderr, stderr_tail))
|
||||
rc: int | None = None
|
||||
|
||||
found_session: str | None = session_id
|
||||
last_agent_text: str | None = None
|
||||
@@ -326,62 +353,57 @@ class CodexExecRunner:
|
||||
cli_last_item: int | None = None
|
||||
|
||||
cancelled = False
|
||||
rc: int | None = None
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(_drain_stderr, proc_stderr, stderr_tail)
|
||||
|
||||
try:
|
||||
proc_stdin.write(prompt.encode())
|
||||
await proc_stdin.drain()
|
||||
proc_stdin.close()
|
||||
try:
|
||||
await proc_stdin.send(prompt.encode())
|
||||
await proc_stdin.aclose()
|
||||
|
||||
async for raw_line in proc_stdout:
|
||||
raw = raw_line.decode(errors="replace")
|
||||
logger.debug("[codex][jsonl] %s", raw.rstrip("\n"))
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
evt = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("[codex][jsonl] invalid line: %r", line)
|
||||
continue
|
||||
|
||||
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
|
||||
for out in out_lines:
|
||||
logger.info("[codex] %s", out)
|
||||
|
||||
if on_event is not None:
|
||||
async for raw_line in _iter_text_lines(proc_stdout):
|
||||
raw = raw_line.rstrip("\n")
|
||||
logger.debug("[codex][jsonl] %s", raw)
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
res = on_event(evt)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
except Exception as e:
|
||||
logger.info("[codex][on_event] callback error: %s", e)
|
||||
evt = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("[codex][jsonl] invalid line: %r", line)
|
||||
continue
|
||||
|
||||
if evt["type"] == "thread.started":
|
||||
found_session = evt.get("thread_id") or found_session
|
||||
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
|
||||
for out in out_lines:
|
||||
logger.info("[codex] %s", out)
|
||||
|
||||
if evt["type"] == "item.completed":
|
||||
item = evt.get("item") or {}
|
||||
if item.get("type") == "agent_message" and isinstance(
|
||||
item.get("text"), str
|
||||
):
|
||||
last_agent_text = item["text"]
|
||||
saw_agent_message = True
|
||||
except asyncio.CancelledError:
|
||||
cancelled = True
|
||||
finally:
|
||||
if cancelled:
|
||||
if not stderr_task.done():
|
||||
stderr_task.cancel()
|
||||
task = cast(asyncio.Task, asyncio.current_task())
|
||||
while task.cancelling():
|
||||
task.uncancel()
|
||||
if not cancelled:
|
||||
rc = await proc.wait()
|
||||
await asyncio.gather(stderr_task, return_exceptions=True)
|
||||
if on_event is not None:
|
||||
try:
|
||||
res = on_event(evt)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
except Exception as e:
|
||||
logger.info("[codex][on_event] callback error: %s", e)
|
||||
|
||||
if evt["type"] == "thread.started":
|
||||
found_session = evt.get("thread_id") or found_session
|
||||
|
||||
if evt["type"] == "item.completed":
|
||||
item = evt.get("item") or {}
|
||||
if item.get("type") == "agent_message" and isinstance(
|
||||
item.get("text"), str
|
||||
):
|
||||
last_agent_text = item["text"]
|
||||
saw_agent_message = True
|
||||
except cancelled_exc_type as exc:
|
||||
cancelled = True
|
||||
cancelled_exc = exc
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
if not cancelled:
|
||||
rc = await proc.wait()
|
||||
|
||||
if cancelled:
|
||||
raise asyncio.CancelledError
|
||||
raise cancelled_exc # type: ignore[misc]
|
||||
|
||||
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
|
||||
if rc != 0:
|
||||
@@ -406,11 +428,31 @@ class CodexExecRunner:
|
||||
session_id: str | None,
|
||||
on_event: EventCallback | None = None,
|
||||
) -> tuple[str, str, bool]:
|
||||
if not session_id:
|
||||
return await self.run(prompt, session_id=None, on_event=on_event)
|
||||
lock = await self._lock_for(session_id)
|
||||
async with lock:
|
||||
return await self.run(prompt, session_id=session_id, on_event=on_event)
|
||||
if session_id:
|
||||
lock = await self._lock_for(session_id)
|
||||
async with lock:
|
||||
return await self.run(prompt, session_id=session_id, on_event=on_event)
|
||||
|
||||
session_lock: anyio.Lock | None = None
|
||||
|
||||
async def on_event_with_lock(evt: dict[str, Any]) -> None:
|
||||
nonlocal session_lock
|
||||
if session_lock is None and evt.get("type") == "thread.started":
|
||||
thread_id = evt.get("thread_id")
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
session_lock = await self._lock_for(thread_id)
|
||||
await session_lock.acquire()
|
||||
if on_event is None:
|
||||
return
|
||||
res = on_event(evt)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
try:
|
||||
return await self.run(prompt, session_id=None, on_event=on_event_with_lock)
|
||||
finally:
|
||||
if session_lock is not None:
|
||||
session_lock.release()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -423,6 +465,12 @@ class BridgeConfig:
|
||||
max_concurrency: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunningTask:
|
||||
scope: anyio.CancelScope
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
def _parse_bridge_config(
|
||||
*,
|
||||
final_notify: bool,
|
||||
@@ -508,9 +556,9 @@ async def handle_message(
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
resume_session: str | None,
|
||||
running_tasks: dict[str, asyncio.Task[Any]] | None = None,
|
||||
running_tasks: dict[int, RunningTask] | None = None,
|
||||
clock: Callable[[], float] = time.monotonic,
|
||||
sleep: Callable[[float], Awaitable[None]] = asyncio.sleep,
|
||||
sleep: Callable[[float], Awaitable[None]] = anyio.sleep,
|
||||
progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
@@ -565,31 +613,54 @@ async def handle_message(
|
||||
last_rendered=last_rendered,
|
||||
)
|
||||
|
||||
exec_task: asyncio.Task[tuple[str, str, bool]] | None = None
|
||||
exec_scope = anyio.CancelScope()
|
||||
cancelled = False
|
||||
error: Exception | None = None
|
||||
session_id: str | None = None
|
||||
answer: str | None = None
|
||||
saw_agent_message: bool | None = None
|
||||
running_task: RunningTask | None = None
|
||||
if running_tasks is not None and progress_id is not None:
|
||||
running_task = RunningTask(scope=exec_scope)
|
||||
running_tasks[progress_id] = running_task
|
||||
if resume_session is not None:
|
||||
running_task.session_id = resume_session
|
||||
|
||||
async def on_event(evt: dict[str, Any]) -> None:
|
||||
if (
|
||||
evt["type"] == "thread.started"
|
||||
and running_tasks is not None
|
||||
and exec_task is not None
|
||||
running_task is not None
|
||||
and running_task.session_id is None
|
||||
and evt.get("type") == "thread.started"
|
||||
):
|
||||
running_tasks[evt["thread_id"]] = exec_task
|
||||
thread_id = evt.get("thread_id")
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
running_task.session_id = thread_id
|
||||
await edits.on_event(evt)
|
||||
|
||||
exec_task = asyncio.create_task(
|
||||
cfg.runner.run_serialized(text, resume_session, on_event=on_event)
|
||||
)
|
||||
async with anyio.create_task_group() as tg:
|
||||
if progress_id is not None:
|
||||
tg.start_soon(edits.run)
|
||||
|
||||
cancelled = False
|
||||
try:
|
||||
session_id, answer, saw_agent_message = await exec_task
|
||||
except asyncio.CancelledError:
|
||||
cancelled = True
|
||||
session_id = progress_renderer.resume_session or resume_session
|
||||
except Exception as e:
|
||||
await edits.shutdown()
|
||||
try:
|
||||
with exec_scope:
|
||||
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
|
||||
text, resume_session, on_event=on_event
|
||||
)
|
||||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
if running_task is not None:
|
||||
if running_tasks is not None and progress_id is not None:
|
||||
running_tasks.pop(progress_id, None)
|
||||
if exec_scope.cancelled_caught and not cancelled and error is None:
|
||||
cancelled = True
|
||||
session_id = progress_renderer.resume_session or resume_session
|
||||
if not cancelled and error is None:
|
||||
await anyio.sleep(0)
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
err = _clamp_tg_text(f"Error:\n{e}")
|
||||
if error is not None:
|
||||
err = _clamp_tg_text(f"Error:\n{error}")
|
||||
logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err)
|
||||
await _send_or_edit_markdown(
|
||||
cfg.bot,
|
||||
@@ -601,16 +672,11 @@ async def handle_message(
|
||||
limit=TELEGRAM_MARKDOWN_LIMIT,
|
||||
)
|
||||
return
|
||||
finally:
|
||||
if running_tasks is not None:
|
||||
for sid, task in list(running_tasks.items()):
|
||||
if task is exec_task:
|
||||
running_tasks.pop(sid, None)
|
||||
|
||||
await edits.shutdown()
|
||||
|
||||
elapsed = clock() - started_at
|
||||
if cancelled:
|
||||
if session_id is None:
|
||||
session_id = progress_renderer.resume_session or resume_session
|
||||
logger.info(
|
||||
"[handle] cancelled session_id=%s elapsed=%.1fs", session_id, elapsed
|
||||
)
|
||||
@@ -627,6 +693,9 @@ async def handle_message(
|
||||
)
|
||||
return
|
||||
|
||||
if session_id is None or answer is None or saw_agent_message is None:
|
||||
raise RuntimeError("codex exec finished without a result")
|
||||
|
||||
status = "done" if saw_agent_message else "error"
|
||||
progress_renderer.resume_session = session_id
|
||||
final_md = progress_renderer.render_final(elapsed, answer, status=status)
|
||||
@@ -679,7 +748,7 @@ async def poll_updates(cfg: BridgeConfig):
|
||||
)
|
||||
if updates is None:
|
||||
logger.info("[loop] getUpdates failed")
|
||||
await asyncio.sleep(2)
|
||||
await anyio.sleep(2)
|
||||
continue
|
||||
logger.debug("[loop] updates: %s", updates)
|
||||
|
||||
@@ -696,7 +765,7 @@ async def poll_updates(cfg: BridgeConfig):
|
||||
async def _handle_cancel(
|
||||
cfg: BridgeConfig,
|
||||
msg: dict[str, Any],
|
||||
running_tasks: dict[str, asyncio.Task[Any]],
|
||||
running_tasks: dict[int, RunningTask],
|
||||
) -> None:
|
||||
chat_id = msg["chat"]["id"]
|
||||
user_msg_id = msg["message_id"]
|
||||
@@ -710,8 +779,8 @@ async def _handle_cancel(
|
||||
)
|
||||
return
|
||||
|
||||
session_id = extract_session_id(reply.get("text"))
|
||||
if not session_id:
|
||||
progress_id = reply.get("message_id")
|
||||
if progress_id is None:
|
||||
await cfg.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text="nothing is currently running for that message.",
|
||||
@@ -719,8 +788,8 @@ async def _handle_cancel(
|
||||
)
|
||||
return
|
||||
|
||||
task = running_tasks.get(session_id)
|
||||
if not task or task.done():
|
||||
running_task = running_tasks.get(int(progress_id))
|
||||
if running_task is None:
|
||||
await cfg.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text="nothing is currently running for that message.",
|
||||
@@ -728,20 +797,20 @@ async def _handle_cancel(
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("[cancel] cancelling session_id=%s", session_id)
|
||||
task.cancel()
|
||||
logger.info("[cancel] cancelling progress_message_id=%s", progress_id)
|
||||
running_task.scope.cancel()
|
||||
|
||||
|
||||
async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
worker_count = max(1, min(cfg.max_concurrency, 16))
|
||||
queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue(
|
||||
maxsize=worker_count * 2
|
||||
send_stream, receive_stream = anyio.create_memory_object_stream(
|
||||
max_buffer_size=worker_count * 2
|
||||
)
|
||||
running_tasks: dict[str, asyncio.Task[Any]] = {}
|
||||
running_tasks: dict[int, RunningTask] = {}
|
||||
|
||||
async def worker() -> None:
|
||||
while True:
|
||||
chat_id, user_msg_id, text, resume_session = await queue.get()
|
||||
chat_id, user_msg_id, text, resume_session = await receive_stream.receive()
|
||||
try:
|
||||
await handle_message(
|
||||
cfg,
|
||||
@@ -753,26 +822,28 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[handle] worker failed")
|
||||
finally:
|
||||
queue.task_done()
|
||||
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
async with anyio.create_task_group() as tg:
|
||||
for _ in range(worker_count):
|
||||
tg.create_task(worker())
|
||||
tg.start_soon(worker)
|
||||
async for msg in poll_updates(cfg):
|
||||
text = msg["text"]
|
||||
user_msg_id = msg["message_id"]
|
||||
|
||||
if text == "/cancel":
|
||||
tg.create_task(_handle_cancel(cfg, msg, running_tasks))
|
||||
tg.start_soon(_handle_cancel, cfg, msg, running_tasks)
|
||||
continue
|
||||
|
||||
r = msg.get("reply_to_message") or {}
|
||||
resume_session = resolve_resume_session(text, r.get("text"))
|
||||
|
||||
await queue.put((msg["chat"]["id"], user_msg_id, text, resume_session))
|
||||
await send_stream.send(
|
||||
(msg["chat"]["id"], user_msg_id, text, resume_session)
|
||||
)
|
||||
finally:
|
||||
await send_stream.aclose()
|
||||
await receive_stream.aclose()
|
||||
await cfg.bot.close()
|
||||
|
||||
|
||||
@@ -813,7 +884,7 @@ def run(
|
||||
except ConfigError as e:
|
||||
typer.echo(str(e), err=True)
|
||||
raise typer.Exit(code=1)
|
||||
asyncio.run(_run_main_loop(cfg))
|
||||
anyio.run(_run_main_loop, cfg)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
Reference in New Issue
Block a user