From 8eda3f5e84f960e6961ee1e05ae24a23752e16e7 Mon Sep 17 00:00:00 2001 From: banteg <4562643+banteg@users.noreply.github.com> Date: Wed, 31 Dec 2025 01:51:46 +0400 Subject: [PATCH] refactor: migrate exec bridge to anyio and harden cancellation (#6) --- developing.md | 6 +- pyproject.toml | 2 + src/takopi/exec_bridge.py | 353 ++++++++++++++++++++-------------- tests/conftest.py | 7 + tests/test_exec_bridge.py | 263 ++++++++++++++----------- tests/test_exec_runner.py | 79 ++++++-- tests/test_subprocess.py | 36 ++-- tests/test_telegram_client.py | 39 ++-- uv.lock | 17 ++ 9 files changed, 492 insertions(+), 310 deletions(-) diff --git a/developing.md b/developing.md index 95b04d7..a5c893b 100644 --- a/developing.md +++ b/developing.md @@ -44,8 +44,8 @@ The orchestrator module containing: **Key patterns:** - Per-session locks prevent concurrent resumes to the same `session_id` -- Worker pool with `asyncio.Queue` limits concurrency (default: 16 workers) -- `asyncio.TaskGroup` manages worker tasks +- Worker pool with an AnyIO memory stream limits concurrency (default: 16 workers) +- AnyIO task groups manage worker tasks - Progress edits are throttled to ~2s intervals - Subprocess stderr is drained to a bounded deque for error reporting - `poll_updates()` uses Telegram `getUpdates` long-polling with a single server-side updates @@ -154,5 +154,5 @@ Same as above, but: |----------|----------| | `codex exec` fails (rc≠0) | Shows stderr tail in error message | | Telegram API error | Logged, edit skipped (progress continues) | -| Cancellation | Subprocess terminated, CancelledError re-raised | +| Cancellation | Cancel scope triggers terminate; cancellation is detected via `cancelled_caught` | | No agent_message | Final shows "error" status | diff --git a/pyproject.toml b/pyproject.toml index 4c1de0f..d43b72a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ readme = "readme.md" license = { file = "LICENSE" } requires-python = ">=3.12" dependencies = [ + "anyio>=4.12.0", "httpx>=0.28.1", "markdown-it-py", "rich>=14.2.0", @@ -38,6 +39,7 @@ build-backend = "uv_build" [dependency-groups] dev = [ "pytest>=9.0.2", + "pytest-anyio>=0.0.0", "pytest-cov>=7.0.0", "ruff>=0.14.10", "ty>=0.0.8", diff --git a/src/takopi/exec_bridge.py b/src/takopi/exec_bridge.py index 9448c65..71cd60e 100644 --- a/src/takopi/exec_bridge.py +++ b/src/takopi/exec_bridge.py @@ -1,21 +1,24 @@ from __future__ import annotations -import asyncio import inspect import json import logging import os import re import shutil +import subprocess import time from collections import deque -from collections.abc import Awaitable, Callable +from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, cast +from typing import Any from weakref import WeakValueDictionary +import anyio import typer +from anyio.abc import ByteReceiveStream, Process +from anyio.streams.text import TextReceiveStream from . import __version__ from .config import ConfigError, load_telegram_config @@ -61,32 +64,59 @@ def resolve_resume_session(text: str | None, reply_text: str | None) -> str | No return extract_session_id(text) or extract_session_id(reply_text) -async def _drain_stderr(stderr: asyncio.StreamReader, tail: deque[str]) -> None: - try: +async def _iter_text_lines(stream: ByteReceiveStream) -> AsyncIterator[str]: + text_stream = TextReceiveStream(stream, errors="replace") + buffer = "" + while True: + try: + chunk = await text_stream.receive() + except anyio.EndOfStream: + if buffer: + yield buffer + return + buffer += chunk while True: - line = await stderr.readline() - if not line: - return - decoded = line.decode(errors="replace") - logger.info("[codex][stderr] %s", decoded.rstrip()) - tail.append(decoded) + split_at = buffer.find("\n") + if split_at < 0: + break + line = buffer[: split_at + 1] + buffer = buffer[split_at + 1 :] + yield line + + +async def _drain_stderr(stderr: ByteReceiveStream, tail: deque[str]) -> None: + try: + async for line in _iter_text_lines(stderr): + logger.info("[codex][stderr] %s", line.rstrip()) + tail.append(line) except Exception as e: logger.debug("[codex][stderr] drain error: %s", e) +async def _wait_for_process(proc: Process, timeout: float) -> bool: + with anyio.move_on_after(timeout) as scope: + await proc.wait() + return scope.cancel_called + + @asynccontextmanager -async def manage_subprocess(*args, **kwargs): - proc = await asyncio.create_subprocess_exec(*args, **kwargs) +async def manage_subprocess(*args, terminate_timeout: float = 2.0, **kwargs): + proc = await anyio.open_process(args, **kwargs) try: yield proc finally: if proc.returncode is None: - proc.terminate() - try: - await asyncio.wait_for(proc.wait(), timeout=2.0) - except asyncio.TimeoutError: - proc.kill() - await proc.wait() + with anyio.CancelScope(shield=True): + try: + proc.terminate() + except ProcessLookupError: + pass + timed_out = await _wait_for_process(proc, terminate_timeout) + if timed_out: + logger.debug( + "[codex] terminate timed out pid=%s; leaving process to exit", + proc.pid, + ) TELEGRAM_MARKDOWN_LIMIT = 3500 @@ -212,17 +242,17 @@ class ProgressEdits: self.last_rendered = last_rendered self._event_seq = 0 self._published_seq = 0 - self.wakeup = asyncio.Event() - self.task: asyncio.Task[None] | None = ( - asyncio.create_task(self.run()) if self.progress_id is not None else None - ) + self.wakeup = anyio.Event() + + async def _wait_for_wakeup(self) -> None: + await self.wakeup.wait() + self.wakeup = anyio.Event() async def run(self) -> None: if self.progress_id is None: return while True: - await self.wakeup.wait() - self.wakeup.clear() + await self._wait_for_wakeup() while self._published_seq < self._event_seq: await self.sleep( max( @@ -250,7 +280,6 @@ class ProgressEdits: self.last_rendered = rendered self._published_seq = seq_at_render - self.wakeup.clear() async def on_event(self, evt: dict[str, Any]) -> None: if not self.renderer.note_event(evt): @@ -260,12 +289,6 @@ class ProgressEdits: self._event_seq += 1 self.wakeup.set() - async def shutdown(self) -> None: - if self.task is None: - return - self.task.cancel() - await asyncio.gather(self.task, return_exceptions=True) - class CodexExecRunner: def __init__( @@ -277,14 +300,14 @@ class CodexExecRunner: self.extra_args = extra_args # Per-session locks to prevent concurrent resumes to the same session_id. - self._session_locks: WeakValueDictionary[str, asyncio.Lock] = ( + self._session_locks: WeakValueDictionary[str, anyio.Lock] = ( WeakValueDictionary() ) - async def _lock_for(self, session_id: str) -> asyncio.Lock: + async def _lock_for(self, session_id: str) -> anyio.Lock: lock = self._session_locks.get(session_id) if lock is None: - lock = asyncio.Lock() + lock = anyio.Lock() self._session_locks[session_id] = lock return lock @@ -306,19 +329,23 @@ class CodexExecRunner: else: args.append("-") + cancelled_exc_type = anyio.get_cancelled_exc_class() + cancelled_exc: BaseException | None = None async with manage_subprocess( *args, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, ) as proc: - proc_stdin = cast(asyncio.StreamWriter, proc.stdin) - proc_stdout = cast(asyncio.StreamReader, proc.stdout) - proc_stderr = cast(asyncio.StreamReader, proc.stderr) + if proc.stdin is None or proc.stdout is None or proc.stderr is None: + raise RuntimeError("codex exec failed to open subprocess pipes") + proc_stdin = proc.stdin + proc_stdout = proc.stdout + proc_stderr = proc.stderr logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args) stderr_tail: deque[str] = deque(maxlen=200) - stderr_task = asyncio.create_task(_drain_stderr(proc_stderr, stderr_tail)) + rc: int | None = None found_session: str | None = session_id last_agent_text: str | None = None @@ -326,62 +353,57 @@ class CodexExecRunner: cli_last_item: int | None = None cancelled = False - rc: int | None = None + async with anyio.create_task_group() as tg: + tg.start_soon(_drain_stderr, proc_stderr, stderr_tail) - try: - proc_stdin.write(prompt.encode()) - await proc_stdin.drain() - proc_stdin.close() + try: + await proc_stdin.send(prompt.encode()) + await proc_stdin.aclose() - async for raw_line in proc_stdout: - raw = raw_line.decode(errors="replace") - logger.debug("[codex][jsonl] %s", raw.rstrip("\n")) - line = raw.strip() - if not line: - continue - try: - evt = json.loads(line) - except json.JSONDecodeError: - logger.debug("[codex][jsonl] invalid line: %r", line) - continue - - cli_last_item, out_lines = render_event_cli(evt, cli_last_item) - for out in out_lines: - logger.info("[codex] %s", out) - - if on_event is not None: + async for raw_line in _iter_text_lines(proc_stdout): + raw = raw_line.rstrip("\n") + logger.debug("[codex][jsonl] %s", raw) + line = raw.strip() + if not line: + continue try: - res = on_event(evt) - if inspect.isawaitable(res): - await res - except Exception as e: - logger.info("[codex][on_event] callback error: %s", e) + evt = json.loads(line) + except json.JSONDecodeError: + logger.debug("[codex][jsonl] invalid line: %r", line) + continue - if evt["type"] == "thread.started": - found_session = evt.get("thread_id") or found_session + cli_last_item, out_lines = render_event_cli(evt, cli_last_item) + for out in out_lines: + logger.info("[codex] %s", out) - if evt["type"] == "item.completed": - item = evt.get("item") or {} - if item.get("type") == "agent_message" and isinstance( - item.get("text"), str - ): - last_agent_text = item["text"] - saw_agent_message = True - except asyncio.CancelledError: - cancelled = True - finally: - if cancelled: - if not stderr_task.done(): - stderr_task.cancel() - task = cast(asyncio.Task, asyncio.current_task()) - while task.cancelling(): - task.uncancel() - if not cancelled: - rc = await proc.wait() - await asyncio.gather(stderr_task, return_exceptions=True) + if on_event is not None: + try: + res = on_event(evt) + if inspect.isawaitable(res): + await res + except Exception as e: + logger.info("[codex][on_event] callback error: %s", e) + + if evt["type"] == "thread.started": + found_session = evt.get("thread_id") or found_session + + if evt["type"] == "item.completed": + item = evt.get("item") or {} + if item.get("type") == "agent_message" and isinstance( + item.get("text"), str + ): + last_agent_text = item["text"] + saw_agent_message = True + except cancelled_exc_type as exc: + cancelled = True + cancelled_exc = exc + tg.cancel_scope.cancel() + finally: + if not cancelled: + rc = await proc.wait() if cancelled: - raise asyncio.CancelledError + raise cancelled_exc # type: ignore[misc] logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc) if rc != 0: @@ -406,11 +428,31 @@ class CodexExecRunner: session_id: str | None, on_event: EventCallback | None = None, ) -> tuple[str, str, bool]: - if not session_id: - return await self.run(prompt, session_id=None, on_event=on_event) - lock = await self._lock_for(session_id) - async with lock: - return await self.run(prompt, session_id=session_id, on_event=on_event) + if session_id: + lock = await self._lock_for(session_id) + async with lock: + return await self.run(prompt, session_id=session_id, on_event=on_event) + + session_lock: anyio.Lock | None = None + + async def on_event_with_lock(evt: dict[str, Any]) -> None: + nonlocal session_lock + if session_lock is None and evt.get("type") == "thread.started": + thread_id = evt.get("thread_id") + if isinstance(thread_id, str) and thread_id: + session_lock = await self._lock_for(thread_id) + await session_lock.acquire() + if on_event is None: + return + res = on_event(evt) + if inspect.isawaitable(res): + await res + + try: + return await self.run(prompt, session_id=None, on_event=on_event_with_lock) + finally: + if session_lock is not None: + session_lock.release() @dataclass(frozen=True) @@ -423,6 +465,12 @@ class BridgeConfig: max_concurrency: int +@dataclass +class RunningTask: + scope: anyio.CancelScope + session_id: str | None = None + + def _parse_bridge_config( *, final_notify: bool, @@ -508,9 +556,9 @@ async def handle_message( user_msg_id: int, text: str, resume_session: str | None, - running_tasks: dict[str, asyncio.Task[Any]] | None = None, + running_tasks: dict[int, RunningTask] | None = None, clock: Callable[[], float] = time.monotonic, - sleep: Callable[[float], Awaitable[None]] = asyncio.sleep, + sleep: Callable[[float], Awaitable[None]] = anyio.sleep, progress_edit_every: float = PROGRESS_EDIT_EVERY_S, ) -> None: logger.debug( @@ -565,31 +613,54 @@ async def handle_message( last_rendered=last_rendered, ) - exec_task: asyncio.Task[tuple[str, str, bool]] | None = None + exec_scope = anyio.CancelScope() + cancelled = False + error: Exception | None = None + session_id: str | None = None + answer: str | None = None + saw_agent_message: bool | None = None + running_task: RunningTask | None = None + if running_tasks is not None and progress_id is not None: + running_task = RunningTask(scope=exec_scope) + running_tasks[progress_id] = running_task + if resume_session is not None: + running_task.session_id = resume_session async def on_event(evt: dict[str, Any]) -> None: if ( - evt["type"] == "thread.started" - and running_tasks is not None - and exec_task is not None + running_task is not None + and running_task.session_id is None + and evt.get("type") == "thread.started" ): - running_tasks[evt["thread_id"]] = exec_task + thread_id = evt.get("thread_id") + if isinstance(thread_id, str) and thread_id: + running_task.session_id = thread_id await edits.on_event(evt) - exec_task = asyncio.create_task( - cfg.runner.run_serialized(text, resume_session, on_event=on_event) - ) + async with anyio.create_task_group() as tg: + if progress_id is not None: + tg.start_soon(edits.run) - cancelled = False - try: - session_id, answer, saw_agent_message = await exec_task - except asyncio.CancelledError: - cancelled = True - session_id = progress_renderer.resume_session or resume_session - except Exception as e: - await edits.shutdown() + try: + with exec_scope: + session_id, answer, saw_agent_message = await cfg.runner.run_serialized( + text, resume_session, on_event=on_event + ) + except Exception as e: + error = e + finally: + if running_task is not None: + if running_tasks is not None and progress_id is not None: + running_tasks.pop(progress_id, None) + if exec_scope.cancelled_caught and not cancelled and error is None: + cancelled = True + session_id = progress_renderer.resume_session or resume_session + if not cancelled and error is None: + await anyio.sleep(0) + tg.cancel_scope.cancel() - err = _clamp_tg_text(f"Error:\n{e}") + if error is not None: + err = _clamp_tg_text(f"Error:\n{error}") logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err) await _send_or_edit_markdown( cfg.bot, @@ -601,16 +672,11 @@ async def handle_message( limit=TELEGRAM_MARKDOWN_LIMIT, ) return - finally: - if running_tasks is not None: - for sid, task in list(running_tasks.items()): - if task is exec_task: - running_tasks.pop(sid, None) - - await edits.shutdown() elapsed = clock() - started_at if cancelled: + if session_id is None: + session_id = progress_renderer.resume_session or resume_session logger.info( "[handle] cancelled session_id=%s elapsed=%.1fs", session_id, elapsed ) @@ -627,6 +693,9 @@ async def handle_message( ) return + if session_id is None or answer is None or saw_agent_message is None: + raise RuntimeError("codex exec finished without a result") + status = "done" if saw_agent_message else "error" progress_renderer.resume_session = session_id final_md = progress_renderer.render_final(elapsed, answer, status=status) @@ -679,7 +748,7 @@ async def poll_updates(cfg: BridgeConfig): ) if updates is None: logger.info("[loop] getUpdates failed") - await asyncio.sleep(2) + await anyio.sleep(2) continue logger.debug("[loop] updates: %s", updates) @@ -696,7 +765,7 @@ async def poll_updates(cfg: BridgeConfig): async def _handle_cancel( cfg: BridgeConfig, msg: dict[str, Any], - running_tasks: dict[str, asyncio.Task[Any]], + running_tasks: dict[int, RunningTask], ) -> None: chat_id = msg["chat"]["id"] user_msg_id = msg["message_id"] @@ -710,8 +779,8 @@ async def _handle_cancel( ) return - session_id = extract_session_id(reply.get("text")) - if not session_id: + progress_id = reply.get("message_id") + if progress_id is None: await cfg.bot.send_message( chat_id=chat_id, text="nothing is currently running for that message.", @@ -719,8 +788,8 @@ async def _handle_cancel( ) return - task = running_tasks.get(session_id) - if not task or task.done(): + running_task = running_tasks.get(int(progress_id)) + if running_task is None: await cfg.bot.send_message( chat_id=chat_id, text="nothing is currently running for that message.", @@ -728,20 +797,20 @@ async def _handle_cancel( ) return - logger.info("[cancel] cancelling session_id=%s", session_id) - task.cancel() + logger.info("[cancel] cancelling progress_message_id=%s", progress_id) + running_task.scope.cancel() async def _run_main_loop(cfg: BridgeConfig) -> None: worker_count = max(1, min(cfg.max_concurrency, 16)) - queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue( - maxsize=worker_count * 2 + send_stream, receive_stream = anyio.create_memory_object_stream( + max_buffer_size=worker_count * 2 ) - running_tasks: dict[str, asyncio.Task[Any]] = {} + running_tasks: dict[int, RunningTask] = {} async def worker() -> None: while True: - chat_id, user_msg_id, text, resume_session = await queue.get() + chat_id, user_msg_id, text, resume_session = await receive_stream.receive() try: await handle_message( cfg, @@ -753,26 +822,28 @@ async def _run_main_loop(cfg: BridgeConfig) -> None: ) except Exception: logger.exception("[handle] worker failed") - finally: - queue.task_done() try: - async with asyncio.TaskGroup() as tg: + async with anyio.create_task_group() as tg: for _ in range(worker_count): - tg.create_task(worker()) + tg.start_soon(worker) async for msg in poll_updates(cfg): text = msg["text"] user_msg_id = msg["message_id"] if text == "/cancel": - tg.create_task(_handle_cancel(cfg, msg, running_tasks)) + tg.start_soon(_handle_cancel, cfg, msg, running_tasks) continue r = msg.get("reply_to_message") or {} resume_session = resolve_resume_session(text, r.get("text")) - await queue.put((msg["chat"]["id"], user_msg_id, text, resume_session)) + await send_stream.send( + (msg["chat"]["id"], user_msg_id, text, resume_session) + ) finally: + await send_stream.aclose() + await receive_stream.aclose() await cfg.bot.close() @@ -813,7 +884,7 @@ def run( except ConfigError as e: typer.echo(str(e), err=True) raise typer.Exit(code=1) - asyncio.run(_run_main_loop(cfg)) + anyio.run(_run_main_loop, cfg) def main() -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 99f1af1..3fa616b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,11 @@ import sys from pathlib import Path +import pytest + sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" diff --git a/tests/test_exec_bridge.py b/tests/test_exec_bridge.py index 25d93aa..ad5b949 100644 --- a/tests/test_exec_bridge.py +++ b/tests/test_exec_bridge.py @@ -1,5 +1,4 @@ -import asyncio - +import anyio import pytest from takopi.exec_bridge import ( @@ -187,7 +186,7 @@ class _FakeClock: def __init__(self, start: float = 0.0) -> None: self._now = start self._sleep_until: float | None = None - self._sleep_event: asyncio.Event | None = None + self._sleep_event: anyio.Event | None = None self.sleep_calls = 0 def __call__(self) -> float: @@ -205,10 +204,10 @@ class _FakeClock: async def sleep(self, delay: float) -> None: self.sleep_calls += 1 if delay <= 0: - await asyncio.sleep(0) + await anyio.sleep(0) return self._sleep_until = self._now + delay - self._sleep_event = asyncio.Event() + self._sleep_event = anyio.Event() await self._sleep_event.wait() @@ -222,7 +221,7 @@ class _FakeRunnerWithEvents: answer: str = "ok", session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2", advance_after: float | None = None, - hold: asyncio.Event | None = None, + hold: anyio.Event | None = None, ) -> None: self._events = events self._times = times @@ -238,16 +237,17 @@ class _FakeRunnerWithEvents: for when, event in zip(self._times, self._events, strict=False): self._clock.set(when) await on_event(event) - await asyncio.sleep(0) + await anyio.sleep(0) if self._advance_after is not None: self._clock.set(self._advance_after) - await asyncio.sleep(0) + await anyio.sleep(0) if self._hold is not None: await self._hold.wait() return (self._session_id, self._answer, True) -def test_final_notify_sends_loud_final_message() -> None: +@pytest.mark.anyio +async def test_final_notify_sends_loud_final_message() -> None: from takopi.exec_bridge import BridgeConfig, handle_message bot = _FakeBot() @@ -261,14 +261,12 @@ def test_final_notify_sends_loud_final_message() -> None: max_concurrency=1, ) - asyncio.run( - handle_message( - cfg, - chat_id=123, - user_msg_id=10, - text="hi", - resume_session=None, - ) + await handle_message( + cfg, + chat_id=123, + user_msg_id=10, + text="hi", + resume_session=None, ) assert len(bot.send_calls) == 2 @@ -276,7 +274,8 @@ def test_final_notify_sends_loud_final_message() -> None: assert bot.send_calls[1]["disable_notification"] is False -def test_new_final_message_forces_notification_when_too_long_to_edit() -> None: +@pytest.mark.anyio +async def test_new_final_message_forces_notification_when_too_long_to_edit() -> None: from takopi.exec_bridge import BridgeConfig, handle_message bot = _FakeBot() @@ -290,14 +289,12 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None: max_concurrency=1, ) - asyncio.run( - handle_message( - cfg, - chat_id=123, - user_msg_id=10, - text="hi", - resume_session=None, - ) + await handle_message( + cfg, + chat_id=123, + user_msg_id=10, + text="hi", + resume_session=None, ) assert len(bot.send_calls) == 2 @@ -305,7 +302,8 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None: assert bot.send_calls[1]["disable_notification"] is False -def test_progress_edits_are_rate_limited() -> None: +@pytest.mark.anyio +async def test_progress_edits_are_rate_limited() -> None: from takopi.exec_bridge import BridgeConfig, handle_message bot = _FakeBot() @@ -345,29 +343,28 @@ def test_progress_edits_are_rate_limited() -> None: max_concurrency=1, ) - asyncio.run( - handle_message( - cfg, - chat_id=123, - user_msg_id=10, - text="hi", - resume_session=None, - clock=clock, - sleep=clock.sleep, - progress_edit_every=1.0, - ) + await handle_message( + cfg, + chat_id=123, + user_msg_id=10, + text="hi", + resume_session=None, + clock=clock, + sleep=clock.sleep, + progress_edit_every=1.0, ) assert len(bot.edit_calls) == 1 assert "echo 2" in bot.edit_calls[0]["text"] -def test_progress_edits_do_not_sleep_again_without_new_events() -> None: +@pytest.mark.anyio +async def test_progress_edits_do_not_sleep_again_without_new_events() -> None: from takopi.exec_bridge import BridgeConfig, handle_message bot = _FakeBot() clock = _FakeClock() - hold = asyncio.Event() + hold = anyio.Event() events = [ { "type": "item.started", @@ -404,24 +401,25 @@ def test_progress_edits_do_not_sleep_again_without_new_events() -> None: max_concurrency=1, ) - async def run_test() -> None: - task = asyncio.create_task( - handle_message( - cfg, - chat_id=123, - user_msg_id=10, - text="hi", - resume_session=None, - clock=clock, - sleep=clock.sleep, - progress_edit_every=1.0, - ) + async def run_handle_message() -> None: + await handle_message( + cfg, + chat_id=123, + user_msg_id=10, + text="hi", + resume_session=None, + clock=clock, + sleep=clock.sleep, + progress_edit_every=1.0, ) + async with anyio.create_task_group() as tg: + tg.start_soon(run_handle_message) + for _ in range(100): if clock._sleep_until is not None: break - await asyncio.sleep(0) + await anyio.sleep(0) assert clock._sleep_until == pytest.approx(1.0) @@ -430,23 +428,21 @@ def test_progress_edits_do_not_sleep_again_without_new_events() -> None: for _ in range(100): if bot.edit_calls: break - await asyncio.sleep(0) + await anyio.sleep(0) assert len(bot.edit_calls) == 1 for _ in range(5): - await asyncio.sleep(0) + await anyio.sleep(0) assert clock.sleep_calls == 1 assert clock._sleep_until is None hold.set() - await task - - asyncio.run(run_test()) -def test_bridge_flow_sends_progress_edits_and_final_resume() -> None: +@pytest.mark.anyio +async def test_bridge_flow_sends_progress_edits_and_final_resume() -> None: from takopi.exec_bridge import BridgeConfig, handle_message bot = _FakeBot() @@ -489,17 +485,15 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None: max_concurrency=1, ) - asyncio.run( - handle_message( - cfg, - chat_id=123, - user_msg_id=42, - text="do it", - resume_session=None, - clock=clock, - sleep=clock.sleep, - progress_edit_every=1.0, - ) + await handle_message( + cfg, + chat_id=123, + user_msg_id=42, + text="do it", + resume_session=None, + clock=clock, + sleep=clock.sleep, + progress_edit_every=1.0, ) assert bot.send_calls[0]["reply_to_message_id"] == 42 @@ -510,7 +504,8 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None: assert len(bot.delete_calls) == 1 -def test_handle_cancel_without_reply_prompts_user() -> None: +@pytest.mark.anyio +async def test_handle_cancel_without_reply_prompts_user() -> None: from takopi.exec_bridge import BridgeConfig, _handle_cancel bot = _FakeBot() @@ -526,13 +521,14 @@ def test_handle_cancel_without_reply_prompts_user() -> None: msg = {"chat": {"id": 123}, "message_id": 10} running_tasks: dict = {} - asyncio.run(_handle_cancel(cfg, msg, running_tasks)) + await _handle_cancel(cfg, msg, running_tasks) assert len(bot.send_calls) == 1 assert "reply to the progress message" in bot.send_calls[0]["text"] -def test_handle_cancel_with_no_session_id_says_nothing_running() -> None: +@pytest.mark.anyio +async def test_handle_cancel_with_no_progress_message_says_nothing_running() -> None: from takopi.exec_bridge import BridgeConfig, _handle_cancel bot = _FakeBot() @@ -548,17 +544,18 @@ def test_handle_cancel_with_no_session_id_says_nothing_running() -> None: msg = { "chat": {"id": 123}, "message_id": 10, - "reply_to_message": {"text": "no uuid here"}, + "reply_to_message": {"text": "no message id"}, } running_tasks: dict = {} - asyncio.run(_handle_cancel(cfg, msg, running_tasks)) + await _handle_cancel(cfg, msg, running_tasks) assert len(bot.send_calls) == 1 assert "nothing is currently running" in bot.send_calls[0]["text"] -def test_handle_cancel_with_finished_task_says_nothing_running() -> None: +@pytest.mark.anyio +async def test_handle_cancel_with_finished_task_says_nothing_running() -> None: from takopi.exec_bridge import BridgeConfig, _handle_cancel bot = _FakeBot() @@ -571,21 +568,22 @@ def test_handle_cancel_with_finished_task_says_nothing_running() -> None: startup_msg="", max_concurrency=1, ) - session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2" + progress_id = 99 msg = { "chat": {"id": 123}, "message_id": 10, - "reply_to_message": {"text": f"resume: `{session_id}`"}, + "reply_to_message": {"message_id": progress_id}, } - running_tasks: dict = {} # Session not in running_tasks + running_tasks: dict = {} # Progress message not in running_tasks - asyncio.run(_handle_cancel(cfg, msg, running_tasks)) + await _handle_cancel(cfg, msg, running_tasks) assert len(bot.send_calls) == 1 assert "nothing is currently running" in bot.send_calls[0]["text"] -def test_handle_cancel_cancels_running_task() -> None: +@pytest.mark.anyio +async def test_handle_cancel_cancels_running_task() -> None: from takopi.exec_bridge import BridgeConfig, _handle_cancel bot = _FakeBot() @@ -598,29 +596,70 @@ def test_handle_cancel_cancels_running_task() -> None: startup_msg="", max_concurrency=1, ) - session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2" + progress_id = 42 msg = { "chat": {"id": 123}, "message_id": 10, - "reply_to_message": {"text": f"resume: `{session_id}`"}, + "reply_to_message": {"message_id": progress_id}, } - async def run_test(): - task = asyncio.create_task(asyncio.sleep(10)) - running_tasks = {session_id: task} + from takopi.exec_bridge import RunningTask + + cancelled_event = anyio.Event() + cancel_scope = anyio.CancelScope() + running_task = RunningTask(scope=cancel_scope) + + async def sleeper() -> None: + with cancel_scope: + try: + await anyio.sleep(10) + except anyio.get_cancelled_exc_class(): + cancelled_event.set() + return + + async with anyio.create_task_group() as tg: + tg.start_soon(sleeper) + running_tasks = {progress_id: running_task} await _handle_cancel(cfg, msg, running_tasks) - try: - await task - except asyncio.CancelledError: - return True - return False + await cancelled_event.wait() - cancelled = asyncio.run(run_test()) - - assert cancelled is True assert len(bot.send_calls) == 0 # No error message sent +@pytest.mark.anyio +async def test_handle_cancel_only_cancels_matching_progress_message() -> None: + from takopi.exec_bridge import BridgeConfig, _handle_cancel + + bot = _FakeBot() + runner = _FakeRunner(answer="ok") + cfg = BridgeConfig( + bot=bot, # type: ignore[arg-type] + runner=runner, # type: ignore[arg-type] + chat_id=123, + final_notify=True, + startup_msg="", + max_concurrency=1, + ) + from takopi.exec_bridge import RunningTask + + scope_first = anyio.CancelScope() + scope_second = anyio.CancelScope() + task_first = RunningTask(scope=scope_first) + task_second = RunningTask(scope=scope_second) + msg = { + "chat": {"id": 123}, + "message_id": 10, + "reply_to_message": {"message_id": 1}, + } + running_tasks = {1: task_first, 2: task_second} + + await _handle_cancel(cfg, msg, running_tasks) + + assert scope_first.cancel_called is True + assert scope_second.cancel_called is False + assert len(bot.send_calls) == 0 + + class _FakeRunnerCancellable: def __init__(self, session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2"): self._session_id = session_id @@ -629,11 +668,12 @@ class _FakeRunnerCancellable: on_event = kwargs.get("on_event") if on_event: await on_event({"type": "thread.started", "thread_id": self._session_id}) - await asyncio.sleep(10) # Will be cancelled + await anyio.sleep(10) # Will be cancelled return (self._session_id, "ok", True) -def test_handle_message_cancelled_renders_cancelled_state() -> None: +@pytest.mark.anyio +async def test_handle_message_cancelled_renders_cancelled_state() -> None: from takopi.exec_bridge import BridgeConfig, handle_message bot = _FakeBot() @@ -649,23 +689,24 @@ def test_handle_message_cancelled_renders_cancelled_state() -> None: ) running_tasks: dict = {} - async def run_test(): - task = asyncio.create_task( - handle_message( - cfg, - chat_id=123, - user_msg_id=10, - text="do something", - resume_session=None, - running_tasks=running_tasks, - ) + async def run_handle_message() -> None: + await handle_message( + cfg, + chat_id=123, + user_msg_id=10, + text="do something", + resume_session=None, + running_tasks=running_tasks, ) - await asyncio.sleep(0.01) # Let task start and register - assert session_id in running_tasks - running_tasks[session_id].cancel() - await task - asyncio.run(run_test()) + async with anyio.create_task_group() as tg: + tg.start_soon(run_handle_message) + for _ in range(100): + if running_tasks: + break + await anyio.sleep(0) + assert running_tasks + running_tasks[next(iter(running_tasks))].scope.cancel() assert len(bot.send_calls) == 1 # Progress message assert len(bot.edit_calls) >= 1 diff --git a/tests/test_exec_runner.py b/tests/test_exec_runner.py index cfa1158..c7f872e 100644 --- a/tests/test_exec_runner.py +++ b/tests/test_exec_runner.py @@ -1,11 +1,13 @@ -import asyncio +import anyio +import pytest -from takopi.exec_bridge import CodexExecRunner +from takopi.exec_bridge import CodexExecRunner, EventCallback -def test_run_serialized_serializes_same_session() -> None: +@pytest.mark.anyio +async def test_run_serialized_serializes_same_session() -> None: runner = CodexExecRunner(codex_cmd="codex", extra_args=[]) - gate = asyncio.Event() + gate = anyio.Event() in_flight = 0 max_in_flight = 0 @@ -19,13 +21,68 @@ def test_run_serialized_serializes_same_session() -> None: runner.run = run_stub # type: ignore[assignment] - async def run_test() -> None: - t1 = asyncio.create_task(runner.run_serialized("a", "sid")) - t2 = asyncio.create_task(runner.run_serialized("b", "sid")) - await asyncio.sleep(0) + async with anyio.create_task_group() as tg: + tg.start_soon(runner.run_serialized, "a", "sid") + tg.start_soon(runner.run_serialized, "b", "sid") + await anyio.sleep(0) gate.set() - await asyncio.gather(t1, t2) - - asyncio.run(run_test()) assert max_in_flight == 1 + + +@pytest.mark.anyio +async def test_run_serialized_allows_parallel_new_sessions() -> None: + runner = CodexExecRunner(codex_cmd="codex", extra_args=[]) + gate = anyio.Event() + in_flight = 0 + max_in_flight = 0 + + async def run_stub(*_args, **_kwargs): + nonlocal in_flight, max_in_flight + in_flight += 1 + max_in_flight = max(max_in_flight, in_flight) + await gate.wait() + in_flight -= 1 + return ("sid", "ok", True) + + runner.run = run_stub # type: ignore[assignment] + + async with anyio.create_task_group() as tg: + tg.start_soon(runner.run_serialized, "a", None) + tg.start_soon(runner.run_serialized, "b", None) + with anyio.move_on_after(1): + while max_in_flight < 2: + await anyio.sleep(0) + gate.set() + + assert max_in_flight == 2 + + +@pytest.mark.anyio +async def test_new_session_holds_lock_for_resumes() -> None: + runner = CodexExecRunner(codex_cmd="codex", extra_args=[]) + finish = anyio.Event() + resume_started = anyio.Event() + + async def run_stub( + _prompt: str, + session_id: str | None, + on_event: EventCallback | None = None, + ) -> tuple[str, str, bool]: + if session_id is None: + if on_event: + await on_event({"type": "thread.started", "thread_id": "sid"}) + await finish.wait() + return ("sid", "ok", True) + resume_started.set() + return ("sid", "ok", True) + + runner.run = run_stub # type: ignore[assignment] + + async with anyio.create_task_group() as tg: + tg.start_soon(runner.run_serialized, "first", None) + await anyio.sleep(0) + tg.start_soon(runner.run_serialized, "resume", "sid") + await anyio.sleep(0) + assert not resume_started.is_set() + finish.set() diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 9cf21d0..d37f8b5 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -1,29 +1,19 @@ -import asyncio import sys +import pytest + from takopi import exec_bridge -def test_manage_subprocess_kills_when_terminate_times_out(monkeypatch) -> None: - async def fake_wait_for(awaitable, *args, **kwargs): - if hasattr(awaitable, "close"): - awaitable.close() - elif hasattr(awaitable, "cancel"): - awaitable.cancel() - raise asyncio.TimeoutError +@pytest.mark.anyio +async def test_manage_subprocess_kills_when_terminate_times_out() -> None: + async with exec_bridge.manage_subprocess( + sys.executable, + "-c", + "import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(10)", + terminate_timeout=0.01, + ) as proc: + assert proc.returncode is None - monkeypatch.setattr(exec_bridge.asyncio, "wait_for", fake_wait_for) - - async def run() -> int | None: - async with exec_bridge.manage_subprocess( - sys.executable, - "-c", - "import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(10)", - ) as proc: - assert proc.returncode is None - return proc.returncode - - rc = asyncio.run(run()) - - assert rc is not None - assert rc != 0 + assert proc.returncode is not None + assert proc.returncode != 0 diff --git a/tests/test_telegram_client.py b/tests/test_telegram_client.py index b117cbb..161cffe 100644 --- a/tests/test_telegram_client.py +++ b/tests/test_telegram_client.py @@ -1,4 +1,3 @@ -import asyncio import logging import httpx @@ -8,7 +7,8 @@ from takopi.logging import RedactTokenFilter from takopi.telegram import TelegramClient -def test_telegram_429_no_retry() -> None: +@pytest.mark.anyio +async def test_telegram_429_no_retry() -> None: calls: list[int] = [] def handler(request: httpx.Request) -> httpx.Response: @@ -25,21 +25,21 @@ def test_telegram_429_no_retry() -> None: transport = httpx.MockTransport(handler) - async def run() -> dict | None: - client = httpx.AsyncClient(transport=transport) - try: - tg = TelegramClient("123:abcDEF_ghij", client=client) - return await tg._post("sendMessage", {"chat_id": 1, "text": "hi"}) - finally: - await client.aclose() - - result = asyncio.run(run()) + client = httpx.AsyncClient(transport=transport) + try: + tg = TelegramClient("123:abcDEF_ghij", client=client) + result = await tg._post("sendMessage", {"chat_id": 1, "text": "hi"}) + finally: + await client.aclose() assert result is None assert len(calls) == 1 -def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> None: +@pytest.mark.anyio +async def test_no_token_in_logs_on_http_error( + caplog: pytest.LogCaptureFixture, +) -> None: token = "123:abcDEF_ghij" redactor = RedactTokenFilter() root_logger = logging.getLogger() @@ -50,16 +50,13 @@ def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> Non transport = httpx.MockTransport(handler) - async def run() -> None: - client = httpx.AsyncClient(transport=transport) - try: - tg = TelegramClient(token, client=client) - await tg._post("getUpdates", {"timeout": 1}) - finally: - await client.aclose() - caplog.set_level(logging.ERROR) - asyncio.run(run()) + client = httpx.AsyncClient(transport=transport) + try: + tg = TelegramClient(token, client=client) + await tg._post("getUpdates", {"timeout": 1}) + finally: + await client.aclose() root_logger.removeFilter(redactor) diff --git a/uv.lock b/uv.lock index 519e011..e525f4f 100644 --- a/uv.lock +++ b/uv.lock @@ -331,6 +331,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-anyio" +version = "0.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/44/a02e5877a671b0940f21a7a0d9704c22097b123ed5cdbcca9cab39f17acc/pytest-anyio-0.0.0.tar.gz", hash = "sha256:b41234e9e9ad7ea1dbfefcc1d6891b23d5ef7c9f07ccf804c13a9cc338571fd3", size = 1560, upload-time = "2021-06-29T22:57:30.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/25/bd6493ae85d0a281b6a0f248d0fdb1d9aa2b31f18bcd4a8800cf397d8209/pytest_anyio-0.0.0-py2.py3-none-any.whl", hash = "sha256:dc8b5c4741cb16ff90be37fddd585ca943ed12bbeb563de7ace6cd94441d8746", size = 1999, upload-time = "2021-06-29T22:57:29.158Z" }, +] + [[package]] name = "pytest-cov" version = "7.0.0" @@ -420,6 +433,7 @@ name = "takopi" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "anyio" }, { name = "httpx" }, { name = "markdown-it-py" }, { name = "rich" }, @@ -430,6 +444,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "pytest" }, + { name = "pytest-anyio" }, { name = "pytest-cov" }, { name = "ruff" }, { name = "ty" }, @@ -437,6 +452,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "anyio", specifier = ">=4.12.0" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "markdown-it-py" }, { name = "rich", specifier = ">=14.2.0" }, @@ -447,6 +463,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "pytest", specifier = ">=9.0.2" }, + { name = "pytest-anyio", specifier = ">=0.0.0" }, { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "ruff", specifier = ">=0.14.10" }, { name = "ty", specifier = ">=0.0.8" },