feat: better progress edits, simpler telegram client (#5)

This commit is contained in:
banteg
2025-12-30 23:22:40 +04:00
committed by GitHub
parent cb6de41e57
commit 6687a435c9
5 changed files with 342 additions and 226 deletions
+2 -2
View File
@@ -38,7 +38,7 @@ The orchestrator module containing:
| `CodexExecRunner` | Spawns `codex exec`, streams JSONL, handles cancellation | | `CodexExecRunner` | Spawns `codex exec`, streams JSONL, handles cancellation |
| `poll_updates()` | Async generator that drains backlog, long-polls updates, filters messages | | `poll_updates()` | Async generator that drains backlog, long-polls updates, filters messages |
| `_run_main_loop()` | TaskGroup-based main loop that spawns per-message handlers | | `_run_main_loop()` | TaskGroup-based main loop that spawns per-message handlers |
| `_handle_message()` | Per-message handler with progress updates | | `handle_message()` | Per-message handler with progress updates |
| `extract_session_id()` | Parses `resume: <uuid>` from message text | | `extract_session_id()` | Parses `resume: <uuid>` from message text |
| `truncate_for_telegram()` | Smart truncation preserving resume lines | | `truncate_for_telegram()` | Smart truncation preserving resume lines |
@@ -122,7 +122,7 @@ poll_updates() drains backlog, long-polls, filters chat_id == from_id == cfg.cha
_run_main_loop() spawns tasks in TaskGroup _run_main_loop() spawns tasks in TaskGroup
_handle_message() spawned as task handle_message() spawned as task
Send initial progress message (silent) Send initial progress message (silent)
+156 -137
View File
@@ -155,26 +155,17 @@ async def _send_or_edit_markdown(
reply_to_message_id: int | None = None, reply_to_message_id: int | None = None,
disable_notification: bool = False, disable_notification: bool = False,
limit: int = TELEGRAM_MARKDOWN_LIMIT, limit: int = TELEGRAM_MARKDOWN_LIMIT,
) -> tuple[dict[str, Any], bool]: ) -> tuple[dict[str, Any] | None, bool]:
if edit_message_id is not None: if edit_message_id is not None:
rendered, entities = prepare_telegram(text, limit=limit) rendered, entities = prepare_telegram(text, limit=limit)
try: edited = await bot.edit_message_text(
return ( chat_id=chat_id,
await bot.edit_message_text( message_id=edit_message_id,
chat_id=chat_id, text=rendered,
message_id=edit_message_id, entities=entities,
text=rendered, )
entities=entities, if edited is not None:
), return (edited, True)
True,
)
except Exception as e:
logger.info(
"[tg] edit failed chat_id=%s message_id=%s: %s",
chat_id,
edit_message_id,
e,
)
rendered, entities = prepare_telegram(text, limit=limit) rendered, entities = prepare_telegram(text, limit=limit)
return ( return (
@@ -192,6 +183,90 @@ async def _send_or_edit_markdown(
EventCallback = Callable[[dict[str, Any]], Awaitable[None] | None] EventCallback = Callable[[dict[str, Any]], Awaitable[None] | None]
class ProgressEdits:
def __init__(
self,
*,
bot: TelegramClient,
chat_id: int,
progress_id: int | None,
renderer: ExecProgressRenderer,
started_at: float,
progress_edit_every: float,
clock: Callable[[], float],
sleep: Callable[[float], Awaitable[None]],
limit: int,
last_edit_at: float,
last_rendered: str | None,
) -> None:
self.bot = bot
self.chat_id = chat_id
self.progress_id = progress_id
self.renderer = renderer
self.started_at = started_at
self.progress_edit_every = progress_edit_every
self.clock = clock
self.sleep = sleep
self.limit = limit
self.last_edit_at = last_edit_at
self.last_rendered = last_rendered
self._event_seq = 0
self._published_seq = 0
self.wakeup = asyncio.Event()
self.task: asyncio.Task[None] | None = (
asyncio.create_task(self.run()) if self.progress_id is not None else None
)
async def run(self) -> None:
if self.progress_id is None:
return
while True:
await self.wakeup.wait()
self.wakeup.clear()
while self._published_seq < self._event_seq:
await self.sleep(
max(
0.0,
self.last_edit_at + self.progress_edit_every - self.clock(),
)
)
seq_at_render = self._event_seq
now = self.clock()
md = self.renderer.render_progress(now - self.started_at)
rendered, entities = prepare_telegram(md, limit=self.limit)
if rendered != self.last_rendered:
logger.debug(
"[progress] edit message_id=%s md=%s", self.progress_id, md
)
self.last_edit_at = now
edited = await self.bot.edit_message_text(
chat_id=self.chat_id,
message_id=self.progress_id,
text=rendered,
entities=entities,
)
if edited is not None:
self.last_rendered = rendered
self._published_seq = seq_at_render
self.wakeup.clear()
async def on_event(self, evt: dict[str, Any]) -> None:
if not self.renderer.note_event(evt):
return
if self.progress_id is None:
return
self._event_seq += 1
self.wakeup.set()
async def shutdown(self) -> None:
if self.task is None:
return
self.task.cancel()
await asyncio.gather(self.task, return_exceptions=True)
class CodexExecRunner: class CodexExecRunner:
def __init__( def __init__(
self, self,
@@ -402,25 +477,20 @@ def _parse_bridge_config(
async def _send_startup(cfg: BridgeConfig) -> None: async def _send_startup(cfg: BridgeConfig) -> None:
try: logger.debug("[startup] message: %s", cfg.startup_msg)
logger.debug("[startup] message: %s", cfg.startup_msg) sent = await cfg.bot.send_message(chat_id=cfg.chat_id, text=cfg.startup_msg)
await cfg.bot.send_message(chat_id=cfg.chat_id, text=cfg.startup_msg) if sent is not None:
logger.info("[startup] sent startup message to chat_id=%s", cfg.chat_id) logger.info("[startup] sent startup message to chat_id=%s", cfg.chat_id)
except Exception as e:
logger.info(
"[startup] failed to send startup message to chat_id=%s: %s", cfg.chat_id, e
)
async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None: async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
drained = 0 drained = 0
while True: while True:
try: updates = await cfg.bot.get_updates(
updates = await cfg.bot.get_updates( offset=offset, timeout_s=0, allowed_updates=["message"]
offset=offset, timeout_s=0, allowed_updates=["message"] )
) if updates is None:
except Exception as e: logger.info("[startup] backlog drain failed")
logger.info("[startup] backlog drain failed: %s", e)
return offset return offset
logger.debug("[startup] backlog updates: %s", updates) logger.debug("[startup] backlog updates: %s", updates)
if not updates: if not updates:
@@ -431,7 +501,7 @@ async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
drained += len(updates) drained += len(updates)
async def _handle_message( async def handle_message(
cfg: BridgeConfig, cfg: BridgeConfig,
*, *,
chat_id: int, chat_id: int,
@@ -440,6 +510,7 @@ async def _handle_message(
resume_session: str | None, resume_session: str | None,
running_tasks: dict[str, asyncio.Task[Any]] | None = None, running_tasks: dict[str, asyncio.Task[Any]] | None = None,
clock: Callable[[], float] = time.monotonic, clock: Callable[[], float] = time.monotonic,
sleep: Callable[[float], Awaitable[None]] = asyncio.sleep,
progress_edit_every: float = PROGRESS_EDIT_EVERY_S, progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
) -> None: ) -> None:
logger.debug( logger.debug(
@@ -453,104 +524,57 @@ async def _handle_message(
progress_renderer = ExecProgressRenderer(max_actions=5) progress_renderer = ExecProgressRenderer(max_actions=5)
progress_id: int | None = None progress_id: int | None = None
last_edit_at = 0.0 last_edit_at = 0.0
edit_task: asyncio.Task[None] | None = None
last_rendered: str | None = None last_rendered: str | None = None
pending_rendered: str | None = None
async def _edit_progress( initial_md = progress_renderer.render_progress(0.0)
md: str, rendered: str, entities: list[dict[str, Any]] | None initial_rendered, initial_entities = prepare_telegram(
) -> None: initial_md, limit=TELEGRAM_MARKDOWN_LIMIT
nonlocal last_rendered, pending_rendered )
if progress_id is None: logger.debug(
return "[progress] send reply_to=%s md=%s rendered=%s entities=%s",
logger.debug( user_msg_id,
"[progress] edit message_id=%s md=%s rendered=%s entities=%s", initial_md,
progress_id, initial_rendered,
md, initial_entities,
rendered, )
entities, progress_msg = await cfg.bot.send_message(
) chat_id=chat_id,
try: text=initial_rendered,
await cfg.bot.edit_message_text( entities=initial_entities,
chat_id=chat_id, reply_to_message_id=user_msg_id,
message_id=progress_id, disable_notification=True,
text=rendered, )
entities=entities, if progress_msg is not None:
)
last_rendered = rendered
except Exception as e:
logger.info(
"[progress] edit failed chat_id=%s message_id=%s: %s",
chat_id,
progress_id,
e,
)
finally:
if pending_rendered == rendered:
pending_rendered = None
try:
initial_md = progress_renderer.render_progress(0.0)
initial_rendered, initial_entities = prepare_telegram(
initial_md, limit=TELEGRAM_MARKDOWN_LIMIT
)
logger.debug(
"[progress] send reply_to=%s md=%s rendered=%s entities=%s",
user_msg_id,
initial_md,
initial_rendered,
initial_entities,
)
progress_msg = await cfg.bot.send_message(
chat_id=chat_id,
text=initial_rendered,
entities=initial_entities,
reply_to_message_id=user_msg_id,
disable_notification=True,
)
progress_id = int(progress_msg["message_id"]) progress_id = int(progress_msg["message_id"])
last_edit_at = clock() last_edit_at = clock()
last_rendered = initial_rendered last_rendered = initial_rendered
logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id) logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id)
except Exception as e:
logger.info( edits = ProgressEdits(
"[handle] failed to send progress message chat_id=%s: %s", chat_id, e bot=cfg.bot,
) chat_id=chat_id,
progress_id=progress_id,
renderer=progress_renderer,
started_at=started_at,
progress_edit_every=progress_edit_every,
clock=clock,
sleep=sleep,
limit=TELEGRAM_MARKDOWN_LIMIT,
last_edit_at=last_edit_at,
last_rendered=last_rendered,
)
exec_task: asyncio.Task[tuple[str, str, bool]] | None = None exec_task: asyncio.Task[tuple[str, str, bool]] | None = None
tracked_session_id: str | None = None
async def on_event(evt: dict[str, Any]) -> None: async def on_event(evt: dict[str, Any]) -> None:
nonlocal last_edit_at, edit_task, pending_rendered, tracked_session_id
if progress_id is None:
return
if not progress_renderer.note_event(evt):
return
if ( if (
evt["type"] == "thread.started" evt["type"] == "thread.started"
and running_tasks is not None and running_tasks is not None
and exec_task is not None and exec_task is not None
): ):
tracked_session_id = progress_renderer.resume_session running_tasks[evt["thread_id"]] = exec_task
if tracked_session_id: await edits.on_event(evt)
running_tasks[tracked_session_id] = exec_task
now = clock()
if (now - last_edit_at) < progress_edit_every:
return
if edit_task is not None and not edit_task.done():
return
elapsed = now - started_at
md = progress_renderer.render_progress(elapsed)
rendered, entities = prepare_telegram(md, limit=TELEGRAM_MARKDOWN_LIMIT)
if rendered == last_rendered or rendered == pending_rendered:
return
last_edit_at = now
pending_rendered = rendered
edit_task = asyncio.create_task(_edit_progress(md, rendered, entities))
exec_task = asyncio.create_task( exec_task = asyncio.create_task(
cfg.runner.run_serialized(text, resume_session, on_event=on_event) cfg.runner.run_serialized(text, resume_session, on_event=on_event)
@@ -561,10 +585,9 @@ async def _handle_message(
session_id, answer, saw_agent_message = await exec_task session_id, answer, saw_agent_message = await exec_task
except asyncio.CancelledError: except asyncio.CancelledError:
cancelled = True cancelled = True
session_id = tracked_session_id or resume_session session_id = progress_renderer.resume_session or resume_session
except Exception as e: except Exception as e:
if edit_task is not None: await edits.shutdown()
await asyncio.gather(edit_task, return_exceptions=True)
err = _clamp_tg_text(f"Error:\n{e}") err = _clamp_tg_text(f"Error:\n{e}")
logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err) logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err)
@@ -579,14 +602,12 @@ async def _handle_message(
) )
return return
finally: finally:
if tracked_session_id and running_tasks is not None and exec_task is not None: if running_tasks is not None:
# Avoid removing a newer task for the same session_id if another run for sid, task in list(running_tasks.items()):
# registered while this one was finishing. if task is exec_task:
if running_tasks.get(tracked_session_id) is exec_task: running_tasks.pop(sid, None)
running_tasks.pop(tracked_session_id, None)
if edit_task is not None: await edits.shutdown()
await asyncio.gather(edit_task, return_exceptions=True)
elapsed = clock() - started_at elapsed = clock() - started_at
if cancelled: if cancelled:
@@ -631,7 +652,7 @@ async def _handle_message(
final_entities, final_entities,
) )
_, edited = await _send_or_edit_markdown( final_msg, edited = await _send_or_edit_markdown(
cfg.bot, cfg.bot,
chat_id=chat_id, chat_id=chat_id,
text=final_md, text=final_md,
@@ -640,12 +661,11 @@ async def _handle_message(
disable_notification=False, disable_notification=False,
limit=TELEGRAM_MARKDOWN_LIMIT, limit=TELEGRAM_MARKDOWN_LIMIT,
) )
if final_msg is None:
return
if progress_id is not None and (edit_message_id is None or not edited): if progress_id is not None and (edit_message_id is None or not edited):
try: logger.debug("[final] delete progress message_id=%s", progress_id)
logger.debug("[final] delete progress message_id=%s", progress_id) await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
except Exception:
pass
async def poll_updates(cfg: BridgeConfig): async def poll_updates(cfg: BridgeConfig):
@@ -654,12 +674,11 @@ async def poll_updates(cfg: BridgeConfig):
await _send_startup(cfg) await _send_startup(cfg)
while True: while True:
try: updates = await cfg.bot.get_updates(
updates = await cfg.bot.get_updates( offset=offset, timeout_s=50, allowed_updates=["message"]
offset=offset, timeout_s=50, allowed_updates=["message"] )
) if updates is None:
except Exception as e: logger.info("[loop] getUpdates failed")
logger.info("[loop] getUpdates failed: %s", e)
await asyncio.sleep(2) await asyncio.sleep(2)
continue continue
logger.debug("[loop] updates: %s", updates) logger.debug("[loop] updates: %s", updates)
@@ -724,7 +743,7 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
while True: while True:
chat_id, user_msg_id, text, resume_session = await queue.get() chat_id, user_msg_id, text, resume_session = await queue.get()
try: try:
await _handle_message( await handle_message(
cfg, cfg,
chat_id=chat_id, chat_id=chat_id,
user_msg_id=user_msg_id, user_msg_id=user_msg_id,
+43 -39
View File
@@ -1,8 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
import httpx import httpx
@@ -13,67 +11,73 @@ logger = logging.getLogger(__name__)
logger.addFilter(RedactTokenFilter()) logger.addFilter(RedactTokenFilter())
class TelegramAPIError(RuntimeError):
def __init__(
self, method: str, payload: dict[str, Any], status_code: int | None
) -> None:
desc = payload.get("description") or str(payload)
super().__init__(f"{method} failed: {desc}")
self.payload = payload
self.status_code = status_code
class TelegramClient: class TelegramClient:
def __init__( def __init__(
self, self,
token: str, token: str,
timeout_s: float = 120, timeout_s: float = 120,
client: httpx.AsyncClient | None = None, client: httpx.AsyncClient | None = None,
sleep: Callable[[float], Awaitable[None]] = asyncio.sleep,
) -> None: ) -> None:
if not token: if not token:
raise ValueError("Telegram token is empty") raise ValueError("Telegram token is empty")
self._base = f"https://api.telegram.org/bot{token}" self._base = f"https://api.telegram.org/bot{token}"
self._client = client or httpx.AsyncClient(timeout=timeout_s) self._client = client or httpx.AsyncClient(timeout=timeout_s)
self._owns_client = client is None self._owns_client = client is None
self._sleep = sleep
async def close(self) -> None: async def close(self) -> None:
if self._owns_client: if self._owns_client:
await self._client.aclose() await self._client.aclose()
async def _post(self, method: str, json_data: dict[str, Any]) -> Any: async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
logger.debug("[telegram] request %s: %s", method, json_data)
try: try:
logger.debug("[telegram] request %s: %s", method, json_data)
resp = await self._client.post(f"{self._base}/{method}", json=json_data) resp = await self._client.post(f"{self._base}/{method}", json=json_data)
payload: dict[str, Any] | None = None
try:
payload = resp.json()
except Exception:
resp.raise_for_status()
raise
if not payload.get("ok"):
params = payload.get("parameters") or {}
retry_after = params.get("retry_after")
if resp.status_code == 429 and isinstance(retry_after, int):
logger.warning(
"[telegram] 429 retry_after=%s method=%s", retry_after, method
)
await self._sleep(retry_after)
return await self._post(method, json_data)
raise TelegramAPIError(method, payload, resp.status_code)
logger.debug("[telegram] response %s: %s", method, payload)
return payload["result"]
except httpx.HTTPError as e: except httpx.HTTPError as e:
logger.error("Telegram network error: %s", e) url = getattr(e.request, "url", None)
raise logger.error(
"[telegram] network error method=%s url=%s: %s", method, url, e
)
return None
try:
payload = resp.json()
except Exception as e:
logger.error(
"[telegram] bad response method=%s status=%s url=%s: %s",
method,
resp.status_code,
resp.request.url,
e,
)
return None
if not isinstance(payload, dict):
logger.error(
"[telegram] invalid response method=%s url=%s: %r",
method,
resp.request.url,
payload,
)
return None
if not payload.get("ok"):
logger.error(
"[telegram] api error method=%s url=%s: %s",
method,
resp.request.url,
payload,
)
return None
logger.debug("[telegram] response %s: %s", method, payload)
return payload.get("result")
async def get_updates( async def get_updates(
self, self,
offset: int | None, offset: int | None,
timeout_s: int = 50, timeout_s: int = 50,
allowed_updates: list[str] | None = None, allowed_updates: list[str] | None = None,
) -> list[dict]: ) -> list[dict] | None:
params: dict[str, Any] = {"timeout": timeout_s} params: dict[str, Any] = {"timeout": timeout_s}
if offset is not None: if offset is not None:
params["offset"] = offset params["offset"] = offset
@@ -89,7 +93,7 @@ class TelegramClient:
disable_notification: bool | None = False, disable_notification: bool | None = False,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
) -> dict: ) -> dict | None:
params: dict[str, Any] = { params: dict[str, Any] = {
"chat_id": chat_id, "chat_id": chat_id,
"text": text, "text": text,
@@ -111,7 +115,7 @@ class TelegramClient:
text: str, text: str,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
) -> dict: ) -> dict | None:
params: dict[str, Any] = { params: dict[str, Any] = {
"chat_id": chat_id, "chat_id": chat_id,
"message_id": message_id, "message_id": message_id,
+129 -24
View File
@@ -186,12 +186,30 @@ class _FakeRunner:
class _FakeClock: class _FakeClock:
def __init__(self, start: float = 0.0) -> None: def __init__(self, start: float = 0.0) -> None:
self._now = start self._now = start
self._sleep_until: float | None = None
self._sleep_event: asyncio.Event | None = None
self.sleep_calls = 0
def __call__(self) -> float: def __call__(self) -> float:
return self._now return self._now
def set(self, value: float) -> None: def set(self, value: float) -> None:
self._now = value self._now = value
if self._sleep_until is None or self._sleep_event is None:
return
if self._sleep_until <= self._now:
self._sleep_event.set()
self._sleep_until = None
self._sleep_event = None
async def sleep(self, delay: float) -> None:
self.sleep_calls += 1
if delay <= 0:
await asyncio.sleep(0)
return
self._sleep_until = self._now + delay
self._sleep_event = asyncio.Event()
await self._sleep_event.wait()
class _FakeRunnerWithEvents: class _FakeRunnerWithEvents:
@@ -203,12 +221,16 @@ class _FakeRunnerWithEvents:
clock: _FakeClock, clock: _FakeClock,
answer: str = "ok", answer: str = "ok",
session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2", session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2",
advance_after: float | None = None,
hold: asyncio.Event | None = None,
) -> None: ) -> None:
self._events = events self._events = events
self._times = times self._times = times
self._clock = clock self._clock = clock
self._answer = answer self._answer = answer
self._session_id = session_id self._session_id = session_id
self._advance_after = advance_after
self._hold = hold
async def run_serialized(self, *_args, **kwargs) -> tuple[str, str, bool]: async def run_serialized(self, *_args, **kwargs) -> tuple[str, str, bool]:
on_event = kwargs.get("on_event") on_event = kwargs.get("on_event")
@@ -217,11 +239,16 @@ class _FakeRunnerWithEvents:
self._clock.set(when) self._clock.set(when)
await on_event(event) await on_event(event)
await asyncio.sleep(0) await asyncio.sleep(0)
if self._advance_after is not None:
self._clock.set(self._advance_after)
await asyncio.sleep(0)
if self._hold is not None:
await self._hold.wait()
return (self._session_id, self._answer, True) return (self._session_id, self._answer, True)
def test_final_notify_sends_loud_final_message() -> None: def test_final_notify_sends_loud_final_message() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_message from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot() bot = _FakeBot()
runner = _FakeRunner(answer="ok") runner = _FakeRunner(answer="ok")
@@ -235,7 +262,7 @@ def test_final_notify_sends_loud_final_message() -> None:
) )
asyncio.run( asyncio.run(
_handle_message( handle_message(
cfg, cfg,
chat_id=123, chat_id=123,
user_msg_id=10, user_msg_id=10,
@@ -250,7 +277,7 @@ def test_final_notify_sends_loud_final_message() -> None:
def test_new_final_message_forces_notification_when_too_long_to_edit() -> None: def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_message from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot() bot = _FakeBot()
runner = _FakeRunner(answer="x" * 10_000) runner = _FakeRunner(answer="x" * 10_000)
@@ -264,7 +291,7 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
) )
asyncio.run( asyncio.run(
_handle_message( handle_message(
cfg, cfg,
chat_id=123, chat_id=123,
user_msg_id=10, user_msg_id=10,
@@ -279,7 +306,7 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
def test_progress_edits_are_rate_limited() -> None: def test_progress_edits_are_rate_limited() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_message from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot() bot = _FakeBot()
clock = _FakeClock() clock = _FakeClock()
@@ -294,13 +321,61 @@ def test_progress_edits_are_rate_limited() -> None:
}, },
}, },
{ {
"type": "item.completed", "type": "item.started",
"item": {
"id": "item_1",
"type": "command_execution",
"command": "echo 2",
"status": "in_progress",
},
},
]
runner = _FakeRunnerWithEvents(
events=events,
times=[0.2, 0.4],
clock=clock,
advance_after=1.0,
)
cfg = BridgeConfig(
bot=bot, # type: ignore[arg-type]
runner=runner, # type: ignore[arg-type]
chat_id=123,
final_notify=True,
startup_msg="",
max_concurrency=1,
)
asyncio.run(
handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="hi",
resume_session=None,
clock=clock,
sleep=clock.sleep,
progress_edit_every=1.0,
)
)
assert len(bot.edit_calls) == 1
assert "echo 2" in bot.edit_calls[0]["text"]
def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
clock = _FakeClock()
hold = asyncio.Event()
events = [
{
"type": "item.started",
"item": { "item": {
"id": "item_0", "id": "item_0",
"type": "command_execution", "type": "command_execution",
"command": "echo 1", "command": "echo 1",
"exit_code": 0, "status": "in_progress",
"status": "completed",
}, },
}, },
{ {
@@ -315,8 +390,10 @@ def test_progress_edits_are_rate_limited() -> None:
] ]
runner = _FakeRunnerWithEvents( runner = _FakeRunnerWithEvents(
events=events, events=events,
times=[0.2, 0.4, 1.2], times=[0.2, 0.4],
clock=clock, clock=clock,
advance_after=None,
hold=hold,
) )
cfg = BridgeConfig( cfg = BridgeConfig(
bot=bot, # type: ignore[arg-type] bot=bot, # type: ignore[arg-type]
@@ -327,23 +404,50 @@ def test_progress_edits_are_rate_limited() -> None:
max_concurrency=1, max_concurrency=1,
) )
asyncio.run( async def run_test() -> None:
_handle_message( task = asyncio.create_task(
cfg, handle_message(
chat_id=123, cfg,
user_msg_id=10, chat_id=123,
text="hi", user_msg_id=10,
resume_session=None, text="hi",
clock=clock, resume_session=None,
progress_edit_every=1.0, clock=clock,
sleep=clock.sleep,
progress_edit_every=1.0,
)
) )
)
assert len(bot.edit_calls) == 1 for _ in range(100):
if clock._sleep_until is not None:
break
await asyncio.sleep(0)
assert clock._sleep_until == pytest.approx(1.0)
clock.set(1.0)
for _ in range(100):
if bot.edit_calls:
break
await asyncio.sleep(0)
assert len(bot.edit_calls) == 1
for _ in range(5):
await asyncio.sleep(0)
assert clock.sleep_calls == 1
assert clock._sleep_until is None
hold.set()
await task
asyncio.run(run_test())
def test_bridge_flow_sends_progress_edits_and_final_resume() -> None: def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_message from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot() bot = _FakeBot()
clock = _FakeClock() clock = _FakeClock()
@@ -386,13 +490,14 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
) )
asyncio.run( asyncio.run(
_handle_message( handle_message(
cfg, cfg,
chat_id=123, chat_id=123,
user_msg_id=42, user_msg_id=42,
text="do it", text="do it",
resume_session=None, resume_session=None,
clock=clock, clock=clock,
sleep=clock.sleep,
progress_edit_every=1.0, progress_edit_every=1.0,
) )
) )
@@ -529,7 +634,7 @@ class _FakeRunnerCancellable:
def test_handle_message_cancelled_renders_cancelled_state() -> None: def test_handle_message_cancelled_renders_cancelled_state() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_message from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot() bot = _FakeBot()
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2" session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
@@ -546,7 +651,7 @@ def test_handle_message_cancelled_renders_cancelled_state() -> None:
async def run_test(): async def run_test():
task = asyncio.create_task( task = asyncio.create_task(
_handle_message( handle_message(
cfg, cfg,
chat_id=123, chat_id=123,
user_msg_id=10, user_msg_id=10,
+12 -24
View File
@@ -8,46 +8,35 @@ from takopi.logging import RedactTokenFilter
from takopi.telegram import TelegramClient from takopi.telegram import TelegramClient
def test_telegram_429_retry_after_calls_sleep() -> None: def test_telegram_429_no_retry() -> None:
calls: list[int] = [] calls: list[int] = []
sleeps: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleeps.append(seconds)
def handler(request: httpx.Request) -> httpx.Response: def handler(request: httpx.Request) -> httpx.Response:
calls.append(1) calls.append(1)
if len(calls) == 1:
return httpx.Response(
429,
json={
"ok": False,
"description": "retry",
"parameters": {"retry_after": 3},
},
request=request,
)
return httpx.Response( return httpx.Response(
200, 429,
json={"ok": True, "result": {"message_id": 1}}, json={
"ok": False,
"description": "retry",
"parameters": {"retry_after": 3},
},
request=request, request=request,
) )
transport = httpx.MockTransport(handler) transport = httpx.MockTransport(handler)
async def run() -> dict: async def run() -> dict | None:
client = httpx.AsyncClient(transport=transport) client = httpx.AsyncClient(transport=transport)
try: try:
tg = TelegramClient("123:abcDEF_ghij", client=client, sleep=fake_sleep) tg = TelegramClient("123:abcDEF_ghij", client=client)
return await tg._post("sendMessage", {"chat_id": 1, "text": "hi"}) return await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
finally: finally:
await client.aclose() await client.aclose()
result = asyncio.run(run()) result = asyncio.run(run())
assert result == {"message_id": 1} assert result is None
assert sleeps == [3] assert len(calls) == 1
assert len(calls) == 2
def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> None: def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> None:
@@ -70,8 +59,7 @@ def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> Non
await client.aclose() await client.aclose()
caplog.set_level(logging.ERROR) caplog.set_level(logging.ERROR)
with pytest.raises(httpx.HTTPStatusError): asyncio.run(run())
asyncio.run(run())
root_logger.removeFilter(redactor) root_logger.removeFilter(redactor)