refactor: migrate exec bridge to anyio and harden cancellation (#6)
This commit is contained in:
+3
-3
@@ -44,8 +44,8 @@ The orchestrator module containing:
|
||||
|
||||
**Key patterns:**
|
||||
- Per-session locks prevent concurrent resumes to the same `session_id`
|
||||
- Worker pool with `asyncio.Queue` limits concurrency (default: 16 workers)
|
||||
- `asyncio.TaskGroup` manages worker tasks
|
||||
- Worker pool with an AnyIO memory stream limits concurrency (default: 16 workers)
|
||||
- AnyIO task groups manage worker tasks
|
||||
- Progress edits are throttled to ~2s intervals
|
||||
- Subprocess stderr is drained to a bounded deque for error reporting
|
||||
- `poll_updates()` uses Telegram `getUpdates` long-polling with a single server-side updates
|
||||
@@ -154,5 +154,5 @@ Same as above, but:
|
||||
|----------|----------|
|
||||
| `codex exec` fails (rc≠0) | Shows stderr tail in error message |
|
||||
| Telegram API error | Logged, edit skipped (progress continues) |
|
||||
| Cancellation | Subprocess terminated, CancelledError re-raised |
|
||||
| Cancellation | Cancel scope triggers terminate; cancellation is detected via `cancelled_caught` |
|
||||
| No agent_message | Final shows "error" status |
|
||||
|
||||
@@ -7,6 +7,7 @@ readme = "readme.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"anyio>=4.12.0",
|
||||
"httpx>=0.28.1",
|
||||
"markdown-it-py",
|
||||
"rich>=14.2.0",
|
||||
@@ -38,6 +39,7 @@ build-backend = "uv_build"
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=9.0.2",
|
||||
"pytest-anyio>=0.0.0",
|
||||
"pytest-cov>=7.0.0",
|
||||
"ruff>=0.14.10",
|
||||
"ty>=0.0.8",
|
||||
|
||||
+212
-141
@@ -1,21 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import anyio
|
||||
import typer
|
||||
from anyio.abc import ByteReceiveStream, Process
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
|
||||
from . import __version__
|
||||
from .config import ConfigError, load_telegram_config
|
||||
@@ -61,32 +64,59 @@ def resolve_resume_session(text: str | None, reply_text: str | None) -> str | No
|
||||
return extract_session_id(text) or extract_session_id(reply_text)
|
||||
|
||||
|
||||
async def _drain_stderr(stderr: asyncio.StreamReader, tail: deque[str]) -> None:
|
||||
try:
|
||||
async def _iter_text_lines(stream: ByteReceiveStream) -> AsyncIterator[str]:
|
||||
text_stream = TextReceiveStream(stream, errors="replace")
|
||||
buffer = ""
|
||||
while True:
|
||||
try:
|
||||
chunk = await text_stream.receive()
|
||||
except anyio.EndOfStream:
|
||||
if buffer:
|
||||
yield buffer
|
||||
return
|
||||
buffer += chunk
|
||||
while True:
|
||||
line = await stderr.readline()
|
||||
if not line:
|
||||
return
|
||||
decoded = line.decode(errors="replace")
|
||||
logger.info("[codex][stderr] %s", decoded.rstrip())
|
||||
tail.append(decoded)
|
||||
split_at = buffer.find("\n")
|
||||
if split_at < 0:
|
||||
break
|
||||
line = buffer[: split_at + 1]
|
||||
buffer = buffer[split_at + 1 :]
|
||||
yield line
|
||||
|
||||
|
||||
async def _drain_stderr(stderr: ByteReceiveStream, tail: deque[str]) -> None:
|
||||
try:
|
||||
async for line in _iter_text_lines(stderr):
|
||||
logger.info("[codex][stderr] %s", line.rstrip())
|
||||
tail.append(line)
|
||||
except Exception as e:
|
||||
logger.debug("[codex][stderr] drain error: %s", e)
|
||||
|
||||
|
||||
async def _wait_for_process(proc: Process, timeout: float) -> bool:
|
||||
with anyio.move_on_after(timeout) as scope:
|
||||
await proc.wait()
|
||||
return scope.cancel_called
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def manage_subprocess(*args, **kwargs):
|
||||
proc = await asyncio.create_subprocess_exec(*args, **kwargs)
|
||||
async def manage_subprocess(*args, terminate_timeout: float = 2.0, **kwargs):
|
||||
proc = await anyio.open_process(args, **kwargs)
|
||||
try:
|
||||
yield proc
|
||||
finally:
|
||||
if proc.returncode is None:
|
||||
proc.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
with anyio.CancelScope(shield=True):
|
||||
try:
|
||||
proc.terminate()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
timed_out = await _wait_for_process(proc, terminate_timeout)
|
||||
if timed_out:
|
||||
logger.debug(
|
||||
"[codex] terminate timed out pid=%s; leaving process to exit",
|
||||
proc.pid,
|
||||
)
|
||||
|
||||
|
||||
TELEGRAM_MARKDOWN_LIMIT = 3500
|
||||
@@ -212,17 +242,17 @@ class ProgressEdits:
|
||||
self.last_rendered = last_rendered
|
||||
self._event_seq = 0
|
||||
self._published_seq = 0
|
||||
self.wakeup = asyncio.Event()
|
||||
self.task: asyncio.Task[None] | None = (
|
||||
asyncio.create_task(self.run()) if self.progress_id is not None else None
|
||||
)
|
||||
self.wakeup = anyio.Event()
|
||||
|
||||
async def _wait_for_wakeup(self) -> None:
|
||||
await self.wakeup.wait()
|
||||
self.wakeup = anyio.Event()
|
||||
|
||||
async def run(self) -> None:
|
||||
if self.progress_id is None:
|
||||
return
|
||||
while True:
|
||||
await self.wakeup.wait()
|
||||
self.wakeup.clear()
|
||||
await self._wait_for_wakeup()
|
||||
while self._published_seq < self._event_seq:
|
||||
await self.sleep(
|
||||
max(
|
||||
@@ -250,7 +280,6 @@ class ProgressEdits:
|
||||
self.last_rendered = rendered
|
||||
|
||||
self._published_seq = seq_at_render
|
||||
self.wakeup.clear()
|
||||
|
||||
async def on_event(self, evt: dict[str, Any]) -> None:
|
||||
if not self.renderer.note_event(evt):
|
||||
@@ -260,12 +289,6 @@ class ProgressEdits:
|
||||
self._event_seq += 1
|
||||
self.wakeup.set()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.task is None:
|
||||
return
|
||||
self.task.cancel()
|
||||
await asyncio.gather(self.task, return_exceptions=True)
|
||||
|
||||
|
||||
class CodexExecRunner:
|
||||
def __init__(
|
||||
@@ -277,14 +300,14 @@ class CodexExecRunner:
|
||||
self.extra_args = extra_args
|
||||
|
||||
# Per-session locks to prevent concurrent resumes to the same session_id.
|
||||
self._session_locks: WeakValueDictionary[str, asyncio.Lock] = (
|
||||
self._session_locks: WeakValueDictionary[str, anyio.Lock] = (
|
||||
WeakValueDictionary()
|
||||
)
|
||||
|
||||
async def _lock_for(self, session_id: str) -> asyncio.Lock:
|
||||
async def _lock_for(self, session_id: str) -> anyio.Lock:
|
||||
lock = self._session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
lock = anyio.Lock()
|
||||
self._session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
@@ -306,19 +329,23 @@ class CodexExecRunner:
|
||||
else:
|
||||
args.append("-")
|
||||
|
||||
cancelled_exc_type = anyio.get_cancelled_exc_class()
|
||||
cancelled_exc: BaseException | None = None
|
||||
async with manage_subprocess(
|
||||
*args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as proc:
|
||||
proc_stdin = cast(asyncio.StreamWriter, proc.stdin)
|
||||
proc_stdout = cast(asyncio.StreamReader, proc.stdout)
|
||||
proc_stderr = cast(asyncio.StreamReader, proc.stderr)
|
||||
if proc.stdin is None or proc.stdout is None or proc.stderr is None:
|
||||
raise RuntimeError("codex exec failed to open subprocess pipes")
|
||||
proc_stdin = proc.stdin
|
||||
proc_stdout = proc.stdout
|
||||
proc_stderr = proc.stderr
|
||||
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
||||
|
||||
stderr_tail: deque[str] = deque(maxlen=200)
|
||||
stderr_task = asyncio.create_task(_drain_stderr(proc_stderr, stderr_tail))
|
||||
rc: int | None = None
|
||||
|
||||
found_session: str | None = session_id
|
||||
last_agent_text: str | None = None
|
||||
@@ -326,62 +353,57 @@ class CodexExecRunner:
|
||||
cli_last_item: int | None = None
|
||||
|
||||
cancelled = False
|
||||
rc: int | None = None
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(_drain_stderr, proc_stderr, stderr_tail)
|
||||
|
||||
try:
|
||||
proc_stdin.write(prompt.encode())
|
||||
await proc_stdin.drain()
|
||||
proc_stdin.close()
|
||||
try:
|
||||
await proc_stdin.send(prompt.encode())
|
||||
await proc_stdin.aclose()
|
||||
|
||||
async for raw_line in proc_stdout:
|
||||
raw = raw_line.decode(errors="replace")
|
||||
logger.debug("[codex][jsonl] %s", raw.rstrip("\n"))
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
evt = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("[codex][jsonl] invalid line: %r", line)
|
||||
continue
|
||||
|
||||
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
|
||||
for out in out_lines:
|
||||
logger.info("[codex] %s", out)
|
||||
|
||||
if on_event is not None:
|
||||
async for raw_line in _iter_text_lines(proc_stdout):
|
||||
raw = raw_line.rstrip("\n")
|
||||
logger.debug("[codex][jsonl] %s", raw)
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
res = on_event(evt)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
except Exception as e:
|
||||
logger.info("[codex][on_event] callback error: %s", e)
|
||||
evt = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("[codex][jsonl] invalid line: %r", line)
|
||||
continue
|
||||
|
||||
if evt["type"] == "thread.started":
|
||||
found_session = evt.get("thread_id") or found_session
|
||||
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
|
||||
for out in out_lines:
|
||||
logger.info("[codex] %s", out)
|
||||
|
||||
if evt["type"] == "item.completed":
|
||||
item = evt.get("item") or {}
|
||||
if item.get("type") == "agent_message" and isinstance(
|
||||
item.get("text"), str
|
||||
):
|
||||
last_agent_text = item["text"]
|
||||
saw_agent_message = True
|
||||
except asyncio.CancelledError:
|
||||
cancelled = True
|
||||
finally:
|
||||
if cancelled:
|
||||
if not stderr_task.done():
|
||||
stderr_task.cancel()
|
||||
task = cast(asyncio.Task, asyncio.current_task())
|
||||
while task.cancelling():
|
||||
task.uncancel()
|
||||
if not cancelled:
|
||||
rc = await proc.wait()
|
||||
await asyncio.gather(stderr_task, return_exceptions=True)
|
||||
if on_event is not None:
|
||||
try:
|
||||
res = on_event(evt)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
except Exception as e:
|
||||
logger.info("[codex][on_event] callback error: %s", e)
|
||||
|
||||
if evt["type"] == "thread.started":
|
||||
found_session = evt.get("thread_id") or found_session
|
||||
|
||||
if evt["type"] == "item.completed":
|
||||
item = evt.get("item") or {}
|
||||
if item.get("type") == "agent_message" and isinstance(
|
||||
item.get("text"), str
|
||||
):
|
||||
last_agent_text = item["text"]
|
||||
saw_agent_message = True
|
||||
except cancelled_exc_type as exc:
|
||||
cancelled = True
|
||||
cancelled_exc = exc
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
if not cancelled:
|
||||
rc = await proc.wait()
|
||||
|
||||
if cancelled:
|
||||
raise asyncio.CancelledError
|
||||
raise cancelled_exc # type: ignore[misc]
|
||||
|
||||
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
|
||||
if rc != 0:
|
||||
@@ -406,11 +428,31 @@ class CodexExecRunner:
|
||||
session_id: str | None,
|
||||
on_event: EventCallback | None = None,
|
||||
) -> tuple[str, str, bool]:
|
||||
if not session_id:
|
||||
return await self.run(prompt, session_id=None, on_event=on_event)
|
||||
lock = await self._lock_for(session_id)
|
||||
async with lock:
|
||||
return await self.run(prompt, session_id=session_id, on_event=on_event)
|
||||
if session_id:
|
||||
lock = await self._lock_for(session_id)
|
||||
async with lock:
|
||||
return await self.run(prompt, session_id=session_id, on_event=on_event)
|
||||
|
||||
session_lock: anyio.Lock | None = None
|
||||
|
||||
async def on_event_with_lock(evt: dict[str, Any]) -> None:
|
||||
nonlocal session_lock
|
||||
if session_lock is None and evt.get("type") == "thread.started":
|
||||
thread_id = evt.get("thread_id")
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
session_lock = await self._lock_for(thread_id)
|
||||
await session_lock.acquire()
|
||||
if on_event is None:
|
||||
return
|
||||
res = on_event(evt)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
try:
|
||||
return await self.run(prompt, session_id=None, on_event=on_event_with_lock)
|
||||
finally:
|
||||
if session_lock is not None:
|
||||
session_lock.release()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -423,6 +465,12 @@ class BridgeConfig:
|
||||
max_concurrency: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunningTask:
|
||||
scope: anyio.CancelScope
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
def _parse_bridge_config(
|
||||
*,
|
||||
final_notify: bool,
|
||||
@@ -508,9 +556,9 @@ async def handle_message(
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
resume_session: str | None,
|
||||
running_tasks: dict[str, asyncio.Task[Any]] | None = None,
|
||||
running_tasks: dict[int, RunningTask] | None = None,
|
||||
clock: Callable[[], float] = time.monotonic,
|
||||
sleep: Callable[[float], Awaitable[None]] = asyncio.sleep,
|
||||
sleep: Callable[[float], Awaitable[None]] = anyio.sleep,
|
||||
progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
@@ -565,31 +613,54 @@ async def handle_message(
|
||||
last_rendered=last_rendered,
|
||||
)
|
||||
|
||||
exec_task: asyncio.Task[tuple[str, str, bool]] | None = None
|
||||
exec_scope = anyio.CancelScope()
|
||||
cancelled = False
|
||||
error: Exception | None = None
|
||||
session_id: str | None = None
|
||||
answer: str | None = None
|
||||
saw_agent_message: bool | None = None
|
||||
running_task: RunningTask | None = None
|
||||
if running_tasks is not None and progress_id is not None:
|
||||
running_task = RunningTask(scope=exec_scope)
|
||||
running_tasks[progress_id] = running_task
|
||||
if resume_session is not None:
|
||||
running_task.session_id = resume_session
|
||||
|
||||
async def on_event(evt: dict[str, Any]) -> None:
|
||||
if (
|
||||
evt["type"] == "thread.started"
|
||||
and running_tasks is not None
|
||||
and exec_task is not None
|
||||
running_task is not None
|
||||
and running_task.session_id is None
|
||||
and evt.get("type") == "thread.started"
|
||||
):
|
||||
running_tasks[evt["thread_id"]] = exec_task
|
||||
thread_id = evt.get("thread_id")
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
running_task.session_id = thread_id
|
||||
await edits.on_event(evt)
|
||||
|
||||
exec_task = asyncio.create_task(
|
||||
cfg.runner.run_serialized(text, resume_session, on_event=on_event)
|
||||
)
|
||||
async with anyio.create_task_group() as tg:
|
||||
if progress_id is not None:
|
||||
tg.start_soon(edits.run)
|
||||
|
||||
cancelled = False
|
||||
try:
|
||||
session_id, answer, saw_agent_message = await exec_task
|
||||
except asyncio.CancelledError:
|
||||
cancelled = True
|
||||
session_id = progress_renderer.resume_session or resume_session
|
||||
except Exception as e:
|
||||
await edits.shutdown()
|
||||
try:
|
||||
with exec_scope:
|
||||
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
|
||||
text, resume_session, on_event=on_event
|
||||
)
|
||||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
if running_task is not None:
|
||||
if running_tasks is not None and progress_id is not None:
|
||||
running_tasks.pop(progress_id, None)
|
||||
if exec_scope.cancelled_caught and not cancelled and error is None:
|
||||
cancelled = True
|
||||
session_id = progress_renderer.resume_session or resume_session
|
||||
if not cancelled and error is None:
|
||||
await anyio.sleep(0)
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
err = _clamp_tg_text(f"Error:\n{e}")
|
||||
if error is not None:
|
||||
err = _clamp_tg_text(f"Error:\n{error}")
|
||||
logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err)
|
||||
await _send_or_edit_markdown(
|
||||
cfg.bot,
|
||||
@@ -601,16 +672,11 @@ async def handle_message(
|
||||
limit=TELEGRAM_MARKDOWN_LIMIT,
|
||||
)
|
||||
return
|
||||
finally:
|
||||
if running_tasks is not None:
|
||||
for sid, task in list(running_tasks.items()):
|
||||
if task is exec_task:
|
||||
running_tasks.pop(sid, None)
|
||||
|
||||
await edits.shutdown()
|
||||
|
||||
elapsed = clock() - started_at
|
||||
if cancelled:
|
||||
if session_id is None:
|
||||
session_id = progress_renderer.resume_session or resume_session
|
||||
logger.info(
|
||||
"[handle] cancelled session_id=%s elapsed=%.1fs", session_id, elapsed
|
||||
)
|
||||
@@ -627,6 +693,9 @@ async def handle_message(
|
||||
)
|
||||
return
|
||||
|
||||
if session_id is None or answer is None or saw_agent_message is None:
|
||||
raise RuntimeError("codex exec finished without a result")
|
||||
|
||||
status = "done" if saw_agent_message else "error"
|
||||
progress_renderer.resume_session = session_id
|
||||
final_md = progress_renderer.render_final(elapsed, answer, status=status)
|
||||
@@ -679,7 +748,7 @@ async def poll_updates(cfg: BridgeConfig):
|
||||
)
|
||||
if updates is None:
|
||||
logger.info("[loop] getUpdates failed")
|
||||
await asyncio.sleep(2)
|
||||
await anyio.sleep(2)
|
||||
continue
|
||||
logger.debug("[loop] updates: %s", updates)
|
||||
|
||||
@@ -696,7 +765,7 @@ async def poll_updates(cfg: BridgeConfig):
|
||||
async def _handle_cancel(
|
||||
cfg: BridgeConfig,
|
||||
msg: dict[str, Any],
|
||||
running_tasks: dict[str, asyncio.Task[Any]],
|
||||
running_tasks: dict[int, RunningTask],
|
||||
) -> None:
|
||||
chat_id = msg["chat"]["id"]
|
||||
user_msg_id = msg["message_id"]
|
||||
@@ -710,8 +779,8 @@ async def _handle_cancel(
|
||||
)
|
||||
return
|
||||
|
||||
session_id = extract_session_id(reply.get("text"))
|
||||
if not session_id:
|
||||
progress_id = reply.get("message_id")
|
||||
if progress_id is None:
|
||||
await cfg.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text="nothing is currently running for that message.",
|
||||
@@ -719,8 +788,8 @@ async def _handle_cancel(
|
||||
)
|
||||
return
|
||||
|
||||
task = running_tasks.get(session_id)
|
||||
if not task or task.done():
|
||||
running_task = running_tasks.get(int(progress_id))
|
||||
if running_task is None:
|
||||
await cfg.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text="nothing is currently running for that message.",
|
||||
@@ -728,20 +797,20 @@ async def _handle_cancel(
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("[cancel] cancelling session_id=%s", session_id)
|
||||
task.cancel()
|
||||
logger.info("[cancel] cancelling progress_message_id=%s", progress_id)
|
||||
running_task.scope.cancel()
|
||||
|
||||
|
||||
async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
worker_count = max(1, min(cfg.max_concurrency, 16))
|
||||
queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue(
|
||||
maxsize=worker_count * 2
|
||||
send_stream, receive_stream = anyio.create_memory_object_stream(
|
||||
max_buffer_size=worker_count * 2
|
||||
)
|
||||
running_tasks: dict[str, asyncio.Task[Any]] = {}
|
||||
running_tasks: dict[int, RunningTask] = {}
|
||||
|
||||
async def worker() -> None:
|
||||
while True:
|
||||
chat_id, user_msg_id, text, resume_session = await queue.get()
|
||||
chat_id, user_msg_id, text, resume_session = await receive_stream.receive()
|
||||
try:
|
||||
await handle_message(
|
||||
cfg,
|
||||
@@ -753,26 +822,28 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[handle] worker failed")
|
||||
finally:
|
||||
queue.task_done()
|
||||
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
async with anyio.create_task_group() as tg:
|
||||
for _ in range(worker_count):
|
||||
tg.create_task(worker())
|
||||
tg.start_soon(worker)
|
||||
async for msg in poll_updates(cfg):
|
||||
text = msg["text"]
|
||||
user_msg_id = msg["message_id"]
|
||||
|
||||
if text == "/cancel":
|
||||
tg.create_task(_handle_cancel(cfg, msg, running_tasks))
|
||||
tg.start_soon(_handle_cancel, cfg, msg, running_tasks)
|
||||
continue
|
||||
|
||||
r = msg.get("reply_to_message") or {}
|
||||
resume_session = resolve_resume_session(text, r.get("text"))
|
||||
|
||||
await queue.put((msg["chat"]["id"], user_msg_id, text, resume_session))
|
||||
await send_stream.send(
|
||||
(msg["chat"]["id"], user_msg_id, text, resume_session)
|
||||
)
|
||||
finally:
|
||||
await send_stream.aclose()
|
||||
await receive_stream.aclose()
|
||||
await cfg.bot.close()
|
||||
|
||||
|
||||
@@ -813,7 +884,7 @@ def run(
|
||||
except ConfigError as e:
|
||||
typer.echo(str(e), err=True)
|
||||
raise typer.Exit(code=1)
|
||||
asyncio.run(_run_main_loop(cfg))
|
||||
anyio.run(_run_main_loop, cfg)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anyio_backend() -> str:
|
||||
return "asyncio"
|
||||
|
||||
+152
-111
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from takopi.exec_bridge import (
|
||||
@@ -187,7 +186,7 @@ class _FakeClock:
|
||||
def __init__(self, start: float = 0.0) -> None:
|
||||
self._now = start
|
||||
self._sleep_until: float | None = None
|
||||
self._sleep_event: asyncio.Event | None = None
|
||||
self._sleep_event: anyio.Event | None = None
|
||||
self.sleep_calls = 0
|
||||
|
||||
def __call__(self) -> float:
|
||||
@@ -205,10 +204,10 @@ class _FakeClock:
|
||||
async def sleep(self, delay: float) -> None:
|
||||
self.sleep_calls += 1
|
||||
if delay <= 0:
|
||||
await asyncio.sleep(0)
|
||||
await anyio.sleep(0)
|
||||
return
|
||||
self._sleep_until = self._now + delay
|
||||
self._sleep_event = asyncio.Event()
|
||||
self._sleep_event = anyio.Event()
|
||||
await self._sleep_event.wait()
|
||||
|
||||
|
||||
@@ -222,7 +221,7 @@ class _FakeRunnerWithEvents:
|
||||
answer: str = "ok",
|
||||
session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2",
|
||||
advance_after: float | None = None,
|
||||
hold: asyncio.Event | None = None,
|
||||
hold: anyio.Event | None = None,
|
||||
) -> None:
|
||||
self._events = events
|
||||
self._times = times
|
||||
@@ -238,16 +237,17 @@ class _FakeRunnerWithEvents:
|
||||
for when, event in zip(self._times, self._events, strict=False):
|
||||
self._clock.set(when)
|
||||
await on_event(event)
|
||||
await asyncio.sleep(0)
|
||||
await anyio.sleep(0)
|
||||
if self._advance_after is not None:
|
||||
self._clock.set(self._advance_after)
|
||||
await asyncio.sleep(0)
|
||||
await anyio.sleep(0)
|
||||
if self._hold is not None:
|
||||
await self._hold.wait()
|
||||
return (self._session_id, self._answer, True)
|
||||
|
||||
|
||||
def test_final_notify_sends_loud_final_message() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_final_notify_sends_loud_final_message() -> None:
|
||||
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||
|
||||
bot = _FakeBot()
|
||||
@@ -261,14 +261,12 @@ def test_final_notify_sends_loud_final_message() -> None:
|
||||
max_concurrency=1,
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="hi",
|
||||
resume_session=None,
|
||||
)
|
||||
await handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="hi",
|
||||
resume_session=None,
|
||||
)
|
||||
|
||||
assert len(bot.send_calls) == 2
|
||||
@@ -276,7 +274,8 @@ def test_final_notify_sends_loud_final_message() -> None:
|
||||
assert bot.send_calls[1]["disable_notification"] is False
|
||||
|
||||
|
||||
def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
|
||||
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||
|
||||
bot = _FakeBot()
|
||||
@@ -290,14 +289,12 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
|
||||
max_concurrency=1,
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="hi",
|
||||
resume_session=None,
|
||||
)
|
||||
await handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="hi",
|
||||
resume_session=None,
|
||||
)
|
||||
|
||||
assert len(bot.send_calls) == 2
|
||||
@@ -305,7 +302,8 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
|
||||
assert bot.send_calls[1]["disable_notification"] is False
|
||||
|
||||
|
||||
def test_progress_edits_are_rate_limited() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_progress_edits_are_rate_limited() -> None:
|
||||
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||
|
||||
bot = _FakeBot()
|
||||
@@ -345,29 +343,28 @@ def test_progress_edits_are_rate_limited() -> None:
|
||||
max_concurrency=1,
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="hi",
|
||||
resume_session=None,
|
||||
clock=clock,
|
||||
sleep=clock.sleep,
|
||||
progress_edit_every=1.0,
|
||||
)
|
||||
await handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="hi",
|
||||
resume_session=None,
|
||||
clock=clock,
|
||||
sleep=clock.sleep,
|
||||
progress_edit_every=1.0,
|
||||
)
|
||||
|
||||
assert len(bot.edit_calls) == 1
|
||||
assert "echo 2" in bot.edit_calls[0]["text"]
|
||||
|
||||
|
||||
def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
|
||||
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||
|
||||
bot = _FakeBot()
|
||||
clock = _FakeClock()
|
||||
hold = asyncio.Event()
|
||||
hold = anyio.Event()
|
||||
events = [
|
||||
{
|
||||
"type": "item.started",
|
||||
@@ -404,24 +401,25 @@ def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
|
||||
max_concurrency=1,
|
||||
)
|
||||
|
||||
async def run_test() -> None:
|
||||
task = asyncio.create_task(
|
||||
handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="hi",
|
||||
resume_session=None,
|
||||
clock=clock,
|
||||
sleep=clock.sleep,
|
||||
progress_edit_every=1.0,
|
||||
)
|
||||
async def run_handle_message() -> None:
|
||||
await handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="hi",
|
||||
resume_session=None,
|
||||
clock=clock,
|
||||
sleep=clock.sleep,
|
||||
progress_edit_every=1.0,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(run_handle_message)
|
||||
|
||||
for _ in range(100):
|
||||
if clock._sleep_until is not None:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
await anyio.sleep(0)
|
||||
|
||||
assert clock._sleep_until == pytest.approx(1.0)
|
||||
|
||||
@@ -430,23 +428,21 @@ def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
|
||||
for _ in range(100):
|
||||
if bot.edit_calls:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
await anyio.sleep(0)
|
||||
|
||||
assert len(bot.edit_calls) == 1
|
||||
|
||||
for _ in range(5):
|
||||
await asyncio.sleep(0)
|
||||
await anyio.sleep(0)
|
||||
|
||||
assert clock.sleep_calls == 1
|
||||
assert clock._sleep_until is None
|
||||
|
||||
hold.set()
|
||||
await task
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
|
||||
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||
|
||||
bot = _FakeBot()
|
||||
@@ -489,17 +485,15 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
|
||||
max_concurrency=1,
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=42,
|
||||
text="do it",
|
||||
resume_session=None,
|
||||
clock=clock,
|
||||
sleep=clock.sleep,
|
||||
progress_edit_every=1.0,
|
||||
)
|
||||
await handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=42,
|
||||
text="do it",
|
||||
resume_session=None,
|
||||
clock=clock,
|
||||
sleep=clock.sleep,
|
||||
progress_edit_every=1.0,
|
||||
)
|
||||
|
||||
assert bot.send_calls[0]["reply_to_message_id"] == 42
|
||||
@@ -510,7 +504,8 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
|
||||
assert len(bot.delete_calls) == 1
|
||||
|
||||
|
||||
def test_handle_cancel_without_reply_prompts_user() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_handle_cancel_without_reply_prompts_user() -> None:
|
||||
from takopi.exec_bridge import BridgeConfig, _handle_cancel
|
||||
|
||||
bot = _FakeBot()
|
||||
@@ -526,13 +521,14 @@ def test_handle_cancel_without_reply_prompts_user() -> None:
|
||||
msg = {"chat": {"id": 123}, "message_id": 10}
|
||||
running_tasks: dict = {}
|
||||
|
||||
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
|
||||
await _handle_cancel(cfg, msg, running_tasks)
|
||||
|
||||
assert len(bot.send_calls) == 1
|
||||
assert "reply to the progress message" in bot.send_calls[0]["text"]
|
||||
|
||||
|
||||
def test_handle_cancel_with_no_session_id_says_nothing_running() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_handle_cancel_with_no_progress_message_says_nothing_running() -> None:
|
||||
from takopi.exec_bridge import BridgeConfig, _handle_cancel
|
||||
|
||||
bot = _FakeBot()
|
||||
@@ -548,17 +544,18 @@ def test_handle_cancel_with_no_session_id_says_nothing_running() -> None:
|
||||
msg = {
|
||||
"chat": {"id": 123},
|
||||
"message_id": 10,
|
||||
"reply_to_message": {"text": "no uuid here"},
|
||||
"reply_to_message": {"text": "no message id"},
|
||||
}
|
||||
running_tasks: dict = {}
|
||||
|
||||
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
|
||||
await _handle_cancel(cfg, msg, running_tasks)
|
||||
|
||||
assert len(bot.send_calls) == 1
|
||||
assert "nothing is currently running" in bot.send_calls[0]["text"]
|
||||
|
||||
|
||||
def test_handle_cancel_with_finished_task_says_nothing_running() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_handle_cancel_with_finished_task_says_nothing_running() -> None:
|
||||
from takopi.exec_bridge import BridgeConfig, _handle_cancel
|
||||
|
||||
bot = _FakeBot()
|
||||
@@ -571,21 +568,22 @@ def test_handle_cancel_with_finished_task_says_nothing_running() -> None:
|
||||
startup_msg="",
|
||||
max_concurrency=1,
|
||||
)
|
||||
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
|
||||
progress_id = 99
|
||||
msg = {
|
||||
"chat": {"id": 123},
|
||||
"message_id": 10,
|
||||
"reply_to_message": {"text": f"resume: `{session_id}`"},
|
||||
"reply_to_message": {"message_id": progress_id},
|
||||
}
|
||||
running_tasks: dict = {} # Session not in running_tasks
|
||||
running_tasks: dict = {} # Progress message not in running_tasks
|
||||
|
||||
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
|
||||
await _handle_cancel(cfg, msg, running_tasks)
|
||||
|
||||
assert len(bot.send_calls) == 1
|
||||
assert "nothing is currently running" in bot.send_calls[0]["text"]
|
||||
|
||||
|
||||
def test_handle_cancel_cancels_running_task() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_handle_cancel_cancels_running_task() -> None:
|
||||
from takopi.exec_bridge import BridgeConfig, _handle_cancel
|
||||
|
||||
bot = _FakeBot()
|
||||
@@ -598,29 +596,70 @@ def test_handle_cancel_cancels_running_task() -> None:
|
||||
startup_msg="",
|
||||
max_concurrency=1,
|
||||
)
|
||||
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
|
||||
progress_id = 42
|
||||
msg = {
|
||||
"chat": {"id": 123},
|
||||
"message_id": 10,
|
||||
"reply_to_message": {"text": f"resume: `{session_id}`"},
|
||||
"reply_to_message": {"message_id": progress_id},
|
||||
}
|
||||
|
||||
async def run_test():
|
||||
task = asyncio.create_task(asyncio.sleep(10))
|
||||
running_tasks = {session_id: task}
|
||||
from takopi.exec_bridge import RunningTask
|
||||
|
||||
cancelled_event = anyio.Event()
|
||||
cancel_scope = anyio.CancelScope()
|
||||
running_task = RunningTask(scope=cancel_scope)
|
||||
|
||||
async def sleeper() -> None:
|
||||
with cancel_scope:
|
||||
try:
|
||||
await anyio.sleep(10)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
cancelled_event.set()
|
||||
return
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(sleeper)
|
||||
running_tasks = {progress_id: running_task}
|
||||
await _handle_cancel(cfg, msg, running_tasks)
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
return True
|
||||
return False
|
||||
await cancelled_event.wait()
|
||||
|
||||
cancelled = asyncio.run(run_test())
|
||||
|
||||
assert cancelled is True
|
||||
assert len(bot.send_calls) == 0 # No error message sent
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_handle_cancel_only_cancels_matching_progress_message() -> None:
|
||||
from takopi.exec_bridge import BridgeConfig, _handle_cancel
|
||||
|
||||
bot = _FakeBot()
|
||||
runner = _FakeRunner(answer="ok")
|
||||
cfg = BridgeConfig(
|
||||
bot=bot, # type: ignore[arg-type]
|
||||
runner=runner, # type: ignore[arg-type]
|
||||
chat_id=123,
|
||||
final_notify=True,
|
||||
startup_msg="",
|
||||
max_concurrency=1,
|
||||
)
|
||||
from takopi.exec_bridge import RunningTask
|
||||
|
||||
scope_first = anyio.CancelScope()
|
||||
scope_second = anyio.CancelScope()
|
||||
task_first = RunningTask(scope=scope_first)
|
||||
task_second = RunningTask(scope=scope_second)
|
||||
msg = {
|
||||
"chat": {"id": 123},
|
||||
"message_id": 10,
|
||||
"reply_to_message": {"message_id": 1},
|
||||
}
|
||||
running_tasks = {1: task_first, 2: task_second}
|
||||
|
||||
await _handle_cancel(cfg, msg, running_tasks)
|
||||
|
||||
assert scope_first.cancel_called is True
|
||||
assert scope_second.cancel_called is False
|
||||
assert len(bot.send_calls) == 0
|
||||
|
||||
|
||||
class _FakeRunnerCancellable:
|
||||
def __init__(self, session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2"):
|
||||
self._session_id = session_id
|
||||
@@ -629,11 +668,12 @@ class _FakeRunnerCancellable:
|
||||
on_event = kwargs.get("on_event")
|
||||
if on_event:
|
||||
await on_event({"type": "thread.started", "thread_id": self._session_id})
|
||||
await asyncio.sleep(10) # Will be cancelled
|
||||
await anyio.sleep(10) # Will be cancelled
|
||||
return (self._session_id, "ok", True)
|
||||
|
||||
|
||||
def test_handle_message_cancelled_renders_cancelled_state() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_handle_message_cancelled_renders_cancelled_state() -> None:
|
||||
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||
|
||||
bot = _FakeBot()
|
||||
@@ -649,23 +689,24 @@ def test_handle_message_cancelled_renders_cancelled_state() -> None:
|
||||
)
|
||||
running_tasks: dict = {}
|
||||
|
||||
async def run_test():
|
||||
task = asyncio.create_task(
|
||||
handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="do something",
|
||||
resume_session=None,
|
||||
running_tasks=running_tasks,
|
||||
)
|
||||
async def run_handle_message() -> None:
|
||||
await handle_message(
|
||||
cfg,
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="do something",
|
||||
resume_session=None,
|
||||
running_tasks=running_tasks,
|
||||
)
|
||||
await asyncio.sleep(0.01) # Let task start and register
|
||||
assert session_id in running_tasks
|
||||
running_tasks[session_id].cancel()
|
||||
await task
|
||||
|
||||
asyncio.run(run_test())
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(run_handle_message)
|
||||
for _ in range(100):
|
||||
if running_tasks:
|
||||
break
|
||||
await anyio.sleep(0)
|
||||
assert running_tasks
|
||||
running_tasks[next(iter(running_tasks))].scope.cancel()
|
||||
|
||||
assert len(bot.send_calls) == 1 # Progress message
|
||||
assert len(bot.edit_calls) >= 1
|
||||
|
||||
+68
-11
@@ -1,11 +1,13 @@
|
||||
import asyncio
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from takopi.exec_bridge import CodexExecRunner
|
||||
from takopi.exec_bridge import CodexExecRunner, EventCallback
|
||||
|
||||
|
||||
def test_run_serialized_serializes_same_session() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_run_serialized_serializes_same_session() -> None:
|
||||
runner = CodexExecRunner(codex_cmd="codex", extra_args=[])
|
||||
gate = asyncio.Event()
|
||||
gate = anyio.Event()
|
||||
in_flight = 0
|
||||
max_in_flight = 0
|
||||
|
||||
@@ -19,13 +21,68 @@ def test_run_serialized_serializes_same_session() -> None:
|
||||
|
||||
runner.run = run_stub # type: ignore[assignment]
|
||||
|
||||
async def run_test() -> None:
|
||||
t1 = asyncio.create_task(runner.run_serialized("a", "sid"))
|
||||
t2 = asyncio.create_task(runner.run_serialized("b", "sid"))
|
||||
await asyncio.sleep(0)
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(runner.run_serialized, "a", "sid")
|
||||
tg.start_soon(runner.run_serialized, "b", "sid")
|
||||
await anyio.sleep(0)
|
||||
gate.set()
|
||||
await asyncio.gather(t1, t2)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
assert max_in_flight == 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_serialized_allows_parallel_new_sessions() -> None:
|
||||
runner = CodexExecRunner(codex_cmd="codex", extra_args=[])
|
||||
gate = anyio.Event()
|
||||
in_flight = 0
|
||||
max_in_flight = 0
|
||||
|
||||
async def run_stub(*_args, **_kwargs):
|
||||
nonlocal in_flight, max_in_flight
|
||||
in_flight += 1
|
||||
max_in_flight = max(max_in_flight, in_flight)
|
||||
await gate.wait()
|
||||
in_flight -= 1
|
||||
return ("sid", "ok", True)
|
||||
|
||||
runner.run = run_stub # type: ignore[assignment]
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(runner.run_serialized, "a", None)
|
||||
tg.start_soon(runner.run_serialized, "b", None)
|
||||
with anyio.move_on_after(1):
|
||||
while max_in_flight < 2:
|
||||
await anyio.sleep(0)
|
||||
gate.set()
|
||||
|
||||
assert max_in_flight == 2
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_new_session_holds_lock_for_resumes() -> None:
|
||||
runner = CodexExecRunner(codex_cmd="codex", extra_args=[])
|
||||
finish = anyio.Event()
|
||||
resume_started = anyio.Event()
|
||||
|
||||
async def run_stub(
|
||||
_prompt: str,
|
||||
session_id: str | None,
|
||||
on_event: EventCallback | None = None,
|
||||
) -> tuple[str, str, bool]:
|
||||
if session_id is None:
|
||||
if on_event:
|
||||
await on_event({"type": "thread.started", "thread_id": "sid"})
|
||||
await finish.wait()
|
||||
return ("sid", "ok", True)
|
||||
resume_started.set()
|
||||
return ("sid", "ok", True)
|
||||
|
||||
runner.run = run_stub # type: ignore[assignment]
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(runner.run_serialized, "first", None)
|
||||
await anyio.sleep(0)
|
||||
tg.start_soon(runner.run_serialized, "resume", "sid")
|
||||
await anyio.sleep(0)
|
||||
assert not resume_started.is_set()
|
||||
finish.set()
|
||||
|
||||
+13
-23
@@ -1,29 +1,19 @@
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from takopi import exec_bridge
|
||||
|
||||
|
||||
def test_manage_subprocess_kills_when_terminate_times_out(monkeypatch) -> None:
|
||||
async def fake_wait_for(awaitable, *args, **kwargs):
|
||||
if hasattr(awaitable, "close"):
|
||||
awaitable.close()
|
||||
elif hasattr(awaitable, "cancel"):
|
||||
awaitable.cancel()
|
||||
raise asyncio.TimeoutError
|
||||
@pytest.mark.anyio
|
||||
async def test_manage_subprocess_kills_when_terminate_times_out() -> None:
|
||||
async with exec_bridge.manage_subprocess(
|
||||
sys.executable,
|
||||
"-c",
|
||||
"import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(10)",
|
||||
terminate_timeout=0.01,
|
||||
) as proc:
|
||||
assert proc.returncode is None
|
||||
|
||||
monkeypatch.setattr(exec_bridge.asyncio, "wait_for", fake_wait_for)
|
||||
|
||||
async def run() -> int | None:
|
||||
async with exec_bridge.manage_subprocess(
|
||||
sys.executable,
|
||||
"-c",
|
||||
"import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(10)",
|
||||
) as proc:
|
||||
assert proc.returncode is None
|
||||
return proc.returncode
|
||||
|
||||
rc = asyncio.run(run())
|
||||
|
||||
assert rc is not None
|
||||
assert rc != 0
|
||||
assert proc.returncode is not None
|
||||
assert proc.returncode != 0
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
@@ -8,7 +7,8 @@ from takopi.logging import RedactTokenFilter
|
||||
from takopi.telegram import TelegramClient
|
||||
|
||||
|
||||
def test_telegram_429_no_retry() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_telegram_429_no_retry() -> None:
|
||||
calls: list[int] = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
@@ -25,21 +25,21 @@ def test_telegram_429_no_retry() -> None:
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
|
||||
async def run() -> dict | None:
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", client=client)
|
||||
return await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
result = asyncio.run(run())
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", client=client)
|
||||
result = await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
assert result is None
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_no_token_in_logs_on_http_error(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
token = "123:abcDEF_ghij"
|
||||
redactor = RedactTokenFilter()
|
||||
root_logger = logging.getLogger()
|
||||
@@ -50,16 +50,13 @@ def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> Non
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
|
||||
async def run() -> None:
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient(token, client=client)
|
||||
await tg._post("getUpdates", {"timeout": 1})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
caplog.set_level(logging.ERROR)
|
||||
asyncio.run(run())
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient(token, client=client)
|
||||
await tg._post("getUpdates", {"timeout": 1})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
root_logger.removeFilter(redactor)
|
||||
|
||||
|
||||
@@ -331,6 +331,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-anyio"
|
||||
version = "0.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/00/44/a02e5877a671b0940f21a7a0d9704c22097b123ed5cdbcca9cab39f17acc/pytest-anyio-0.0.0.tar.gz", hash = "sha256:b41234e9e9ad7ea1dbfefcc1d6891b23d5ef7c9f07ccf804c13a9cc338571fd3", size = 1560, upload-time = "2021-06-29T22:57:30.846Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/25/bd6493ae85d0a281b6a0f248d0fdb1d9aa2b31f18bcd4a8800cf397d8209/pytest_anyio-0.0.0-py2.py3-none-any.whl", hash = "sha256:dc8b5c4741cb16ff90be37fddd585ca943ed12bbeb563de7ace6cd94441d8746", size = 1999, upload-time = "2021-06-29T22:57:29.158Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "7.0.0"
|
||||
@@ -420,6 +433,7 @@ name = "takopi"
|
||||
version = "0.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "httpx" },
|
||||
{ name = "markdown-it-py" },
|
||||
{ name = "rich" },
|
||||
@@ -430,6 +444,7 @@ dependencies = [
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-anyio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "ruff" },
|
||||
{ name = "ty" },
|
||||
@@ -437,6 +452,7 @@ dev = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "anyio", specifier = ">=4.12.0" },
|
||||
{ name = "httpx", specifier = ">=0.28.1" },
|
||||
{ name = "markdown-it-py" },
|
||||
{ name = "rich", specifier = ">=14.2.0" },
|
||||
@@ -447,6 +463,7 @@ requires-dist = [
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "pytest", specifier = ">=9.0.2" },
|
||||
{ name = "pytest-anyio", specifier = ">=0.0.0" },
|
||||
{ name = "pytest-cov", specifier = ">=7.0.0" },
|
||||
{ name = "ruff", specifier = ">=0.14.10" },
|
||||
{ name = "ty", specifier = ">=0.0.8" },
|
||||
|
||||
Reference in New Issue
Block a user