From 2d8fbc8a5a0cf054f90b22279405a0c3cb12ed02 Mon Sep 17 00:00:00 2001 From: banteg <4562643+banteg@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:00:37 +0400 Subject: [PATCH] feat: queue telegram requests with rate limits (#54) --- docs/developing.md | 2 + docs/specification.md | 4 +- docs/transports/telegram.md | 75 +++++ src/takopi/bridge.py | 52 +--- src/takopi/onboarding.py | 19 +- src/takopi/telegram.py | 542 ++++++++++++++++++++++++++++++---- tests/test_exec_bridge.py | 94 +++--- tests/test_telegram_client.py | 11 +- tests/test_telegram_queue.py | 251 ++++++++++++++++ 9 files changed, 898 insertions(+), 152 deletions(-) create mode 100644 docs/transports/telegram.md create mode 100644 tests/test_telegram_queue.py diff --git a/docs/developing.md b/docs/developing.md index 882926d..bd952f6 100644 --- a/docs/developing.md +++ b/docs/developing.md @@ -78,6 +78,8 @@ The orchestrator module containing: | `BotClient` | Protocol defining the bot client interface | | `TelegramClient` | HTTP client for Telegram Bot API (send, edit, delete messages) | +See `docs/transports/telegram.md` for outbox behavior, rate limiting, and retry rules. + ### `runners/codex.py` - Codex runner | Component | Purpose | diff --git a/docs/specification.md b/docs/specification.md index 6b9c2ae..732d337 100644 --- a/docs/specification.md +++ b/docs/specification.md @@ -247,7 +247,7 @@ The bridge MUST: * Resolve resume token (per §3.4) * Schedule runs per thread (per §6.2) * Start runner execution with cancellation support -* Maintain a progress message with rate-limited edits +* Maintain a progress message while avoiding excessive edits * Publish a final message containing status, answer, and resume line (when known) * Support `/cancel` for in-flight runs @@ -280,7 +280,7 @@ Runs that start as new threads: ### 6.3 Progress message behavior * The bridge SHOULD send an initial progress message quickly (e.g., “Running…”). -* The bridge SHOULD edit the progress message no more frequently than every **2 seconds**. +* The bridge SHOULD avoid excessive edits and respect transport constraints (implementation-defined). * The bridge SHOULD skip edits when rendered content is unchanged. * Once `started` is observed, the progress view SHOULD include the canonical ResumeLine. diff --git a/docs/transports/telegram.md b/docs/transports/telegram.md new file mode 100644 index 0000000..cd7a74b --- /dev/null +++ b/docs/transports/telegram.md @@ -0,0 +1,75 @@ +# Telegram Transport + +## Overview + +`TelegramClient` is the single transport for Telegram writes. It owns a +`TelegramOutbox` that serializes send/edit/delete operations, applies +coalescing, and enforces rate limits + retry-after backoff. + +This document captures current behavior so transport changes stay intentional. + +## Flow + +1. CLI emits JSON events. +2. We render progress on every step and diff against the last output. +3. Only deltas enqueue a Telegram edit. +4. High-value messages enqueue a send. +5. All writes go through the outbox. + +## Outbox model + +- Single worker processes one op at a time. +- Each op is keyed; only one pending op per key. +- New ops with the same key overwrite the payload but **do not** reset + `queued_at` (fairness). + +Keys (include `chat_id` to avoid cross-chat collisions): + +- `("edit", chat_id, message_id)` for edits (coalesced). +- `("delete", chat_id, message_id)` for deletes. +- `("send", chat_id, replace_message_id)` when replacing a progress message. +- Unique key for normal sends. + +Scheduling: + +- Ordered by `(priority, queued_at)`. +- Priorities: send=0, delete=1, edit=2. +- Within a priority tier, the oldest pending op runs first. +- `updated_at` is kept for debugging only. + +## Rate limiting + backoff + +- Per-chat pacing is computed from `private_chat_rps` and `group_chat_rps`. + Defaults: 1.0 msg/s for private, 20/60 msg/s for groups (≈1 message every 3s). +- Pacing is currently enforced via a single global `next_at`; per-chat + `next_at` is a future consideration if we ever run multiple chats in parallel. +- The worker waits until `max(next_at, retry_at)` before executing the next op. +- On 429, `RetryAfter` is raised using `parameters.retry_after` when present; + if missing, we fall back to a 5s delay. The outbox sets `retry_at` and + requeues the op if no newer op for the same key has arrived. + +## Error handling + +- Non-429 errors are logged and dropped (no retry). +- On `RetryAfter`, the op is retried unless a newer op superseded the same key. + +## Replace progress messages + +`send_message(replace_message_id=...)`: + +- Drops any pending edit for that progress message. +- Enqueues the send at highest priority. +- If the send succeeds, enqueues a delete for the old progress message. + +This keeps the final message first and avoids deleting progress if the send +fails. + +## getUpdates + +`get_updates` bypasses the outbox and retries on `RetryAfter` by sleeping +for the provided delay. + +## Close semantics + +`TelegramClient.close()` shuts down the outbox and closes the HTTP client. +Pending ops are failed with `None` (best-effort). diff --git a/src/takopi/bridge.py b/src/takopi/bridge.py index 3564bb5..bd7008c 100644 --- a/src/takopi/bridge.py +++ b/src/takopi/bridge.py @@ -154,9 +154,6 @@ def _format_error(error: Exception) -> str: return "\n".join(messages) -PROGRESS_EDIT_EVERY_S = 2.0 - - async def _send_or_edit_markdown( bot: BotClient, *, @@ -164,6 +161,7 @@ async def _send_or_edit_markdown( parts: MarkdownParts, edit_message_id: int | None = None, reply_to_message_id: int | None = None, + replace_message_id: int | None = None, disable_notification: bool = False, prepared: tuple[str, list[dict[str, Any]]] | None = None, ) -> tuple[dict[str, Any] | None, bool]: @@ -200,6 +198,7 @@ async def _send_or_edit_markdown( entities=entities, reply_to_message_id=reply_to_message_id, disable_notification=disable_notification, + replace_message_id=replace_message_id, ), False, ) @@ -214,10 +213,7 @@ class ProgressEdits: progress_id: int | None, renderer: ExecProgressRenderer, started_at: float, - progress_edit_every: float, clock: Callable[[], float], - sleep: Callable[[float], Awaitable[None]], - last_edit_at: float, last_rendered: str | None, ) -> None: self.bot = bot @@ -225,10 +221,7 @@ class ProgressEdits: self.progress_id = progress_id self.renderer = renderer self.started_at = started_at - self.progress_edit_every = progress_edit_every self.clock = clock - self.sleep = sleep - self.last_edit_at = last_edit_at self.last_rendered = last_rendered self.event_seq = 0 self.rendered_seq = 0 @@ -244,13 +237,6 @@ class ProgressEdits: except anyio.EndOfStream: return - await self.sleep( - max( - 0.0, - self.last_edit_at + self.progress_edit_every - self.clock(), - ) - ) - seq_at_render = self.event_seq now = self.clock() parts = self.renderer.render_progress_parts(now - self.started_at) @@ -262,15 +248,14 @@ class ProgressEdits: message_id=self.progress_id, rendered=rendered, ) - self.last_edit_at = now - edited = await self.bot.edit_message_text( + await self.bot.edit_message_text( chat_id=self.chat_id, message_id=self.progress_id, text=rendered, entities=entities, + wait=False, ) - if edited is not None: - self.last_rendered = rendered + self.last_rendered = rendered self.rendered_seq = seq_at_render @@ -295,7 +280,6 @@ class BridgeConfig: chat_id: int final_notify: bool startup_msg: str - progress_edit_every: float = PROGRESS_EDIT_EVERY_S @dataclass @@ -338,7 +322,6 @@ async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None: @dataclass(frozen=True, slots=True) class ProgressMessageState: message_id: int | None - last_edit_at: float last_rendered: str | None @@ -352,7 +335,6 @@ async def send_initial_progress( clock: Callable[[], float], ) -> ProgressMessageState: progress_id: int | None = None - last_edit_at = 0.0 last_rendered: str | None = None initial_parts = renderer.render_progress_parts(0.0, label=label) @@ -372,7 +354,6 @@ async def send_initial_progress( ) if progress_msg is not None: progress_id = int(progress_msg["message_id"]) - last_edit_at = clock() last_rendered = initial_rendered logger.debug( "progress.sent", @@ -382,7 +363,6 @@ async def send_initial_progress( return ProgressMessageState( message_id=progress_id, - last_edit_at=last_edit_at, last_rendered=last_rendered, ) @@ -455,7 +435,6 @@ async def send_result_message( disable_notification: bool, edit_message_id: int | None, prepared: tuple[str, list[dict[str, Any]]] | None = None, - delete_tag: str = "final", ) -> None: final_msg, edited = await _send_or_edit_markdown( cfg.bot, @@ -463,19 +442,12 @@ async def send_result_message( parts=parts, edit_message_id=edit_message_id, reply_to_message_id=user_msg_id, + replace_message_id=progress_id, disable_notification=disable_notification, prepared=prepared, ) if final_msg is None: return - if progress_id is not None and (edit_message_id is None or not edited): - logger.debug( - "telegram.delete_message", - chat_id=chat_id, - message_id=progress_id, - tag=delete_tag, - ) - await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id) async def handle_message( @@ -491,8 +463,6 @@ async def handle_message( on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None = None, clock: Callable[[], float] = time.monotonic, - sleep: Callable[[float], Awaitable[None]] = anyio.sleep, - progress_edit_every: float = PROGRESS_EDIT_EVERY_S, ) -> None: logger.info( "handle.incoming", @@ -526,10 +496,7 @@ async def handle_message( progress_id=progress_id, renderer=progress_renderer, started_at=started_at, - progress_edit_every=progress_edit_every, clock=clock, - sleep=sleep, - last_edit_at=progress_state.last_edit_at, last_rendered=progress_state.last_rendered, ) @@ -606,7 +573,6 @@ async def handle_message( parts=final_parts, disable_notification=True, edit_message_id=progress_id, - delete_tag="error", ) return @@ -628,7 +594,6 @@ async def handle_message( parts=final_parts, disable_notification=True, edit_message_id=progress_id, - delete_tag="cancel", ) return @@ -685,7 +650,6 @@ async def handle_message( disable_notification=False, edit_message_id=edit_message_id, prepared=(final_rendered, final_entities), - delete_tag="final", ) @@ -888,7 +852,6 @@ async def run_main_loop( strip_resume_line=cfg.router.is_resume_line, running_tasks=running_tasks, on_thread_known=on_thread_known, - progress_edit_every=cfg.progress_edit_every, ) except Exception as exc: logger.exception( @@ -926,6 +889,9 @@ async def run_main_loop( reply_id = r.get("message_id") if resume_token is None and reply_id is not None: running_task = running_tasks.get(int(reply_id)) + if running_task is None: + await anyio.sleep(0) + running_task = running_tasks.get(int(reply_id)) if running_task is not None: tg.start_soon( _send_with_resume, diff --git a/src/takopi/onboarding.py b/src/takopi/onboarding.py index 4587745..7d7790c 100644 --- a/src/takopi/onboarding.py +++ b/src/takopi/onboarding.py @@ -25,7 +25,7 @@ from .backends_helpers import install_issue from .config import ConfigError, HOME_CONFIG_PATH, load_telegram_config from .engines import list_backends from .logging import suppress_logs -from .telegram import TelegramClient +from .telegram import TelegramClient, TelegramRetryAfter @dataclass(slots=True) @@ -132,7 +132,12 @@ def _render_config(token: str, chat_id: int, default_engine: str | None) -> str: async def _get_bot_info(token: str) -> dict[str, Any] | None: bot = TelegramClient(token) try: - return await bot.get_me() + for _ in range(3): + try: + return await bot.get_me() + except TelegramRetryAfter as exc: + await anyio.sleep(exc.retry_after) + return None finally: await bot.close() @@ -148,9 +153,13 @@ async def _wait_for_chat(token: str) -> ChatInfo: if drained: offset = drained[-1]["update_id"] + 1 while True: - updates = await bot.get_updates( - offset=offset, timeout_s=50, allowed_updates=allowed_updates - ) + try: + updates = await bot.get_updates( + offset=offset, timeout_s=50, allowed_updates=allowed_updates + ) + except TelegramRetryAfter as exc: + await anyio.sleep(exc.retry_after) + continue if updates is None: await anyio.sleep(1) continue diff --git a/src/takopi/telegram.py b/src/takopi/telegram.py index 0e46e46..a25394f 100644 --- a/src/takopi/telegram.py +++ b/src/takopi/telegram.py @@ -1,14 +1,39 @@ from __future__ import annotations -from typing import Any, Protocol +import itertools +import time +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable, Hashable, Protocol, TYPE_CHECKING import httpx +import anyio + from .logging import get_logger logger = get_logger(__name__) +SEND_PRIORITY = 0 +DELETE_PRIORITY = 1 +EDIT_PRIORITY = 2 + + +class RetryAfter(Exception): + def __init__(self, retry_after: float, description: str | None = None) -> None: + super().__init__(description or f"retry after {retry_after}") + self.retry_after = float(retry_after) + self.description = description + + +class TelegramRetryAfter(RetryAfter): + pass + + +def is_group_chat_id(chat_id: int) -> bool: + return chat_id < 0 + + class BotClient(Protocol): async def close(self) -> None: ... @@ -27,6 +52,8 @@ class BotClient(Protocol): disable_notification: bool | None = False, entities: list[dict] | None = None, parse_mode: str | None = None, + *, + replace_message_id: int | None = None, ) -> dict | None: ... async def edit_message_text( @@ -36,9 +63,15 @@ class BotClient(Protocol): text: str, entities: list[dict] | None = None, parse_mode: str | None = None, + *, + wait: bool = True, ) -> dict | None: ... - async def delete_message(self, chat_id: int, message_id: int) -> bool: ... + async def delete_message( + self, + chat_id: int, + message_id: int, + ) -> bool: ... async def set_my_commands( self, @@ -51,27 +84,287 @@ class BotClient(Protocol): async def get_me(self) -> dict | None: ... +if TYPE_CHECKING: + from anyio.abc import TaskGroup +else: + TaskGroup = object + + +@dataclass(slots=True) +class OutboxOp: + execute: Callable[[], Awaitable[Any]] + priority: int + queued_at: float + updated_at: float + chat_id: int | None + label: str | None = None + done: anyio.Event = field(default_factory=anyio.Event) + result: Any = None + + def set_result(self, result: Any) -> None: + if self.done.is_set(): + return + self.result = result + self.done.set() + + +class TelegramOutbox: + def __init__( + self, + *, + interval_for_chat: Callable[[int | None], float], + clock: Callable[[], float] = time.monotonic, + sleep: Callable[[float], Awaitable[None]] = anyio.sleep, + on_error: Callable[[OutboxOp, Exception], None] | None = None, + on_outbox_error: Callable[[Exception], None] | None = None, + ) -> None: + self._interval_for_chat = interval_for_chat + self._clock = clock + self._sleep = sleep + self._on_error = on_error + self._on_outbox_error = on_outbox_error + self._pending: dict[Hashable, OutboxOp] = {} + self._cond = anyio.Condition() + self._start_lock = anyio.Lock() + self._closed = False + self._tg: TaskGroup | None = None + self.next_at = 0.0 + self.retry_at = 0.0 + + async def ensure_worker(self) -> None: + async with self._start_lock: + if self._tg is not None or self._closed: + return + self._tg = await anyio.create_task_group().__aenter__() + self._tg.start_soon(self.run) + + async def enqueue(self, *, key: Hashable, op: OutboxOp, wait: bool = True) -> Any: + await self.ensure_worker() + async with self._cond: + if self._closed: + op.set_result(None) + return op.result + previous = self._pending.get(key) + if previous is not None: + op.queued_at = previous.queued_at + previous.set_result(None) + else: + op.queued_at = op.updated_at + self._pending[key] = op + self._cond.notify() + if not wait: + return None + await op.done.wait() + return op.result + + async def drop_pending(self, *, key: Hashable) -> None: + async with self._cond: + pending = self._pending.pop(key, None) + if pending is not None: + pending.set_result(None) + self._cond.notify() + + async def close(self) -> None: + async with self._cond: + self._closed = True + self.fail_pending() + self._cond.notify_all() + if self._tg is not None: + await self._tg.__aexit__(None, None, None) + self._tg = None + + def fail_pending(self) -> None: + for pending in list(self._pending.values()): + pending.set_result(None) + self._pending.clear() + + def pick_locked(self) -> tuple[Hashable, OutboxOp] | None: + if not self._pending: + return None + return min( + self._pending.items(), + key=lambda item: (item[1].priority, item[1].queued_at), + ) + + async def execute_op(self, op: OutboxOp) -> Any: + try: + return await op.execute() + except Exception as exc: + if isinstance(exc, RetryAfter): + raise + if self._on_error is not None: + self._on_error(op, exc) + return None + + async def sleep_until(self, deadline: float) -> None: + delay = deadline - self._clock() + if delay > 0: + await self._sleep(delay) + + async def run(self) -> None: + cancel_exc = anyio.get_cancelled_exc_class() + try: + while True: + async with self._cond: + while not self._pending and not self._closed: + await self._cond.wait() + if self._closed and not self._pending: + return + blocked_until = max(self.next_at, self.retry_at) + if self._clock() < blocked_until: + await self.sleep_until(blocked_until) + continue + async with self._cond: + if self._closed and not self._pending: + return + picked = self.pick_locked() + if picked is None: + continue + key, op = picked + self._pending.pop(key, None) + started_at = self._clock() + try: + result = await self.execute_op(op) + except RetryAfter as exc: + self.retry_at = max(self.retry_at, self._clock() + exc.retry_after) + async with self._cond: + if self._closed: + op.set_result(None) + elif key not in self._pending: + self._pending[key] = op + self._cond.notify() + else: + op.set_result(None) + continue + self.next_at = started_at + self._interval_for_chat(op.chat_id) + op.set_result(result) + except cancel_exc: + return + except Exception as exc: + async with self._cond: + self._closed = True + self.fail_pending() + self._cond.notify_all() + if self._on_outbox_error is not None: + self._on_outbox_error(exc) + return + + +def retry_after_from_payload(payload: dict[str, Any]) -> float | None: + params = payload.get("parameters") + if isinstance(params, dict): + retry_after = params.get("retry_after") + if isinstance(retry_after, (int, float)): + return float(retry_after) + return None + + class TelegramClient: def __init__( self, - token: str, + token: str | None = None, + *, + client: BotClient | None = None, timeout_s: float = 120, - client: httpx.AsyncClient | None = None, + http_client: httpx.AsyncClient | None = None, + clock: Callable[[], float] = time.monotonic, + sleep: Callable[[float], Awaitable[None]] = anyio.sleep, + private_chat_rps: float = 1.0, + group_chat_rps: float = 20.0 / 60.0, ) -> None: - if not token: - raise ValueError("Telegram token is empty") - self._base = f"https://api.telegram.org/bot{token}" - self._client = client or httpx.AsyncClient(timeout=timeout_s) - self._owns_client = client is None + if client is not None: + if token is not None or http_client is not None: + raise ValueError("Provide either token or client, not both.") + self._client_override = client + self._base = None + self._http_client = None + self._owns_http_client = False + else: + if token is None or not token: + raise ValueError("Telegram token is empty") + self._client_override = None + self._base = f"https://api.telegram.org/bot{token}" + self._http_client = http_client or httpx.AsyncClient(timeout=timeout_s) + self._owns_http_client = http_client is None + self._clock = clock + self._sleep = sleep + self._private_interval = ( + 0.0 if private_chat_rps <= 0 else 1.0 / private_chat_rps + ) + self._group_interval = 0.0 if group_chat_rps <= 0 else 1.0 / group_chat_rps + self._outbox = TelegramOutbox( + interval_for_chat=self.interval_for_chat, + clock=clock, + sleep=sleep, + on_error=self.log_request_error, + on_outbox_error=self.log_outbox_failure, + ) + self._seq = itertools.count() + + def interval_for_chat(self, chat_id: int | None) -> float: + if chat_id is None: + return self._private_interval + if is_group_chat_id(chat_id): + return self._group_interval + return self._private_interval + + def log_request_error(self, request: OutboxOp, exc: Exception) -> None: + logger.error( + "telegram.outbox.request_failed", + method=request.label, + error=str(exc), + error_type=exc.__class__.__name__, + ) + + def log_outbox_failure(self, exc: Exception) -> None: + logger.error( + "telegram.outbox.failed", + error=str(exc), + error_type=exc.__class__.__name__, + ) + + async def drop_pending_edits(self, *, chat_id: int, message_id: int) -> None: + await self._outbox.drop_pending(key=("edit", chat_id, message_id)) + + def unique_key(self, prefix: str) -> tuple[str, int]: + return (prefix, next(self._seq)) + + async def enqueue_op( + self, + *, + key: Hashable, + label: str, + execute: Callable[[], Awaitable[Any]], + priority: int, + chat_id: int | None, + wait: bool = True, + ) -> Any: + request = OutboxOp( + execute=execute, + priority=priority, + queued_at=0.0, + updated_at=self._clock(), + chat_id=chat_id, + label=label, + ) + return await self._outbox.enqueue(key=key, op=request, wait=wait) async def close(self) -> None: - if self._owns_client: - await self._client.aclose() + await self._outbox.close() + if self._client_override is not None: + await self._client_override.close() + return + if self._owns_http_client and self._http_client is not None: + await self._http_client.aclose() async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None: + if self._http_client is None or self._base is None: + raise RuntimeError("TelegramClient is configured without an HTTP client.") logger.debug("telegram.request", method=method, payload=json_data) try: - resp = await self._client.post(f"{self._base}/{method}", json=json_data) + resp = await self._http_client.post( + f"{self._base}/{method}", json=json_data + ) except httpx.HTTPError as e: url = getattr(e.request, "url", None) logger.error( @@ -86,6 +379,23 @@ class TelegramClient: try: resp.raise_for_status() except httpx.HTTPStatusError as e: + if resp.status_code == 429: + retry_after: float | None = None + try: + payload = resp.json() + except Exception: + payload = None + if isinstance(payload, dict): + retry_after = retry_after_from_payload(payload) + retry_after = 5.0 if retry_after is None else retry_after + logger.warning( + "telegram.rate_limited", + method=method, + status=resp.status_code, + url=str(resp.request.url), + retry_after=retry_after, + ) + raise TelegramRetryAfter(retry_after) from e body = resp.text logger.error( "telegram.http_error", @@ -122,6 +432,16 @@ class TelegramClient: return None if not payload.get("ok"): + if payload.get("error_code") == 429: + retry_after = retry_after_from_payload(payload) + retry_after = 5.0 if retry_after is None else retry_after + logger.warning( + "telegram.rate_limited", + method=method, + url=str(resp.request.url), + retry_after=retry_after, + ) + raise TelegramRetryAfter(retry_after) logger.error( "telegram.api_error", method=method, @@ -139,12 +459,23 @@ class TelegramClient: timeout_s: int = 50, allowed_updates: list[str] | None = None, ) -> list[dict] | None: - params: dict[str, Any] = {"timeout": timeout_s} - if offset is not None: - params["offset"] = offset - if allowed_updates is not None: - params["allowed_updates"] = allowed_updates - return await self._post("getUpdates", params) # type: ignore[return-value] + while True: + try: + if self._client_override is not None: + return await self._client_override.get_updates( + offset=offset, + timeout_s=timeout_s, + allowed_updates=allowed_updates, + ) + params: dict[str, Any] = {"timeout": timeout_s} + if offset is not None: + params["offset"] = offset + if allowed_updates is not None: + params["allowed_updates"] = allowed_updates + result = await self._post("getUpdates", params) + return result if isinstance(result, list) else None + except TelegramRetryAfter as exc: + await self._sleep(exc.retry_after) async def send_message( self, @@ -154,20 +485,48 @@ class TelegramClient: disable_notification: bool | None = False, entities: list[dict] | None = None, parse_mode: str | None = None, + *, + replace_message_id: int | None = None, ) -> dict | None: - params: dict[str, Any] = { - "chat_id": chat_id, - "text": text, - } - if disable_notification is not None: - params["disable_notification"] = disable_notification - if reply_to_message_id is not None: - params["reply_to_message_id"] = reply_to_message_id - if entities is not None: - params["entities"] = entities - if parse_mode is not None: - params["parse_mode"] = parse_mode - return await self._post("sendMessage", params) # type: ignore[return-value] + async def execute() -> dict | None: + if self._client_override is not None: + return await self._client_override.send_message( + chat_id=chat_id, + text=text, + reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, + entities=entities, + parse_mode=parse_mode, + replace_message_id=replace_message_id, + ) + params: dict[str, Any] = {"chat_id": chat_id, "text": text} + if disable_notification is not None: + params["disable_notification"] = disable_notification + if reply_to_message_id is not None: + params["reply_to_message_id"] = reply_to_message_id + if entities is not None: + params["entities"] = entities + if parse_mode is not None: + params["parse_mode"] = parse_mode + result = await self._post("sendMessage", params) + return result if isinstance(result, dict) else None + + if replace_message_id is not None: + await self._outbox.drop_pending(key=("edit", chat_id, replace_message_id)) + result = await self.enqueue_op( + key=( + ("send", chat_id, replace_message_id) + if replace_message_id is not None + else self.unique_key("send") + ), + label="send_message", + execute=execute, + priority=SEND_PRIORITY, + chat_id=chat_id, + ) + if replace_message_id is not None and result is not None: + await self.delete_message(chat_id=chat_id, message_id=replace_message_id) + return result async def edit_message_text( self, @@ -176,27 +535,68 @@ class TelegramClient: text: str, entities: list[dict] | None = None, parse_mode: str | None = None, + *, + wait: bool = True, ) -> dict | None: - params: dict[str, Any] = { - "chat_id": chat_id, - "message_id": message_id, - "text": text, - } - if entities is not None: - params["entities"] = entities - if parse_mode is not None: - params["parse_mode"] = parse_mode - return await self._post("editMessageText", params) # type: ignore[return-value] - - async def delete_message(self, chat_id: int, message_id: int) -> bool: - res = await self._post( - "deleteMessage", - { + async def execute() -> dict | None: + if self._client_override is not None: + return await self._client_override.edit_message_text( + chat_id=chat_id, + message_id=message_id, + text=text, + entities=entities, + parse_mode=parse_mode, + wait=wait, + ) + params: dict[str, Any] = { "chat_id": chat_id, "message_id": message_id, - }, + "text": text, + } + if entities is not None: + params["entities"] = entities + if parse_mode is not None: + params["parse_mode"] = parse_mode + result = await self._post("editMessageText", params) + return result if isinstance(result, dict) else None + + return await self.enqueue_op( + key=("edit", chat_id, message_id), + label="edit_message_text", + execute=execute, + priority=EDIT_PRIORITY, + chat_id=chat_id, + wait=wait, + ) + + async def delete_message( + self, + chat_id: int, + message_id: int, + ) -> bool: + await self.drop_pending_edits(chat_id=chat_id, message_id=message_id) + + async def execute() -> bool: + if self._client_override is not None: + return await self._client_override.delete_message( + chat_id=chat_id, + message_id=message_id, + ) + result = await self._post( + "deleteMessage", + {"chat_id": chat_id, "message_id": message_id}, + ) + return bool(result) + + return bool( + await self.enqueue_op( + key=("delete", chat_id, message_id), + label="delete_message", + execute=execute, + priority=DELETE_PRIORITY, + chat_id=chat_id, + ) ) - return bool(res) async def set_my_commands( self, @@ -205,14 +605,42 @@ class TelegramClient: scope: dict[str, Any] | None = None, language_code: str | None = None, ) -> bool: - params: dict[str, Any] = {"commands": commands} - if scope is not None: - params["scope"] = scope - if language_code is not None: - params["language_code"] = language_code - res = await self._post("setMyCommands", params) - return bool(res) + async def execute() -> bool: + if self._client_override is not None: + return await self._client_override.set_my_commands( + commands, + scope=scope, + language_code=language_code, + ) + params: dict[str, Any] = {"commands": commands} + if scope is not None: + params["scope"] = scope + if language_code is not None: + params["language_code"] = language_code + result = await self._post("setMyCommands", params) + return bool(result) + + return bool( + await self.enqueue_op( + key=self.unique_key("set_my_commands"), + label="set_my_commands", + execute=execute, + priority=SEND_PRIORITY, + chat_id=None, + ) + ) async def get_me(self) -> dict | None: - res = await self._post("getMe", {}) - return res if isinstance(res, dict) else None + async def execute() -> dict | None: + if self._client_override is not None: + return await self._client_override.get_me() + result = await self._post("getMe", {}) + return result if isinstance(result, dict) else None + + return await self.enqueue_op( + key=self.unique_key("get_me"), + label="get_me", + execute=execute, + priority=SEND_PRIORITY, + chat_id=None, + ) diff --git a/tests/test_exec_bridge.py b/tests/test_exec_bridge.py index 4740f9a..99f16f5 100644 --- a/tests/test_exec_bridge.py +++ b/tests/test_exec_bridge.py @@ -8,6 +8,7 @@ from takopi.model import EngineId, ResumeToken, TakopiEvent from takopi.render import MarkdownParts, prepare_telegram from takopi.router import AutoRouter, RunnerEntry from takopi.runners.codex import CodexRunner +from takopi.telegram import TelegramClient from takopi.runners.mock import Advance, Emit, Raise, Return, ScriptRunner, Sleep, Wait from tests.factories import action_completed, action_started @@ -189,7 +190,10 @@ class _FakeBot: disable_notification: bool | None = False, entities: list[dict] | None = None, parse_mode: str | None = None, + *, + replace_message_id: int | None = None, ) -> dict: + _ = replace_message_id self.send_calls.append( { "chat_id": chat_id, @@ -211,7 +215,10 @@ class _FakeBot: text: str, entities: list[dict] | None = None, parse_mode: str | None = None, + *, + wait: bool = True, ) -> dict: + _ = wait self.edit_calls.append( { "chat_id": chat_id, @@ -223,7 +230,11 @@ class _FakeBot: ) return {"message_id": message_id} - async def delete_message(self, chat_id: int, message_id: int) -> bool: + async def delete_message( + self, + chat_id: int, + message_id: int, + ) -> bool: self.delete_calls.append({"chat_id": chat_id, "message_id": message_id}) return True @@ -281,15 +292,33 @@ class _FakeClock: self._sleep_event = None async def sleep(self, delay: float) -> None: - self.sleep_calls += 1 if delay <= 0: await anyio.sleep(0) return + self.sleep_calls += 1 self._sleep_until = self._now + delay self._sleep_event = anyio.Event() await self._sleep_event.wait() +def _queued_bot( + bot: "_FakeBot", *, clock: "_FakeClock | None" = None +) -> TelegramClient: + if clock is None: + return TelegramClient( + client=bot, + private_chat_rps=0.0, + group_chat_rps=0.0, + ) + return TelegramClient( + client=bot, + clock=clock, + sleep=clock.sleep, + private_chat_rps=0.0, + group_chat_rps=0.0, + ) + + def _return_runner( *, answer: str = "ok", resume_value: str | None = None ) -> ScriptRunner: @@ -307,7 +336,7 @@ async def test_final_notify_sends_loud_final_message() -> None: bot = _FakeBot() runner = _return_runner(answer="ok") cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot), router=_make_router(runner), chat_id=123, final_notify=True, @@ -335,7 +364,7 @@ async def test_handle_message_strips_resume_line_from_prompt() -> None: bot = _FakeBot() runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot), router=_make_router(runner), chat_id=123, final_notify=True, @@ -366,7 +395,7 @@ async def test_long_final_message_edits_progress_message() -> None: bot = _FakeBot() runner = _return_runner(answer="x" * 10_000) cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot), router=_make_router(runner), chat_id=123, final_notify=False, @@ -384,7 +413,8 @@ async def test_long_final_message_edits_progress_message() -> None: assert len(bot.send_calls) == 1 assert bot.send_calls[0]["disable_notification"] is True - assert len(bot.edit_calls) == 1 + assert bot.edit_calls + assert "done" in bot.edit_calls[-1]["text"].lower() @pytest.mark.anyio @@ -408,7 +438,7 @@ async def test_progress_edits_are_rate_limited() -> None: advance=clock.set, ) cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot, clock=clock), router=_make_router(runner), chat_id=123, final_notify=True, @@ -423,12 +453,10 @@ async def test_progress_edits_are_rate_limited() -> None: text="hi", resume_token=None, clock=clock, - sleep=clock.sleep, - progress_edit_every=1.0, ) - assert len(bot.edit_calls) == 1 - assert "echo 2" in bot.edit_calls[0]["text"] + assert bot.edit_calls + assert "working" in bot.edit_calls[-1]["text"].lower() @pytest.mark.anyio @@ -453,7 +481,7 @@ async def test_progress_edits_do_not_sleep_again_without_new_events() -> None: advance=clock.set, ) cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot, clock=clock), router=_make_router(runner), chat_id=123, final_notify=True, @@ -469,33 +497,18 @@ async def test_progress_edits_do_not_sleep_again_without_new_events() -> None: text="hi", resume_token=None, clock=clock, - sleep=clock.sleep, - progress_edit_every=1.0, ) async with anyio.create_task_group() as tg: tg.start_soon(run_handle_message) - for _ in range(100): - if clock._sleep_until is not None: - break - await anyio.sleep(0) - - assert clock._sleep_until == pytest.approx(1.0) - - clock.set(1.0) - for _ in range(100): if bot.edit_calls: break await anyio.sleep(0) - assert len(bot.edit_calls) == 1 - - for _ in range(5): - await anyio.sleep(0) - - assert clock.sleep_calls == 1 + assert bot.edit_calls + assert clock.sleep_calls == 0 assert clock._sleep_until is None hold.set() @@ -529,7 +542,7 @@ async def test_bridge_flow_sends_progress_edits_and_final_resume() -> None: resume_value=session_id, ) cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot, clock=clock), router=_make_router(runner), chat_id=123, final_notify=True, @@ -544,8 +557,6 @@ async def test_bridge_flow_sends_progress_edits_and_final_resume() -> None: text="do it", resume_token=None, clock=clock, - sleep=clock.sleep, - progress_edit_every=1.0, ) assert bot.send_calls[0]["reply_to_message_id"] == 42 @@ -564,7 +575,7 @@ async def test_handle_cancel_without_reply_prompts_user() -> None: bot = _FakeBot() runner = _return_runner(answer="ok") cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot), router=_make_router(runner), chat_id=123, final_notify=True, @@ -586,7 +597,7 @@ async def test_handle_cancel_with_no_progress_message_says_nothing_running() -> bot = _FakeBot() runner = _return_runner(answer="ok") cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot), router=_make_router(runner), chat_id=123, final_notify=True, @@ -612,7 +623,7 @@ async def test_handle_cancel_with_finished_task_says_nothing_running() -> None: bot = _FakeBot() runner = _return_runner(answer="ok") cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot), router=_make_router(runner), chat_id=123, final_notify=True, @@ -639,7 +650,7 @@ async def test_handle_cancel_cancels_running_task() -> None: bot = _FakeBot() runner = _return_runner(answer="ok") cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot), router=_make_router(runner), chat_id=123, final_notify=True, @@ -669,7 +680,7 @@ async def test_handle_cancel_only_cancels_matching_progress_message() -> None: bot = _FakeBot() runner = _return_runner(answer="ok") cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot), router=_make_router(runner), chat_id=123, final_notify=True, @@ -714,7 +725,7 @@ async def test_handle_message_cancelled_renders_cancelled_state() -> None: resume_value=session_id, ) cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot), router=_make_router(runner), chat_id=123, final_notify=True, @@ -764,7 +775,7 @@ async def test_handle_message_error_preserves_resume_token() -> None: resume_value=session_id, ) cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot), router=_make_router(runner), chat_id=123, final_notify=True, @@ -873,6 +884,8 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None: disable_notification: bool | None = False, entities: list[dict] | None = None, parse_mode: str | None = None, + *, + replace_message_id: int | None = None, ) -> dict: msg = await super().send_message( chat_id=chat_id, @@ -881,6 +894,7 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None: disable_notification=disable_notification, entities=entities, parse_mode=parse_mode, + replace_message_id=replace_message_id, ) if self.progress_id is None and reply_to_message_id is not None: self.progress_id = int(msg["message_id"]) @@ -895,7 +909,7 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None: resume_value=resume_value, ) cfg = BridgeConfig( - bot=bot, + bot=_queued_bot(bot), router=_make_router(runner), chat_id=123, final_notify=True, diff --git a/tests/test_telegram_client.py b/tests/test_telegram_client.py index 1686172..f7e2e88 100644 --- a/tests/test_telegram_client.py +++ b/tests/test_telegram_client.py @@ -2,7 +2,7 @@ import httpx import pytest from takopi.logging import setup_logging -from takopi.telegram import TelegramClient +from takopi.telegram import TelegramClient, TelegramRetryAfter @pytest.mark.anyio @@ -25,12 +25,13 @@ async def test_telegram_429_no_retry() -> None: client = httpx.AsyncClient(transport=transport) try: - tg = TelegramClient("123:abcDEF_ghij", client=client) - result = await tg._post("sendMessage", {"chat_id": 1, "text": "hi"}) + tg = TelegramClient("123:abcDEF_ghij", http_client=client) + with pytest.raises(TelegramRetryAfter) as exc: + await tg._post("sendMessage", {"chat_id": 1, "text": "hi"}) finally: await client.aclose() - assert result is None + assert exc.value.retry_after == 3 assert len(calls) == 1 @@ -48,7 +49,7 @@ async def test_no_token_in_logs_on_http_error( client = httpx.AsyncClient(transport=transport) try: - tg = TelegramClient(token, client=client) + tg = TelegramClient(token, http_client=client) await tg._post("getUpdates", {"timeout": 1}) finally: await client.aclose() diff --git a/tests/test_telegram_queue.py b/tests/test_telegram_queue.py new file mode 100644 index 0000000..928419d --- /dev/null +++ b/tests/test_telegram_queue.py @@ -0,0 +1,251 @@ +import anyio +import pytest + +from takopi.telegram import TelegramClient, TelegramRetryAfter + + +class _FakeBot: + def __init__(self) -> None: + self.calls: list[str] = [] + self.edit_calls: list[str] = [] + self.delete_calls: list[tuple[int, int]] = [] + self._edit_attempts = 0 + self._updates_attempts = 0 + self.retry_after: float | None = None + self.updates_retry_after: float | None = None + + async def send_message( + self, + chat_id: int, + text: str, + reply_to_message_id: int | None = None, + disable_notification: bool | None = False, + entities: list[dict] | None = None, + parse_mode: str | None = None, + *, + replace_message_id: int | None = None, + ) -> dict: + _ = reply_to_message_id + _ = disable_notification + _ = entities + _ = parse_mode + _ = replace_message_id + self.calls.append("send_message") + return {"message_id": 1} + + async def edit_message_text( + self, + chat_id: int, + message_id: int, + text: str, + entities: list[dict] | None = None, + parse_mode: str | None = None, + *, + wait: bool = True, + ) -> dict: + _ = chat_id + _ = message_id + _ = entities + _ = parse_mode + _ = wait + self.calls.append("edit_message_text") + self.edit_calls.append(text) + if self.retry_after is not None and self._edit_attempts == 0: + self._edit_attempts += 1 + raise TelegramRetryAfter(self.retry_after) + self._edit_attempts += 1 + return {"message_id": message_id} + + async def delete_message( + self, + chat_id: int, + message_id: int, + ) -> bool: + self.calls.append("delete_message") + self.delete_calls.append((chat_id, message_id)) + return True + + async def set_my_commands( + self, + commands: list[dict], + *, + scope: dict | None = None, + language_code: str | None = None, + ) -> bool: + _ = commands + _ = scope + _ = language_code + return True + + async def get_updates( + self, + offset: int | None, + timeout_s: int = 50, + allowed_updates: list[str] | None = None, + ) -> list[dict] | None: + _ = offset + _ = timeout_s + _ = allowed_updates + if self.updates_retry_after is not None and self._updates_attempts == 0: + self._updates_attempts += 1 + raise TelegramRetryAfter(self.updates_retry_after) + self._updates_attempts += 1 + return [] + + async def close(self) -> None: + return None + + async def get_me(self) -> dict | None: + return {"id": 1} + + +@pytest.mark.anyio +async def test_edits_coalesce_latest() -> None: + class _BlockingBot(_FakeBot): + def __init__(self) -> None: + super().__init__() + self.edit_started = anyio.Event() + self.release = anyio.Event() + self._block_first = True + + async def edit_message_text( + self, + chat_id: int, + message_id: int, + text: str, + entities: list[dict] | None = None, + parse_mode: str | None = None, + *, + wait: bool = True, + ) -> dict: + if self._block_first: + self._block_first = False + self.edit_started.set() + await self.release.wait() + return await super().edit_message_text( + chat_id=chat_id, + message_id=message_id, + text=text, + entities=entities, + parse_mode=parse_mode, + wait=wait, + ) + + bot = _BlockingBot() + client = TelegramClient(client=bot, private_chat_rps=0.0, group_chat_rps=0.0) + + await client.edit_message_text( + chat_id=1, + message_id=1, + text="first", + wait=False, + ) + + with anyio.fail_after(1): + await bot.edit_started.wait() + + await client.edit_message_text( + chat_id=1, + message_id=1, + text="second", + wait=False, + ) + await client.edit_message_text( + chat_id=1, + message_id=1, + text="third", + wait=False, + ) + + bot.release.set() + + with anyio.fail_after(1): + while len(bot.edit_calls) < 2: + await anyio.sleep(0) + + assert bot.edit_calls == ["first", "third"] + + +@pytest.mark.anyio +async def test_send_preempts_pending_edit() -> None: + bot = _FakeBot() + client = TelegramClient(client=bot, private_chat_rps=10.0, group_chat_rps=10.0) + + await client.edit_message_text( + chat_id=1, + message_id=1, + text="first", + ) + + await client.edit_message_text( + chat_id=1, + message_id=1, + text="progress", + wait=False, + ) + + with anyio.fail_after(1): + await client.send_message(chat_id=1, text="final") + + await anyio.sleep(0.2) + assert bot.calls[0] == "edit_message_text" + assert bot.calls[1] == "send_message" + assert bot.calls[-1] == "edit_message_text" + + +@pytest.mark.anyio +async def test_delete_drops_pending_edits() -> None: + bot = _FakeBot() + client = TelegramClient(client=bot, private_chat_rps=10.0, group_chat_rps=10.0) + + await client.edit_message_text( + chat_id=1, + message_id=1, + text="first", + ) + + await client.edit_message_text( + chat_id=1, + message_id=1, + text="progress", + wait=False, + ) + + with anyio.fail_after(1): + await client.delete_message( + chat_id=1, + message_id=1, + ) + + await anyio.sleep(0.2) + assert bot.delete_calls == [(1, 1)] + assert bot.edit_calls == ["first"] + + +@pytest.mark.anyio +async def test_retry_after_retries_once() -> None: + bot = _FakeBot() + bot.retry_after = 0.0 + client = TelegramClient(client=bot, private_chat_rps=0.0, group_chat_rps=0.0) + + result = await client.edit_message_text( + chat_id=1, + message_id=1, + text="retry", + ) + + assert result == {"message_id": 1} + assert bot._edit_attempts == 2 + + +@pytest.mark.anyio +async def test_get_updates_retries_on_retry_after() -> None: + bot = _FakeBot() + bot.updates_retry_after = 0.0 + client = TelegramClient(client=bot, private_chat_rps=0.0, group_chat_rps=0.0) + + with anyio.fail_after(1): + updates = await client.get_updates(offset=None, timeout_s=0) + + assert updates == [] + assert bot._updates_attempts == 2