refactor: migrate exec bridge to anyio and harden cancellation (#6)

This commit is contained in:
banteg
2025-12-31 01:51:46 +04:00
committed by GitHub
parent 6687a435c9
commit 8eda3f5e84
9 changed files with 492 additions and 310 deletions
+3 -3
View File
@@ -44,8 +44,8 @@ The orchestrator module containing:
**Key patterns:**
- Per-session locks prevent concurrent resumes to the same `session_id`
- Worker pool with `asyncio.Queue` limits concurrency (default: 16 workers)
- `asyncio.TaskGroup` manages worker tasks
- Worker pool with an AnyIO memory stream limits concurrency (default: 16 workers)
- AnyIO task groups manage worker tasks
- Progress edits are throttled to ~2s intervals
- Subprocess stderr is drained to a bounded deque for error reporting
- `poll_updates()` uses Telegram `getUpdates` long-polling with a single server-side updates
@@ -154,5 +154,5 @@ Same as above, but:
|----------|----------|
| `codex exec` fails (rc≠0) | Shows stderr tail in error message |
| Telegram API error | Logged, edit skipped (progress continues) |
| Cancellation | Subprocess terminated, CancelledError re-raised |
| Cancellation | Cancel scope triggers terminate; cancellation is detected via `cancelled_caught` |
| No agent_message | Final shows "error" status |
+2
View File
@@ -7,6 +7,7 @@ readme = "readme.md"
license = { file = "LICENSE" }
requires-python = ">=3.12"
dependencies = [
"anyio>=4.12.0",
"httpx>=0.28.1",
"markdown-it-py",
"rich>=14.2.0",
@@ -38,6 +39,7 @@ build-backend = "uv_build"
[dependency-groups]
dev = [
"pytest>=9.0.2",
"pytest-anyio>=0.0.0",
"pytest-cov>=7.0.0",
"ruff>=0.14.10",
"ty>=0.0.8",
+171 -100
View File
@@ -1,21 +1,24 @@
from __future__ import annotations
import asyncio
import inspect
import json
import logging
import os
import re
import shutil
import subprocess
import time
from collections import deque
from collections.abc import Awaitable, Callable
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, cast
from typing import Any
from weakref import WeakValueDictionary
import anyio
import typer
from anyio.abc import ByteReceiveStream, Process
from anyio.streams.text import TextReceiveStream
from . import __version__
from .config import ConfigError, load_telegram_config
@@ -61,32 +64,59 @@ def resolve_resume_session(text: str | None, reply_text: str | None) -> str | No
return extract_session_id(text) or extract_session_id(reply_text)
async def _drain_stderr(stderr: asyncio.StreamReader, tail: deque[str]) -> None:
try:
async def _iter_text_lines(stream: ByteReceiveStream) -> AsyncIterator[str]:
text_stream = TextReceiveStream(stream, errors="replace")
buffer = ""
while True:
line = await stderr.readline()
if not line:
try:
chunk = await text_stream.receive()
except anyio.EndOfStream:
if buffer:
yield buffer
return
decoded = line.decode(errors="replace")
logger.info("[codex][stderr] %s", decoded.rstrip())
tail.append(decoded)
buffer += chunk
while True:
split_at = buffer.find("\n")
if split_at < 0:
break
line = buffer[: split_at + 1]
buffer = buffer[split_at + 1 :]
yield line
async def _drain_stderr(stderr: ByteReceiveStream, tail: deque[str]) -> None:
try:
async for line in _iter_text_lines(stderr):
logger.info("[codex][stderr] %s", line.rstrip())
tail.append(line)
except Exception as e:
logger.debug("[codex][stderr] drain error: %s", e)
async def _wait_for_process(proc: Process, timeout: float) -> bool:
with anyio.move_on_after(timeout) as scope:
await proc.wait()
return scope.cancel_called
@asynccontextmanager
async def manage_subprocess(*args, **kwargs):
proc = await asyncio.create_subprocess_exec(*args, **kwargs)
async def manage_subprocess(*args, terminate_timeout: float = 2.0, **kwargs):
proc = await anyio.open_process(args, **kwargs)
try:
yield proc
finally:
if proc.returncode is None:
proc.terminate()
with anyio.CancelScope(shield=True):
try:
await asyncio.wait_for(proc.wait(), timeout=2.0)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
proc.terminate()
except ProcessLookupError:
pass
timed_out = await _wait_for_process(proc, terminate_timeout)
if timed_out:
logger.debug(
"[codex] terminate timed out pid=%s; leaving process to exit",
proc.pid,
)
TELEGRAM_MARKDOWN_LIMIT = 3500
@@ -212,17 +242,17 @@ class ProgressEdits:
self.last_rendered = last_rendered
self._event_seq = 0
self._published_seq = 0
self.wakeup = asyncio.Event()
self.task: asyncio.Task[None] | None = (
asyncio.create_task(self.run()) if self.progress_id is not None else None
)
self.wakeup = anyio.Event()
async def _wait_for_wakeup(self) -> None:
await self.wakeup.wait()
self.wakeup = anyio.Event()
async def run(self) -> None:
if self.progress_id is None:
return
while True:
await self.wakeup.wait()
self.wakeup.clear()
await self._wait_for_wakeup()
while self._published_seq < self._event_seq:
await self.sleep(
max(
@@ -250,7 +280,6 @@ class ProgressEdits:
self.last_rendered = rendered
self._published_seq = seq_at_render
self.wakeup.clear()
async def on_event(self, evt: dict[str, Any]) -> None:
if not self.renderer.note_event(evt):
@@ -260,12 +289,6 @@ class ProgressEdits:
self._event_seq += 1
self.wakeup.set()
async def shutdown(self) -> None:
if self.task is None:
return
self.task.cancel()
await asyncio.gather(self.task, return_exceptions=True)
class CodexExecRunner:
def __init__(
@@ -277,14 +300,14 @@ class CodexExecRunner:
self.extra_args = extra_args
# Per-session locks to prevent concurrent resumes to the same session_id.
self._session_locks: WeakValueDictionary[str, asyncio.Lock] = (
self._session_locks: WeakValueDictionary[str, anyio.Lock] = (
WeakValueDictionary()
)
async def _lock_for(self, session_id: str) -> asyncio.Lock:
async def _lock_for(self, session_id: str) -> anyio.Lock:
lock = self._session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
lock = anyio.Lock()
self._session_locks[session_id] = lock
return lock
@@ -306,19 +329,23 @@ class CodexExecRunner:
else:
args.append("-")
cancelled_exc_type = anyio.get_cancelled_exc_class()
cancelled_exc: BaseException | None = None
async with manage_subprocess(
*args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
proc_stdin = cast(asyncio.StreamWriter, proc.stdin)
proc_stdout = cast(asyncio.StreamReader, proc.stdout)
proc_stderr = cast(asyncio.StreamReader, proc.stderr)
if proc.stdin is None or proc.stdout is None or proc.stderr is None:
raise RuntimeError("codex exec failed to open subprocess pipes")
proc_stdin = proc.stdin
proc_stdout = proc.stdout
proc_stderr = proc.stderr
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
stderr_tail: deque[str] = deque(maxlen=200)
stderr_task = asyncio.create_task(_drain_stderr(proc_stderr, stderr_tail))
rc: int | None = None
found_session: str | None = session_id
last_agent_text: str | None = None
@@ -326,16 +353,16 @@ class CodexExecRunner:
cli_last_item: int | None = None
cancelled = False
rc: int | None = None
async with anyio.create_task_group() as tg:
tg.start_soon(_drain_stderr, proc_stderr, stderr_tail)
try:
proc_stdin.write(prompt.encode())
await proc_stdin.drain()
proc_stdin.close()
await proc_stdin.send(prompt.encode())
await proc_stdin.aclose()
async for raw_line in proc_stdout:
raw = raw_line.decode(errors="replace")
logger.debug("[codex][jsonl] %s", raw.rstrip("\n"))
async for raw_line in _iter_text_lines(proc_stdout):
raw = raw_line.rstrip("\n")
logger.debug("[codex][jsonl] %s", raw)
line = raw.strip()
if not line:
continue
@@ -367,21 +394,16 @@ class CodexExecRunner:
):
last_agent_text = item["text"]
saw_agent_message = True
except asyncio.CancelledError:
except cancelled_exc_type as exc:
cancelled = True
cancelled_exc = exc
tg.cancel_scope.cancel()
finally:
if cancelled:
if not stderr_task.done():
stderr_task.cancel()
task = cast(asyncio.Task, asyncio.current_task())
while task.cancelling():
task.uncancel()
if not cancelled:
rc = await proc.wait()
await asyncio.gather(stderr_task, return_exceptions=True)
if cancelled:
raise asyncio.CancelledError
raise cancelled_exc # type: ignore[misc]
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
if rc != 0:
@@ -406,12 +428,32 @@ class CodexExecRunner:
session_id: str | None,
on_event: EventCallback | None = None,
) -> tuple[str, str, bool]:
if not session_id:
return await self.run(prompt, session_id=None, on_event=on_event)
if session_id:
lock = await self._lock_for(session_id)
async with lock:
return await self.run(prompt, session_id=session_id, on_event=on_event)
session_lock: anyio.Lock | None = None
async def on_event_with_lock(evt: dict[str, Any]) -> None:
nonlocal session_lock
if session_lock is None and evt.get("type") == "thread.started":
thread_id = evt.get("thread_id")
if isinstance(thread_id, str) and thread_id:
session_lock = await self._lock_for(thread_id)
await session_lock.acquire()
if on_event is None:
return
res = on_event(evt)
if inspect.isawaitable(res):
await res
try:
return await self.run(prompt, session_id=None, on_event=on_event_with_lock)
finally:
if session_lock is not None:
session_lock.release()
@dataclass(frozen=True)
class BridgeConfig:
@@ -423,6 +465,12 @@ class BridgeConfig:
max_concurrency: int
@dataclass
class RunningTask:
scope: anyio.CancelScope
session_id: str | None = None
def _parse_bridge_config(
*,
final_notify: bool,
@@ -508,9 +556,9 @@ async def handle_message(
user_msg_id: int,
text: str,
resume_session: str | None,
running_tasks: dict[str, asyncio.Task[Any]] | None = None,
running_tasks: dict[int, RunningTask] | None = None,
clock: Callable[[], float] = time.monotonic,
sleep: Callable[[float], Awaitable[None]] = asyncio.sleep,
sleep: Callable[[float], Awaitable[None]] = anyio.sleep,
progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
) -> None:
logger.debug(
@@ -565,31 +613,54 @@ async def handle_message(
last_rendered=last_rendered,
)
exec_task: asyncio.Task[tuple[str, str, bool]] | None = None
exec_scope = anyio.CancelScope()
cancelled = False
error: Exception | None = None
session_id: str | None = None
answer: str | None = None
saw_agent_message: bool | None = None
running_task: RunningTask | None = None
if running_tasks is not None and progress_id is not None:
running_task = RunningTask(scope=exec_scope)
running_tasks[progress_id] = running_task
if resume_session is not None:
running_task.session_id = resume_session
async def on_event(evt: dict[str, Any]) -> None:
if (
evt["type"] == "thread.started"
and running_tasks is not None
and exec_task is not None
running_task is not None
and running_task.session_id is None
and evt.get("type") == "thread.started"
):
running_tasks[evt["thread_id"]] = exec_task
thread_id = evt.get("thread_id")
if isinstance(thread_id, str) and thread_id:
running_task.session_id = thread_id
await edits.on_event(evt)
exec_task = asyncio.create_task(
cfg.runner.run_serialized(text, resume_session, on_event=on_event)
)
async with anyio.create_task_group() as tg:
if progress_id is not None:
tg.start_soon(edits.run)
cancelled = False
try:
session_id, answer, saw_agent_message = await exec_task
except asyncio.CancelledError:
with exec_scope:
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
text, resume_session, on_event=on_event
)
except Exception as e:
error = e
finally:
if running_task is not None:
if running_tasks is not None and progress_id is not None:
running_tasks.pop(progress_id, None)
if exec_scope.cancelled_caught and not cancelled and error is None:
cancelled = True
session_id = progress_renderer.resume_session or resume_session
except Exception as e:
await edits.shutdown()
if not cancelled and error is None:
await anyio.sleep(0)
tg.cancel_scope.cancel()
err = _clamp_tg_text(f"Error:\n{e}")
if error is not None:
err = _clamp_tg_text(f"Error:\n{error}")
logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err)
await _send_or_edit_markdown(
cfg.bot,
@@ -601,16 +672,11 @@ async def handle_message(
limit=TELEGRAM_MARKDOWN_LIMIT,
)
return
finally:
if running_tasks is not None:
for sid, task in list(running_tasks.items()):
if task is exec_task:
running_tasks.pop(sid, None)
await edits.shutdown()
elapsed = clock() - started_at
if cancelled:
if session_id is None:
session_id = progress_renderer.resume_session or resume_session
logger.info(
"[handle] cancelled session_id=%s elapsed=%.1fs", session_id, elapsed
)
@@ -627,6 +693,9 @@ async def handle_message(
)
return
if session_id is None or answer is None or saw_agent_message is None:
raise RuntimeError("codex exec finished without a result")
status = "done" if saw_agent_message else "error"
progress_renderer.resume_session = session_id
final_md = progress_renderer.render_final(elapsed, answer, status=status)
@@ -679,7 +748,7 @@ async def poll_updates(cfg: BridgeConfig):
)
if updates is None:
logger.info("[loop] getUpdates failed")
await asyncio.sleep(2)
await anyio.sleep(2)
continue
logger.debug("[loop] updates: %s", updates)
@@ -696,7 +765,7 @@ async def poll_updates(cfg: BridgeConfig):
async def _handle_cancel(
cfg: BridgeConfig,
msg: dict[str, Any],
running_tasks: dict[str, asyncio.Task[Any]],
running_tasks: dict[int, RunningTask],
) -> None:
chat_id = msg["chat"]["id"]
user_msg_id = msg["message_id"]
@@ -710,8 +779,8 @@ async def _handle_cancel(
)
return
session_id = extract_session_id(reply.get("text"))
if not session_id:
progress_id = reply.get("message_id")
if progress_id is None:
await cfg.bot.send_message(
chat_id=chat_id,
text="nothing is currently running for that message.",
@@ -719,8 +788,8 @@ async def _handle_cancel(
)
return
task = running_tasks.get(session_id)
if not task or task.done():
running_task = running_tasks.get(int(progress_id))
if running_task is None:
await cfg.bot.send_message(
chat_id=chat_id,
text="nothing is currently running for that message.",
@@ -728,20 +797,20 @@ async def _handle_cancel(
)
return
logger.info("[cancel] cancelling session_id=%s", session_id)
task.cancel()
logger.info("[cancel] cancelling progress_message_id=%s", progress_id)
running_task.scope.cancel()
async def _run_main_loop(cfg: BridgeConfig) -> None:
worker_count = max(1, min(cfg.max_concurrency, 16))
queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue(
maxsize=worker_count * 2
send_stream, receive_stream = anyio.create_memory_object_stream(
max_buffer_size=worker_count * 2
)
running_tasks: dict[str, asyncio.Task[Any]] = {}
running_tasks: dict[int, RunningTask] = {}
async def worker() -> None:
while True:
chat_id, user_msg_id, text, resume_session = await queue.get()
chat_id, user_msg_id, text, resume_session = await receive_stream.receive()
try:
await handle_message(
cfg,
@@ -753,26 +822,28 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
)
except Exception:
logger.exception("[handle] worker failed")
finally:
queue.task_done()
try:
async with asyncio.TaskGroup() as tg:
async with anyio.create_task_group() as tg:
for _ in range(worker_count):
tg.create_task(worker())
tg.start_soon(worker)
async for msg in poll_updates(cfg):
text = msg["text"]
user_msg_id = msg["message_id"]
if text == "/cancel":
tg.create_task(_handle_cancel(cfg, msg, running_tasks))
tg.start_soon(_handle_cancel, cfg, msg, running_tasks)
continue
r = msg.get("reply_to_message") or {}
resume_session = resolve_resume_session(text, r.get("text"))
await queue.put((msg["chat"]["id"], user_msg_id, text, resume_session))
await send_stream.send(
(msg["chat"]["id"], user_msg_id, text, resume_session)
)
finally:
await send_stream.aclose()
await receive_stream.aclose()
await cfg.bot.close()
@@ -813,7 +884,7 @@ def run(
except ConfigError as e:
typer.echo(str(e), err=True)
raise typer.Exit(code=1)
asyncio.run(_run_main_loop(cfg))
anyio.run(_run_main_loop, cfg)
def main() -> None:
+7
View File
@@ -1,4 +1,11 @@
import sys
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
@pytest.fixture
def anyio_backend() -> str:
return "asyncio"
+154 -113
View File
@@ -1,5 +1,4 @@
import asyncio
import anyio
import pytest
from takopi.exec_bridge import (
@@ -187,7 +186,7 @@ class _FakeClock:
def __init__(self, start: float = 0.0) -> None:
self._now = start
self._sleep_until: float | None = None
self._sleep_event: asyncio.Event | None = None
self._sleep_event: anyio.Event | None = None
self.sleep_calls = 0
def __call__(self) -> float:
@@ -205,10 +204,10 @@ class _FakeClock:
async def sleep(self, delay: float) -> None:
self.sleep_calls += 1
if delay <= 0:
await asyncio.sleep(0)
await anyio.sleep(0)
return
self._sleep_until = self._now + delay
self._sleep_event = asyncio.Event()
self._sleep_event = anyio.Event()
await self._sleep_event.wait()
@@ -222,7 +221,7 @@ class _FakeRunnerWithEvents:
answer: str = "ok",
session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2",
advance_after: float | None = None,
hold: asyncio.Event | None = None,
hold: anyio.Event | None = None,
) -> None:
self._events = events
self._times = times
@@ -238,16 +237,17 @@ class _FakeRunnerWithEvents:
for when, event in zip(self._times, self._events, strict=False):
self._clock.set(when)
await on_event(event)
await asyncio.sleep(0)
await anyio.sleep(0)
if self._advance_after is not None:
self._clock.set(self._advance_after)
await asyncio.sleep(0)
await anyio.sleep(0)
if self._hold is not None:
await self._hold.wait()
return (self._session_id, self._answer, True)
def test_final_notify_sends_loud_final_message() -> None:
@pytest.mark.anyio
async def test_final_notify_sends_loud_final_message() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
@@ -261,22 +261,21 @@ def test_final_notify_sends_loud_final_message() -> None:
max_concurrency=1,
)
asyncio.run(
handle_message(
await handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="hi",
resume_session=None,
)
)
assert len(bot.send_calls) == 2
assert bot.send_calls[0]["disable_notification"] is True
assert bot.send_calls[1]["disable_notification"] is False
def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
@pytest.mark.anyio
async def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
@@ -290,22 +289,21 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
max_concurrency=1,
)
asyncio.run(
handle_message(
await handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="hi",
resume_session=None,
)
)
assert len(bot.send_calls) == 2
assert bot.send_calls[0]["disable_notification"] is True
assert bot.send_calls[1]["disable_notification"] is False
def test_progress_edits_are_rate_limited() -> None:
@pytest.mark.anyio
async def test_progress_edits_are_rate_limited() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
@@ -345,8 +343,7 @@ def test_progress_edits_are_rate_limited() -> None:
max_concurrency=1,
)
asyncio.run(
handle_message(
await handle_message(
cfg,
chat_id=123,
user_msg_id=10,
@@ -356,18 +353,18 @@ def test_progress_edits_are_rate_limited() -> None:
sleep=clock.sleep,
progress_edit_every=1.0,
)
)
assert len(bot.edit_calls) == 1
assert "echo 2" in bot.edit_calls[0]["text"]
def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
@pytest.mark.anyio
async def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
clock = _FakeClock()
hold = asyncio.Event()
hold = anyio.Event()
events = [
{
"type": "item.started",
@@ -404,9 +401,8 @@ def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
max_concurrency=1,
)
async def run_test() -> None:
task = asyncio.create_task(
handle_message(
async def run_handle_message() -> None:
await handle_message(
cfg,
chat_id=123,
user_msg_id=10,
@@ -416,12 +412,14 @@ def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
sleep=clock.sleep,
progress_edit_every=1.0,
)
)
async with anyio.create_task_group() as tg:
tg.start_soon(run_handle_message)
for _ in range(100):
if clock._sleep_until is not None:
break
await asyncio.sleep(0)
await anyio.sleep(0)
assert clock._sleep_until == pytest.approx(1.0)
@@ -430,23 +428,21 @@ def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
for _ in range(100):
if bot.edit_calls:
break
await asyncio.sleep(0)
await anyio.sleep(0)
assert len(bot.edit_calls) == 1
for _ in range(5):
await asyncio.sleep(0)
await anyio.sleep(0)
assert clock.sleep_calls == 1
assert clock._sleep_until is None
hold.set()
await task
asyncio.run(run_test())
def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
@pytest.mark.anyio
async def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
@@ -489,8 +485,7 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
max_concurrency=1,
)
asyncio.run(
handle_message(
await handle_message(
cfg,
chat_id=123,
user_msg_id=42,
@@ -500,7 +495,6 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
sleep=clock.sleep,
progress_edit_every=1.0,
)
)
assert bot.send_calls[0]["reply_to_message_id"] == 42
assert "working" in bot.send_calls[0]["text"]
@@ -510,7 +504,8 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
assert len(bot.delete_calls) == 1
def test_handle_cancel_without_reply_prompts_user() -> None:
@pytest.mark.anyio
async def test_handle_cancel_without_reply_prompts_user() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
@@ -526,13 +521,14 @@ def test_handle_cancel_without_reply_prompts_user() -> None:
msg = {"chat": {"id": 123}, "message_id": 10}
running_tasks: dict = {}
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
await _handle_cancel(cfg, msg, running_tasks)
assert len(bot.send_calls) == 1
assert "reply to the progress message" in bot.send_calls[0]["text"]
def test_handle_cancel_with_no_session_id_says_nothing_running() -> None:
@pytest.mark.anyio
async def test_handle_cancel_with_no_progress_message_says_nothing_running() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
@@ -548,79 +544,122 @@ def test_handle_cancel_with_no_session_id_says_nothing_running() -> None:
msg = {
"chat": {"id": 123},
"message_id": 10,
"reply_to_message": {"text": "no uuid here"},
"reply_to_message": {"text": "no message id"},
}
running_tasks: dict = {}
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
assert len(bot.send_calls) == 1
assert "nothing is currently running" in bot.send_calls[0]["text"]
def test_handle_cancel_with_finished_task_says_nothing_running() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
runner = _FakeRunner(answer="ok")
cfg = BridgeConfig(
bot=bot, # type: ignore[arg-type]
runner=runner, # type: ignore[arg-type]
chat_id=123,
final_notify=True,
startup_msg="",
max_concurrency=1,
)
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
msg = {
"chat": {"id": 123},
"message_id": 10,
"reply_to_message": {"text": f"resume: `{session_id}`"},
}
running_tasks: dict = {} # Session not in running_tasks
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
assert len(bot.send_calls) == 1
assert "nothing is currently running" in bot.send_calls[0]["text"]
def test_handle_cancel_cancels_running_task() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
runner = _FakeRunner(answer="ok")
cfg = BridgeConfig(
bot=bot, # type: ignore[arg-type]
runner=runner, # type: ignore[arg-type]
chat_id=123,
final_notify=True,
startup_msg="",
max_concurrency=1,
)
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
msg = {
"chat": {"id": 123},
"message_id": 10,
"reply_to_message": {"text": f"resume: `{session_id}`"},
}
async def run_test():
task = asyncio.create_task(asyncio.sleep(10))
running_tasks = {session_id: task}
await _handle_cancel(cfg, msg, running_tasks)
assert len(bot.send_calls) == 1
assert "nothing is currently running" in bot.send_calls[0]["text"]
@pytest.mark.anyio
async def test_handle_cancel_with_finished_task_says_nothing_running() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
runner = _FakeRunner(answer="ok")
cfg = BridgeConfig(
bot=bot, # type: ignore[arg-type]
runner=runner, # type: ignore[arg-type]
chat_id=123,
final_notify=True,
startup_msg="",
max_concurrency=1,
)
progress_id = 99
msg = {
"chat": {"id": 123},
"message_id": 10,
"reply_to_message": {"message_id": progress_id},
}
running_tasks: dict = {} # Progress message not in running_tasks
await _handle_cancel(cfg, msg, running_tasks)
assert len(bot.send_calls) == 1
assert "nothing is currently running" in bot.send_calls[0]["text"]
@pytest.mark.anyio
async def test_handle_cancel_cancels_running_task() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
runner = _FakeRunner(answer="ok")
cfg = BridgeConfig(
bot=bot, # type: ignore[arg-type]
runner=runner, # type: ignore[arg-type]
chat_id=123,
final_notify=True,
startup_msg="",
max_concurrency=1,
)
progress_id = 42
msg = {
"chat": {"id": 123},
"message_id": 10,
"reply_to_message": {"message_id": progress_id},
}
from takopi.exec_bridge import RunningTask
cancelled_event = anyio.Event()
cancel_scope = anyio.CancelScope()
running_task = RunningTask(scope=cancel_scope)
async def sleeper() -> None:
with cancel_scope:
try:
await task
except asyncio.CancelledError:
return True
return False
await anyio.sleep(10)
except anyio.get_cancelled_exc_class():
cancelled_event.set()
return
cancelled = asyncio.run(run_test())
async with anyio.create_task_group() as tg:
tg.start_soon(sleeper)
running_tasks = {progress_id: running_task}
await _handle_cancel(cfg, msg, running_tasks)
await cancelled_event.wait()
assert cancelled is True
assert len(bot.send_calls) == 0 # No error message sent
@pytest.mark.anyio
async def test_handle_cancel_only_cancels_matching_progress_message() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
runner = _FakeRunner(answer="ok")
cfg = BridgeConfig(
bot=bot, # type: ignore[arg-type]
runner=runner, # type: ignore[arg-type]
chat_id=123,
final_notify=True,
startup_msg="",
max_concurrency=1,
)
from takopi.exec_bridge import RunningTask
scope_first = anyio.CancelScope()
scope_second = anyio.CancelScope()
task_first = RunningTask(scope=scope_first)
task_second = RunningTask(scope=scope_second)
msg = {
"chat": {"id": 123},
"message_id": 10,
"reply_to_message": {"message_id": 1},
}
running_tasks = {1: task_first, 2: task_second}
await _handle_cancel(cfg, msg, running_tasks)
assert scope_first.cancel_called is True
assert scope_second.cancel_called is False
assert len(bot.send_calls) == 0
class _FakeRunnerCancellable:
def __init__(self, session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2"):
self._session_id = session_id
@@ -629,11 +668,12 @@ class _FakeRunnerCancellable:
on_event = kwargs.get("on_event")
if on_event:
await on_event({"type": "thread.started", "thread_id": self._session_id})
await asyncio.sleep(10) # Will be cancelled
await anyio.sleep(10) # Will be cancelled
return (self._session_id, "ok", True)
def test_handle_message_cancelled_renders_cancelled_state() -> None:
@pytest.mark.anyio
async def test_handle_message_cancelled_renders_cancelled_state() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
@@ -649,9 +689,8 @@ def test_handle_message_cancelled_renders_cancelled_state() -> None:
)
running_tasks: dict = {}
async def run_test():
task = asyncio.create_task(
handle_message(
async def run_handle_message() -> None:
await handle_message(
cfg,
chat_id=123,
user_msg_id=10,
@@ -659,13 +698,15 @@ def test_handle_message_cancelled_renders_cancelled_state() -> None:
resume_session=None,
running_tasks=running_tasks,
)
)
await asyncio.sleep(0.01) # Let task start and register
assert session_id in running_tasks
running_tasks[session_id].cancel()
await task
asyncio.run(run_test())
async with anyio.create_task_group() as tg:
tg.start_soon(run_handle_message)
for _ in range(100):
if running_tasks:
break
await anyio.sleep(0)
assert running_tasks
running_tasks[next(iter(running_tasks))].scope.cancel()
assert len(bot.send_calls) == 1 # Progress message
assert len(bot.edit_calls) >= 1
+68 -11
View File
@@ -1,11 +1,13 @@
import asyncio
import anyio
import pytest
from takopi.exec_bridge import CodexExecRunner
from takopi.exec_bridge import CodexExecRunner, EventCallback
def test_run_serialized_serializes_same_session() -> None:
@pytest.mark.anyio
async def test_run_serialized_serializes_same_session() -> None:
runner = CodexExecRunner(codex_cmd="codex", extra_args=[])
gate = asyncio.Event()
gate = anyio.Event()
in_flight = 0
max_in_flight = 0
@@ -19,13 +21,68 @@ def test_run_serialized_serializes_same_session() -> None:
runner.run = run_stub # type: ignore[assignment]
async def run_test() -> None:
t1 = asyncio.create_task(runner.run_serialized("a", "sid"))
t2 = asyncio.create_task(runner.run_serialized("b", "sid"))
await asyncio.sleep(0)
async with anyio.create_task_group() as tg:
tg.start_soon(runner.run_serialized, "a", "sid")
tg.start_soon(runner.run_serialized, "b", "sid")
await anyio.sleep(0)
gate.set()
await asyncio.gather(t1, t2)
asyncio.run(run_test())
assert max_in_flight == 1
@pytest.mark.anyio
async def test_run_serialized_allows_parallel_new_sessions() -> None:
runner = CodexExecRunner(codex_cmd="codex", extra_args=[])
gate = anyio.Event()
in_flight = 0
max_in_flight = 0
async def run_stub(*_args, **_kwargs):
nonlocal in_flight, max_in_flight
in_flight += 1
max_in_flight = max(max_in_flight, in_flight)
await gate.wait()
in_flight -= 1
return ("sid", "ok", True)
runner.run = run_stub # type: ignore[assignment]
async with anyio.create_task_group() as tg:
tg.start_soon(runner.run_serialized, "a", None)
tg.start_soon(runner.run_serialized, "b", None)
with anyio.move_on_after(1):
while max_in_flight < 2:
await anyio.sleep(0)
gate.set()
assert max_in_flight == 2
@pytest.mark.anyio
async def test_new_session_holds_lock_for_resumes() -> None:
runner = CodexExecRunner(codex_cmd="codex", extra_args=[])
finish = anyio.Event()
resume_started = anyio.Event()
async def run_stub(
_prompt: str,
session_id: str | None,
on_event: EventCallback | None = None,
) -> tuple[str, str, bool]:
if session_id is None:
if on_event:
await on_event({"type": "thread.started", "thread_id": "sid"})
await finish.wait()
return ("sid", "ok", True)
resume_started.set()
return ("sid", "ok", True)
runner.run = run_stub # type: ignore[assignment]
async with anyio.create_task_group() as tg:
tg.start_soon(runner.run_serialized, "first", None)
await anyio.sleep(0)
tg.start_soon(runner.run_serialized, "resume", "sid")
await anyio.sleep(0)
assert not resume_started.is_set()
finish.set()
+7 -17
View File
@@ -1,29 +1,19 @@
import asyncio
import sys
import pytest
from takopi import exec_bridge
def test_manage_subprocess_kills_when_terminate_times_out(monkeypatch) -> None:
async def fake_wait_for(awaitable, *args, **kwargs):
if hasattr(awaitable, "close"):
awaitable.close()
elif hasattr(awaitable, "cancel"):
awaitable.cancel()
raise asyncio.TimeoutError
monkeypatch.setattr(exec_bridge.asyncio, "wait_for", fake_wait_for)
async def run() -> int | None:
@pytest.mark.anyio
async def test_manage_subprocess_kills_when_terminate_times_out() -> None:
async with exec_bridge.manage_subprocess(
sys.executable,
"-c",
"import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(10)",
terminate_timeout=0.01,
) as proc:
assert proc.returncode is None
return proc.returncode
rc = asyncio.run(run())
assert rc is not None
assert rc != 0
assert proc.returncode is not None
assert proc.returncode != 0
+8 -11
View File
@@ -1,4 +1,3 @@
import asyncio
import logging
import httpx
@@ -8,7 +7,8 @@ from takopi.logging import RedactTokenFilter
from takopi.telegram import TelegramClient
def test_telegram_429_no_retry() -> None:
@pytest.mark.anyio
async def test_telegram_429_no_retry() -> None:
calls: list[int] = []
def handler(request: httpx.Request) -> httpx.Response:
@@ -25,21 +25,21 @@ def test_telegram_429_no_retry() -> None:
transport = httpx.MockTransport(handler)
async def run() -> dict | None:
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", client=client)
return await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
result = await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
finally:
await client.aclose()
result = asyncio.run(run())
assert result is None
assert len(calls) == 1
def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> None:
@pytest.mark.anyio
async def test_no_token_in_logs_on_http_error(
caplog: pytest.LogCaptureFixture,
) -> None:
token = "123:abcDEF_ghij"
redactor = RedactTokenFilter()
root_logger = logging.getLogger()
@@ -50,7 +50,7 @@ def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> Non
transport = httpx.MockTransport(handler)
async def run() -> None:
caplog.set_level(logging.ERROR)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient(token, client=client)
@@ -58,9 +58,6 @@ def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> Non
finally:
await client.aclose()
caplog.set_level(logging.ERROR)
asyncio.run(run())
root_logger.removeFilter(redactor)
assert token not in caplog.text
Generated
+17
View File
@@ -331,6 +331,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
]
[[package]]
name = "pytest-anyio"
version = "0.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/00/44/a02e5877a671b0940f21a7a0d9704c22097b123ed5cdbcca9cab39f17acc/pytest-anyio-0.0.0.tar.gz", hash = "sha256:b41234e9e9ad7ea1dbfefcc1d6891b23d5ef7c9f07ccf804c13a9cc338571fd3", size = 1560, upload-time = "2021-06-29T22:57:30.846Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c6/25/bd6493ae85d0a281b6a0f248d0fdb1d9aa2b31f18bcd4a8800cf397d8209/pytest_anyio-0.0.0-py2.py3-none-any.whl", hash = "sha256:dc8b5c4741cb16ff90be37fddd585ca943ed12bbeb563de7ace6cd94441d8746", size = 1999, upload-time = "2021-06-29T22:57:29.158Z" },
]
[[package]]
name = "pytest-cov"
version = "7.0.0"
@@ -420,6 +433,7 @@ name = "takopi"
version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "anyio" },
{ name = "httpx" },
{ name = "markdown-it-py" },
{ name = "rich" },
@@ -430,6 +444,7 @@ dependencies = [
[package.dev-dependencies]
dev = [
{ name = "pytest" },
{ name = "pytest-anyio" },
{ name = "pytest-cov" },
{ name = "ruff" },
{ name = "ty" },
@@ -437,6 +452,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "anyio", specifier = ">=4.12.0" },
{ name = "httpx", specifier = ">=0.28.1" },
{ name = "markdown-it-py" },
{ name = "rich", specifier = ">=14.2.0" },
@@ -447,6 +463,7 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
{ name = "pytest", specifier = ">=9.0.2" },
{ name = "pytest-anyio", specifier = ">=0.0.0" },
{ name = "pytest-cov", specifier = ">=7.0.0" },
{ name = "ruff", specifier = ">=0.14.10" },
{ name = "ty", specifier = ">=0.0.8" },