feat: better progress edits, simpler telegram client (#5)
This commit is contained in:
+2
-2
@@ -38,7 +38,7 @@ The orchestrator module containing:
|
|||||||
| `CodexExecRunner` | Spawns `codex exec`, streams JSONL, handles cancellation |
|
| `CodexExecRunner` | Spawns `codex exec`, streams JSONL, handles cancellation |
|
||||||
| `poll_updates()` | Async generator that drains backlog, long-polls updates, filters messages |
|
| `poll_updates()` | Async generator that drains backlog, long-polls updates, filters messages |
|
||||||
| `_run_main_loop()` | TaskGroup-based main loop that spawns per-message handlers |
|
| `_run_main_loop()` | TaskGroup-based main loop that spawns per-message handlers |
|
||||||
| `_handle_message()` | Per-message handler with progress updates |
|
| `handle_message()` | Per-message handler with progress updates |
|
||||||
| `extract_session_id()` | Parses `resume: <uuid>` from message text |
|
| `extract_session_id()` | Parses `resume: <uuid>` from message text |
|
||||||
| `truncate_for_telegram()` | Smart truncation preserving resume lines |
|
| `truncate_for_telegram()` | Smart truncation preserving resume lines |
|
||||||
|
|
||||||
@@ -122,7 +122,7 @@ poll_updates() drains backlog, long-polls, filters chat_id == from_id == cfg.cha
|
|||||||
↓
|
↓
|
||||||
_run_main_loop() spawns tasks in TaskGroup
|
_run_main_loop() spawns tasks in TaskGroup
|
||||||
↓
|
↓
|
||||||
_handle_message() spawned as task
|
handle_message() spawned as task
|
||||||
↓
|
↓
|
||||||
Send initial progress message (silent)
|
Send initial progress message (silent)
|
||||||
↓
|
↓
|
||||||
|
|||||||
+156
-137
@@ -155,26 +155,17 @@ async def _send_or_edit_markdown(
|
|||||||
reply_to_message_id: int | None = None,
|
reply_to_message_id: int | None = None,
|
||||||
disable_notification: bool = False,
|
disable_notification: bool = False,
|
||||||
limit: int = TELEGRAM_MARKDOWN_LIMIT,
|
limit: int = TELEGRAM_MARKDOWN_LIMIT,
|
||||||
) -> tuple[dict[str, Any], bool]:
|
) -> tuple[dict[str, Any] | None, bool]:
|
||||||
if edit_message_id is not None:
|
if edit_message_id is not None:
|
||||||
rendered, entities = prepare_telegram(text, limit=limit)
|
rendered, entities = prepare_telegram(text, limit=limit)
|
||||||
try:
|
edited = await bot.edit_message_text(
|
||||||
return (
|
chat_id=chat_id,
|
||||||
await bot.edit_message_text(
|
message_id=edit_message_id,
|
||||||
chat_id=chat_id,
|
text=rendered,
|
||||||
message_id=edit_message_id,
|
entities=entities,
|
||||||
text=rendered,
|
)
|
||||||
entities=entities,
|
if edited is not None:
|
||||||
),
|
return (edited, True)
|
||||||
True,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(
|
|
||||||
"[tg] edit failed chat_id=%s message_id=%s: %s",
|
|
||||||
chat_id,
|
|
||||||
edit_message_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
|
|
||||||
rendered, entities = prepare_telegram(text, limit=limit)
|
rendered, entities = prepare_telegram(text, limit=limit)
|
||||||
return (
|
return (
|
||||||
@@ -192,6 +183,90 @@ async def _send_or_edit_markdown(
|
|||||||
EventCallback = Callable[[dict[str, Any]], Awaitable[None] | None]
|
EventCallback = Callable[[dict[str, Any]], Awaitable[None] | None]
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressEdits:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
bot: TelegramClient,
|
||||||
|
chat_id: int,
|
||||||
|
progress_id: int | None,
|
||||||
|
renderer: ExecProgressRenderer,
|
||||||
|
started_at: float,
|
||||||
|
progress_edit_every: float,
|
||||||
|
clock: Callable[[], float],
|
||||||
|
sleep: Callable[[float], Awaitable[None]],
|
||||||
|
limit: int,
|
||||||
|
last_edit_at: float,
|
||||||
|
last_rendered: str | None,
|
||||||
|
) -> None:
|
||||||
|
self.bot = bot
|
||||||
|
self.chat_id = chat_id
|
||||||
|
self.progress_id = progress_id
|
||||||
|
self.renderer = renderer
|
||||||
|
self.started_at = started_at
|
||||||
|
self.progress_edit_every = progress_edit_every
|
||||||
|
self.clock = clock
|
||||||
|
self.sleep = sleep
|
||||||
|
self.limit = limit
|
||||||
|
self.last_edit_at = last_edit_at
|
||||||
|
self.last_rendered = last_rendered
|
||||||
|
self._event_seq = 0
|
||||||
|
self._published_seq = 0
|
||||||
|
self.wakeup = asyncio.Event()
|
||||||
|
self.task: asyncio.Task[None] | None = (
|
||||||
|
asyncio.create_task(self.run()) if self.progress_id is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
if self.progress_id is None:
|
||||||
|
return
|
||||||
|
while True:
|
||||||
|
await self.wakeup.wait()
|
||||||
|
self.wakeup.clear()
|
||||||
|
while self._published_seq < self._event_seq:
|
||||||
|
await self.sleep(
|
||||||
|
max(
|
||||||
|
0.0,
|
||||||
|
self.last_edit_at + self.progress_edit_every - self.clock(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_at_render = self._event_seq
|
||||||
|
now = self.clock()
|
||||||
|
md = self.renderer.render_progress(now - self.started_at)
|
||||||
|
rendered, entities = prepare_telegram(md, limit=self.limit)
|
||||||
|
if rendered != self.last_rendered:
|
||||||
|
logger.debug(
|
||||||
|
"[progress] edit message_id=%s md=%s", self.progress_id, md
|
||||||
|
)
|
||||||
|
self.last_edit_at = now
|
||||||
|
edited = await self.bot.edit_message_text(
|
||||||
|
chat_id=self.chat_id,
|
||||||
|
message_id=self.progress_id,
|
||||||
|
text=rendered,
|
||||||
|
entities=entities,
|
||||||
|
)
|
||||||
|
if edited is not None:
|
||||||
|
self.last_rendered = rendered
|
||||||
|
|
||||||
|
self._published_seq = seq_at_render
|
||||||
|
self.wakeup.clear()
|
||||||
|
|
||||||
|
async def on_event(self, evt: dict[str, Any]) -> None:
|
||||||
|
if not self.renderer.note_event(evt):
|
||||||
|
return
|
||||||
|
if self.progress_id is None:
|
||||||
|
return
|
||||||
|
self._event_seq += 1
|
||||||
|
self.wakeup.set()
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
if self.task is None:
|
||||||
|
return
|
||||||
|
self.task.cancel()
|
||||||
|
await asyncio.gather(self.task, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
class CodexExecRunner:
|
class CodexExecRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -402,25 +477,20 @@ def _parse_bridge_config(
|
|||||||
|
|
||||||
|
|
||||||
async def _send_startup(cfg: BridgeConfig) -> None:
|
async def _send_startup(cfg: BridgeConfig) -> None:
|
||||||
try:
|
logger.debug("[startup] message: %s", cfg.startup_msg)
|
||||||
logger.debug("[startup] message: %s", cfg.startup_msg)
|
sent = await cfg.bot.send_message(chat_id=cfg.chat_id, text=cfg.startup_msg)
|
||||||
await cfg.bot.send_message(chat_id=cfg.chat_id, text=cfg.startup_msg)
|
if sent is not None:
|
||||||
logger.info("[startup] sent startup message to chat_id=%s", cfg.chat_id)
|
logger.info("[startup] sent startup message to chat_id=%s", cfg.chat_id)
|
||||||
except Exception as e:
|
|
||||||
logger.info(
|
|
||||||
"[startup] failed to send startup message to chat_id=%s: %s", cfg.chat_id, e
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
|
async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
|
||||||
drained = 0
|
drained = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
updates = await cfg.bot.get_updates(
|
||||||
updates = await cfg.bot.get_updates(
|
offset=offset, timeout_s=0, allowed_updates=["message"]
|
||||||
offset=offset, timeout_s=0, allowed_updates=["message"]
|
)
|
||||||
)
|
if updates is None:
|
||||||
except Exception as e:
|
logger.info("[startup] backlog drain failed")
|
||||||
logger.info("[startup] backlog drain failed: %s", e)
|
|
||||||
return offset
|
return offset
|
||||||
logger.debug("[startup] backlog updates: %s", updates)
|
logger.debug("[startup] backlog updates: %s", updates)
|
||||||
if not updates:
|
if not updates:
|
||||||
@@ -431,7 +501,7 @@ async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
|
|||||||
drained += len(updates)
|
drained += len(updates)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_message(
|
async def handle_message(
|
||||||
cfg: BridgeConfig,
|
cfg: BridgeConfig,
|
||||||
*,
|
*,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
@@ -440,6 +510,7 @@ async def _handle_message(
|
|||||||
resume_session: str | None,
|
resume_session: str | None,
|
||||||
running_tasks: dict[str, asyncio.Task[Any]] | None = None,
|
running_tasks: dict[str, asyncio.Task[Any]] | None = None,
|
||||||
clock: Callable[[], float] = time.monotonic,
|
clock: Callable[[], float] = time.monotonic,
|
||||||
|
sleep: Callable[[float], Awaitable[None]] = asyncio.sleep,
|
||||||
progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
|
progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -453,104 +524,57 @@ async def _handle_message(
|
|||||||
progress_renderer = ExecProgressRenderer(max_actions=5)
|
progress_renderer = ExecProgressRenderer(max_actions=5)
|
||||||
|
|
||||||
progress_id: int | None = None
|
progress_id: int | None = None
|
||||||
|
|
||||||
last_edit_at = 0.0
|
last_edit_at = 0.0
|
||||||
edit_task: asyncio.Task[None] | None = None
|
|
||||||
last_rendered: str | None = None
|
last_rendered: str | None = None
|
||||||
pending_rendered: str | None = None
|
|
||||||
|
|
||||||
async def _edit_progress(
|
initial_md = progress_renderer.render_progress(0.0)
|
||||||
md: str, rendered: str, entities: list[dict[str, Any]] | None
|
initial_rendered, initial_entities = prepare_telegram(
|
||||||
) -> None:
|
initial_md, limit=TELEGRAM_MARKDOWN_LIMIT
|
||||||
nonlocal last_rendered, pending_rendered
|
)
|
||||||
if progress_id is None:
|
logger.debug(
|
||||||
return
|
"[progress] send reply_to=%s md=%s rendered=%s entities=%s",
|
||||||
logger.debug(
|
user_msg_id,
|
||||||
"[progress] edit message_id=%s md=%s rendered=%s entities=%s",
|
initial_md,
|
||||||
progress_id,
|
initial_rendered,
|
||||||
md,
|
initial_entities,
|
||||||
rendered,
|
)
|
||||||
entities,
|
progress_msg = await cfg.bot.send_message(
|
||||||
)
|
chat_id=chat_id,
|
||||||
try:
|
text=initial_rendered,
|
||||||
await cfg.bot.edit_message_text(
|
entities=initial_entities,
|
||||||
chat_id=chat_id,
|
reply_to_message_id=user_msg_id,
|
||||||
message_id=progress_id,
|
disable_notification=True,
|
||||||
text=rendered,
|
)
|
||||||
entities=entities,
|
if progress_msg is not None:
|
||||||
)
|
|
||||||
last_rendered = rendered
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(
|
|
||||||
"[progress] edit failed chat_id=%s message_id=%s: %s",
|
|
||||||
chat_id,
|
|
||||||
progress_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if pending_rendered == rendered:
|
|
||||||
pending_rendered = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
initial_md = progress_renderer.render_progress(0.0)
|
|
||||||
initial_rendered, initial_entities = prepare_telegram(
|
|
||||||
initial_md, limit=TELEGRAM_MARKDOWN_LIMIT
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"[progress] send reply_to=%s md=%s rendered=%s entities=%s",
|
|
||||||
user_msg_id,
|
|
||||||
initial_md,
|
|
||||||
initial_rendered,
|
|
||||||
initial_entities,
|
|
||||||
)
|
|
||||||
progress_msg = await cfg.bot.send_message(
|
|
||||||
chat_id=chat_id,
|
|
||||||
text=initial_rendered,
|
|
||||||
entities=initial_entities,
|
|
||||||
reply_to_message_id=user_msg_id,
|
|
||||||
disable_notification=True,
|
|
||||||
)
|
|
||||||
progress_id = int(progress_msg["message_id"])
|
progress_id = int(progress_msg["message_id"])
|
||||||
last_edit_at = clock()
|
last_edit_at = clock()
|
||||||
last_rendered = initial_rendered
|
last_rendered = initial_rendered
|
||||||
logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id)
|
logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id)
|
||||||
except Exception as e:
|
|
||||||
logger.info(
|
edits = ProgressEdits(
|
||||||
"[handle] failed to send progress message chat_id=%s: %s", chat_id, e
|
bot=cfg.bot,
|
||||||
)
|
chat_id=chat_id,
|
||||||
|
progress_id=progress_id,
|
||||||
|
renderer=progress_renderer,
|
||||||
|
started_at=started_at,
|
||||||
|
progress_edit_every=progress_edit_every,
|
||||||
|
clock=clock,
|
||||||
|
sleep=sleep,
|
||||||
|
limit=TELEGRAM_MARKDOWN_LIMIT,
|
||||||
|
last_edit_at=last_edit_at,
|
||||||
|
last_rendered=last_rendered,
|
||||||
|
)
|
||||||
|
|
||||||
exec_task: asyncio.Task[tuple[str, str, bool]] | None = None
|
exec_task: asyncio.Task[tuple[str, str, bool]] | None = None
|
||||||
tracked_session_id: str | None = None
|
|
||||||
|
|
||||||
async def on_event(evt: dict[str, Any]) -> None:
|
async def on_event(evt: dict[str, Any]) -> None:
|
||||||
nonlocal last_edit_at, edit_task, pending_rendered, tracked_session_id
|
|
||||||
if progress_id is None:
|
|
||||||
return
|
|
||||||
if not progress_renderer.note_event(evt):
|
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
evt["type"] == "thread.started"
|
evt["type"] == "thread.started"
|
||||||
and running_tasks is not None
|
and running_tasks is not None
|
||||||
and exec_task is not None
|
and exec_task is not None
|
||||||
):
|
):
|
||||||
tracked_session_id = progress_renderer.resume_session
|
running_tasks[evt["thread_id"]] = exec_task
|
||||||
if tracked_session_id:
|
await edits.on_event(evt)
|
||||||
running_tasks[tracked_session_id] = exec_task
|
|
||||||
|
|
||||||
now = clock()
|
|
||||||
if (now - last_edit_at) < progress_edit_every:
|
|
||||||
return
|
|
||||||
if edit_task is not None and not edit_task.done():
|
|
||||||
return
|
|
||||||
elapsed = now - started_at
|
|
||||||
md = progress_renderer.render_progress(elapsed)
|
|
||||||
rendered, entities = prepare_telegram(md, limit=TELEGRAM_MARKDOWN_LIMIT)
|
|
||||||
if rendered == last_rendered or rendered == pending_rendered:
|
|
||||||
return
|
|
||||||
last_edit_at = now
|
|
||||||
pending_rendered = rendered
|
|
||||||
edit_task = asyncio.create_task(_edit_progress(md, rendered, entities))
|
|
||||||
|
|
||||||
exec_task = asyncio.create_task(
|
exec_task = asyncio.create_task(
|
||||||
cfg.runner.run_serialized(text, resume_session, on_event=on_event)
|
cfg.runner.run_serialized(text, resume_session, on_event=on_event)
|
||||||
@@ -561,10 +585,9 @@ async def _handle_message(
|
|||||||
session_id, answer, saw_agent_message = await exec_task
|
session_id, answer, saw_agent_message = await exec_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
cancelled = True
|
cancelled = True
|
||||||
session_id = tracked_session_id or resume_session
|
session_id = progress_renderer.resume_session or resume_session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if edit_task is not None:
|
await edits.shutdown()
|
||||||
await asyncio.gather(edit_task, return_exceptions=True)
|
|
||||||
|
|
||||||
err = _clamp_tg_text(f"Error:\n{e}")
|
err = _clamp_tg_text(f"Error:\n{e}")
|
||||||
logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err)
|
logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err)
|
||||||
@@ -579,14 +602,12 @@ async def _handle_message(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
finally:
|
finally:
|
||||||
if tracked_session_id and running_tasks is not None and exec_task is not None:
|
if running_tasks is not None:
|
||||||
# Avoid removing a newer task for the same session_id if another run
|
for sid, task in list(running_tasks.items()):
|
||||||
# registered while this one was finishing.
|
if task is exec_task:
|
||||||
if running_tasks.get(tracked_session_id) is exec_task:
|
running_tasks.pop(sid, None)
|
||||||
running_tasks.pop(tracked_session_id, None)
|
|
||||||
|
|
||||||
if edit_task is not None:
|
await edits.shutdown()
|
||||||
await asyncio.gather(edit_task, return_exceptions=True)
|
|
||||||
|
|
||||||
elapsed = clock() - started_at
|
elapsed = clock() - started_at
|
||||||
if cancelled:
|
if cancelled:
|
||||||
@@ -631,7 +652,7 @@ async def _handle_message(
|
|||||||
final_entities,
|
final_entities,
|
||||||
)
|
)
|
||||||
|
|
||||||
_, edited = await _send_or_edit_markdown(
|
final_msg, edited = await _send_or_edit_markdown(
|
||||||
cfg.bot,
|
cfg.bot,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
text=final_md,
|
text=final_md,
|
||||||
@@ -640,12 +661,11 @@ async def _handle_message(
|
|||||||
disable_notification=False,
|
disable_notification=False,
|
||||||
limit=TELEGRAM_MARKDOWN_LIMIT,
|
limit=TELEGRAM_MARKDOWN_LIMIT,
|
||||||
)
|
)
|
||||||
|
if final_msg is None:
|
||||||
|
return
|
||||||
if progress_id is not None and (edit_message_id is None or not edited):
|
if progress_id is not None and (edit_message_id is None or not edited):
|
||||||
try:
|
logger.debug("[final] delete progress message_id=%s", progress_id)
|
||||||
logger.debug("[final] delete progress message_id=%s", progress_id)
|
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
|
||||||
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def poll_updates(cfg: BridgeConfig):
|
async def poll_updates(cfg: BridgeConfig):
|
||||||
@@ -654,12 +674,11 @@ async def poll_updates(cfg: BridgeConfig):
|
|||||||
await _send_startup(cfg)
|
await _send_startup(cfg)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
updates = await cfg.bot.get_updates(
|
||||||
updates = await cfg.bot.get_updates(
|
offset=offset, timeout_s=50, allowed_updates=["message"]
|
||||||
offset=offset, timeout_s=50, allowed_updates=["message"]
|
)
|
||||||
)
|
if updates is None:
|
||||||
except Exception as e:
|
logger.info("[loop] getUpdates failed")
|
||||||
logger.info("[loop] getUpdates failed: %s", e)
|
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
continue
|
continue
|
||||||
logger.debug("[loop] updates: %s", updates)
|
logger.debug("[loop] updates: %s", updates)
|
||||||
@@ -724,7 +743,7 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
|
|||||||
while True:
|
while True:
|
||||||
chat_id, user_msg_id, text, resume_session = await queue.get()
|
chat_id, user_msg_id, text, resume_session = await queue.get()
|
||||||
try:
|
try:
|
||||||
await _handle_message(
|
await handle_message(
|
||||||
cfg,
|
cfg,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
user_msg_id=user_msg_id,
|
user_msg_id=user_msg_id,
|
||||||
|
|||||||
+43
-39
@@ -1,8 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -13,67 +11,73 @@ logger = logging.getLogger(__name__)
|
|||||||
logger.addFilter(RedactTokenFilter())
|
logger.addFilter(RedactTokenFilter())
|
||||||
|
|
||||||
|
|
||||||
class TelegramAPIError(RuntimeError):
|
|
||||||
def __init__(
|
|
||||||
self, method: str, payload: dict[str, Any], status_code: int | None
|
|
||||||
) -> None:
|
|
||||||
desc = payload.get("description") or str(payload)
|
|
||||||
super().__init__(f"{method} failed: {desc}")
|
|
||||||
self.payload = payload
|
|
||||||
self.status_code = status_code
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramClient:
|
class TelegramClient:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
timeout_s: float = 120,
|
timeout_s: float = 120,
|
||||||
client: httpx.AsyncClient | None = None,
|
client: httpx.AsyncClient | None = None,
|
||||||
sleep: Callable[[float], Awaitable[None]] = asyncio.sleep,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if not token:
|
if not token:
|
||||||
raise ValueError("Telegram token is empty")
|
raise ValueError("Telegram token is empty")
|
||||||
self._base = f"https://api.telegram.org/bot{token}"
|
self._base = f"https://api.telegram.org/bot{token}"
|
||||||
self._client = client or httpx.AsyncClient(timeout=timeout_s)
|
self._client = client or httpx.AsyncClient(timeout=timeout_s)
|
||||||
self._owns_client = client is None
|
self._owns_client = client is None
|
||||||
self._sleep = sleep
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
if self._owns_client:
|
if self._owns_client:
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
|
|
||||||
async def _post(self, method: str, json_data: dict[str, Any]) -> Any:
|
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
|
||||||
|
logger.debug("[telegram] request %s: %s", method, json_data)
|
||||||
try:
|
try:
|
||||||
logger.debug("[telegram] request %s: %s", method, json_data)
|
|
||||||
resp = await self._client.post(f"{self._base}/{method}", json=json_data)
|
resp = await self._client.post(f"{self._base}/{method}", json=json_data)
|
||||||
payload: dict[str, Any] | None = None
|
|
||||||
try:
|
|
||||||
payload = resp.json()
|
|
||||||
except Exception:
|
|
||||||
resp.raise_for_status()
|
|
||||||
raise
|
|
||||||
if not payload.get("ok"):
|
|
||||||
params = payload.get("parameters") or {}
|
|
||||||
retry_after = params.get("retry_after")
|
|
||||||
if resp.status_code == 429 and isinstance(retry_after, int):
|
|
||||||
logger.warning(
|
|
||||||
"[telegram] 429 retry_after=%s method=%s", retry_after, method
|
|
||||||
)
|
|
||||||
await self._sleep(retry_after)
|
|
||||||
return await self._post(method, json_data)
|
|
||||||
raise TelegramAPIError(method, payload, resp.status_code)
|
|
||||||
logger.debug("[telegram] response %s: %s", method, payload)
|
|
||||||
return payload["result"]
|
|
||||||
except httpx.HTTPError as e:
|
except httpx.HTTPError as e:
|
||||||
logger.error("Telegram network error: %s", e)
|
url = getattr(e.request, "url", None)
|
||||||
raise
|
logger.error(
|
||||||
|
"[telegram] network error method=%s url=%s: %s", method, url, e
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = resp.json()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"[telegram] bad response method=%s status=%s url=%s: %s",
|
||||||
|
method,
|
||||||
|
resp.status_code,
|
||||||
|
resp.request.url,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
logger.error(
|
||||||
|
"[telegram] invalid response method=%s url=%s: %r",
|
||||||
|
method,
|
||||||
|
resp.request.url,
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not payload.get("ok"):
|
||||||
|
logger.error(
|
||||||
|
"[telegram] api error method=%s url=%s: %s",
|
||||||
|
method,
|
||||||
|
resp.request.url,
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug("[telegram] response %s: %s", method, payload)
|
||||||
|
return payload.get("result")
|
||||||
|
|
||||||
async def get_updates(
|
async def get_updates(
|
||||||
self,
|
self,
|
||||||
offset: int | None,
|
offset: int | None,
|
||||||
timeout_s: int = 50,
|
timeout_s: int = 50,
|
||||||
allowed_updates: list[str] | None = None,
|
allowed_updates: list[str] | None = None,
|
||||||
) -> list[dict]:
|
) -> list[dict] | None:
|
||||||
params: dict[str, Any] = {"timeout": timeout_s}
|
params: dict[str, Any] = {"timeout": timeout_s}
|
||||||
if offset is not None:
|
if offset is not None:
|
||||||
params["offset"] = offset
|
params["offset"] = offset
|
||||||
@@ -89,7 +93,7 @@ class TelegramClient:
|
|||||||
disable_notification: bool | None = False,
|
disable_notification: bool | None = False,
|
||||||
entities: list[dict] | None = None,
|
entities: list[dict] | None = None,
|
||||||
parse_mode: str | None = None,
|
parse_mode: str | None = None,
|
||||||
) -> dict:
|
) -> dict | None:
|
||||||
params: dict[str, Any] = {
|
params: dict[str, Any] = {
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
"text": text,
|
"text": text,
|
||||||
@@ -111,7 +115,7 @@ class TelegramClient:
|
|||||||
text: str,
|
text: str,
|
||||||
entities: list[dict] | None = None,
|
entities: list[dict] | None = None,
|
||||||
parse_mode: str | None = None,
|
parse_mode: str | None = None,
|
||||||
) -> dict:
|
) -> dict | None:
|
||||||
params: dict[str, Any] = {
|
params: dict[str, Any] = {
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
|
|||||||
+129
-24
@@ -186,12 +186,30 @@ class _FakeRunner:
|
|||||||
class _FakeClock:
|
class _FakeClock:
|
||||||
def __init__(self, start: float = 0.0) -> None:
|
def __init__(self, start: float = 0.0) -> None:
|
||||||
self._now = start
|
self._now = start
|
||||||
|
self._sleep_until: float | None = None
|
||||||
|
self._sleep_event: asyncio.Event | None = None
|
||||||
|
self.sleep_calls = 0
|
||||||
|
|
||||||
def __call__(self) -> float:
|
def __call__(self) -> float:
|
||||||
return self._now
|
return self._now
|
||||||
|
|
||||||
def set(self, value: float) -> None:
|
def set(self, value: float) -> None:
|
||||||
self._now = value
|
self._now = value
|
||||||
|
if self._sleep_until is None or self._sleep_event is None:
|
||||||
|
return
|
||||||
|
if self._sleep_until <= self._now:
|
||||||
|
self._sleep_event.set()
|
||||||
|
self._sleep_until = None
|
||||||
|
self._sleep_event = None
|
||||||
|
|
||||||
|
async def sleep(self, delay: float) -> None:
|
||||||
|
self.sleep_calls += 1
|
||||||
|
if delay <= 0:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
return
|
||||||
|
self._sleep_until = self._now + delay
|
||||||
|
self._sleep_event = asyncio.Event()
|
||||||
|
await self._sleep_event.wait()
|
||||||
|
|
||||||
|
|
||||||
class _FakeRunnerWithEvents:
|
class _FakeRunnerWithEvents:
|
||||||
@@ -203,12 +221,16 @@ class _FakeRunnerWithEvents:
|
|||||||
clock: _FakeClock,
|
clock: _FakeClock,
|
||||||
answer: str = "ok",
|
answer: str = "ok",
|
||||||
session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2",
|
session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2",
|
||||||
|
advance_after: float | None = None,
|
||||||
|
hold: asyncio.Event | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._events = events
|
self._events = events
|
||||||
self._times = times
|
self._times = times
|
||||||
self._clock = clock
|
self._clock = clock
|
||||||
self._answer = answer
|
self._answer = answer
|
||||||
self._session_id = session_id
|
self._session_id = session_id
|
||||||
|
self._advance_after = advance_after
|
||||||
|
self._hold = hold
|
||||||
|
|
||||||
async def run_serialized(self, *_args, **kwargs) -> tuple[str, str, bool]:
|
async def run_serialized(self, *_args, **kwargs) -> tuple[str, str, bool]:
|
||||||
on_event = kwargs.get("on_event")
|
on_event = kwargs.get("on_event")
|
||||||
@@ -217,11 +239,16 @@ class _FakeRunnerWithEvents:
|
|||||||
self._clock.set(when)
|
self._clock.set(when)
|
||||||
await on_event(event)
|
await on_event(event)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
if self._advance_after is not None:
|
||||||
|
self._clock.set(self._advance_after)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if self._hold is not None:
|
||||||
|
await self._hold.wait()
|
||||||
return (self._session_id, self._answer, True)
|
return (self._session_id, self._answer, True)
|
||||||
|
|
||||||
|
|
||||||
def test_final_notify_sends_loud_final_message() -> None:
|
def test_final_notify_sends_loud_final_message() -> None:
|
||||||
from takopi.exec_bridge import BridgeConfig, _handle_message
|
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||||
|
|
||||||
bot = _FakeBot()
|
bot = _FakeBot()
|
||||||
runner = _FakeRunner(answer="ok")
|
runner = _FakeRunner(answer="ok")
|
||||||
@@ -235,7 +262,7 @@ def test_final_notify_sends_loud_final_message() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
_handle_message(
|
handle_message(
|
||||||
cfg,
|
cfg,
|
||||||
chat_id=123,
|
chat_id=123,
|
||||||
user_msg_id=10,
|
user_msg_id=10,
|
||||||
@@ -250,7 +277,7 @@ def test_final_notify_sends_loud_final_message() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
|
def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
|
||||||
from takopi.exec_bridge import BridgeConfig, _handle_message
|
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||||
|
|
||||||
bot = _FakeBot()
|
bot = _FakeBot()
|
||||||
runner = _FakeRunner(answer="x" * 10_000)
|
runner = _FakeRunner(answer="x" * 10_000)
|
||||||
@@ -264,7 +291,7 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
_handle_message(
|
handle_message(
|
||||||
cfg,
|
cfg,
|
||||||
chat_id=123,
|
chat_id=123,
|
||||||
user_msg_id=10,
|
user_msg_id=10,
|
||||||
@@ -279,7 +306,7 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_progress_edits_are_rate_limited() -> None:
|
def test_progress_edits_are_rate_limited() -> None:
|
||||||
from takopi.exec_bridge import BridgeConfig, _handle_message
|
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||||
|
|
||||||
bot = _FakeBot()
|
bot = _FakeBot()
|
||||||
clock = _FakeClock()
|
clock = _FakeClock()
|
||||||
@@ -294,13 +321,61 @@ def test_progress_edits_are_rate_limited() -> None:
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "item.completed",
|
"type": "item.started",
|
||||||
|
"item": {
|
||||||
|
"id": "item_1",
|
||||||
|
"type": "command_execution",
|
||||||
|
"command": "echo 2",
|
||||||
|
"status": "in_progress",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
runner = _FakeRunnerWithEvents(
|
||||||
|
events=events,
|
||||||
|
times=[0.2, 0.4],
|
||||||
|
clock=clock,
|
||||||
|
advance_after=1.0,
|
||||||
|
)
|
||||||
|
cfg = BridgeConfig(
|
||||||
|
bot=bot, # type: ignore[arg-type]
|
||||||
|
runner=runner, # type: ignore[arg-type]
|
||||||
|
chat_id=123,
|
||||||
|
final_notify=True,
|
||||||
|
startup_msg="",
|
||||||
|
max_concurrency=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
handle_message(
|
||||||
|
cfg,
|
||||||
|
chat_id=123,
|
||||||
|
user_msg_id=10,
|
||||||
|
text="hi",
|
||||||
|
resume_session=None,
|
||||||
|
clock=clock,
|
||||||
|
sleep=clock.sleep,
|
||||||
|
progress_edit_every=1.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(bot.edit_calls) == 1
|
||||||
|
assert "echo 2" in bot.edit_calls[0]["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
|
||||||
|
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||||
|
|
||||||
|
bot = _FakeBot()
|
||||||
|
clock = _FakeClock()
|
||||||
|
hold = asyncio.Event()
|
||||||
|
events = [
|
||||||
|
{
|
||||||
|
"type": "item.started",
|
||||||
"item": {
|
"item": {
|
||||||
"id": "item_0",
|
"id": "item_0",
|
||||||
"type": "command_execution",
|
"type": "command_execution",
|
||||||
"command": "echo 1",
|
"command": "echo 1",
|
||||||
"exit_code": 0,
|
"status": "in_progress",
|
||||||
"status": "completed",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -315,8 +390,10 @@ def test_progress_edits_are_rate_limited() -> None:
|
|||||||
]
|
]
|
||||||
runner = _FakeRunnerWithEvents(
|
runner = _FakeRunnerWithEvents(
|
||||||
events=events,
|
events=events,
|
||||||
times=[0.2, 0.4, 1.2],
|
times=[0.2, 0.4],
|
||||||
clock=clock,
|
clock=clock,
|
||||||
|
advance_after=None,
|
||||||
|
hold=hold,
|
||||||
)
|
)
|
||||||
cfg = BridgeConfig(
|
cfg = BridgeConfig(
|
||||||
bot=bot, # type: ignore[arg-type]
|
bot=bot, # type: ignore[arg-type]
|
||||||
@@ -327,23 +404,50 @@ def test_progress_edits_are_rate_limited() -> None:
|
|||||||
max_concurrency=1,
|
max_concurrency=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(
|
async def run_test() -> None:
|
||||||
_handle_message(
|
task = asyncio.create_task(
|
||||||
cfg,
|
handle_message(
|
||||||
chat_id=123,
|
cfg,
|
||||||
user_msg_id=10,
|
chat_id=123,
|
||||||
text="hi",
|
user_msg_id=10,
|
||||||
resume_session=None,
|
text="hi",
|
||||||
clock=clock,
|
resume_session=None,
|
||||||
progress_edit_every=1.0,
|
clock=clock,
|
||||||
|
sleep=clock.sleep,
|
||||||
|
progress_edit_every=1.0,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
assert len(bot.edit_calls) == 1
|
for _ in range(100):
|
||||||
|
if clock._sleep_until is not None:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert clock._sleep_until == pytest.approx(1.0)
|
||||||
|
|
||||||
|
clock.set(1.0)
|
||||||
|
|
||||||
|
for _ in range(100):
|
||||||
|
if bot.edit_calls:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert len(bot.edit_calls) == 1
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert clock.sleep_calls == 1
|
||||||
|
assert clock._sleep_until is None
|
||||||
|
|
||||||
|
hold.set()
|
||||||
|
await task
|
||||||
|
|
||||||
|
asyncio.run(run_test())
|
||||||
|
|
||||||
|
|
||||||
def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
|
def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
|
||||||
from takopi.exec_bridge import BridgeConfig, _handle_message
|
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||||
|
|
||||||
bot = _FakeBot()
|
bot = _FakeBot()
|
||||||
clock = _FakeClock()
|
clock = _FakeClock()
|
||||||
@@ -386,13 +490,14 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
_handle_message(
|
handle_message(
|
||||||
cfg,
|
cfg,
|
||||||
chat_id=123,
|
chat_id=123,
|
||||||
user_msg_id=42,
|
user_msg_id=42,
|
||||||
text="do it",
|
text="do it",
|
||||||
resume_session=None,
|
resume_session=None,
|
||||||
clock=clock,
|
clock=clock,
|
||||||
|
sleep=clock.sleep,
|
||||||
progress_edit_every=1.0,
|
progress_edit_every=1.0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -529,7 +634,7 @@ class _FakeRunnerCancellable:
|
|||||||
|
|
||||||
|
|
||||||
def test_handle_message_cancelled_renders_cancelled_state() -> None:
|
def test_handle_message_cancelled_renders_cancelled_state() -> None:
|
||||||
from takopi.exec_bridge import BridgeConfig, _handle_message
|
from takopi.exec_bridge import BridgeConfig, handle_message
|
||||||
|
|
||||||
bot = _FakeBot()
|
bot = _FakeBot()
|
||||||
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
|
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
|
||||||
@@ -546,7 +651,7 @@ def test_handle_message_cancelled_renders_cancelled_state() -> None:
|
|||||||
|
|
||||||
async def run_test():
|
async def run_test():
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
_handle_message(
|
handle_message(
|
||||||
cfg,
|
cfg,
|
||||||
chat_id=123,
|
chat_id=123,
|
||||||
user_msg_id=10,
|
user_msg_id=10,
|
||||||
|
|||||||
@@ -8,46 +8,35 @@ from takopi.logging import RedactTokenFilter
|
|||||||
from takopi.telegram import TelegramClient
|
from takopi.telegram import TelegramClient
|
||||||
|
|
||||||
|
|
||||||
def test_telegram_429_retry_after_calls_sleep() -> None:
|
def test_telegram_429_no_retry() -> None:
|
||||||
calls: list[int] = []
|
calls: list[int] = []
|
||||||
sleeps: list[float] = []
|
|
||||||
|
|
||||||
async def fake_sleep(seconds: float) -> None:
|
|
||||||
sleeps.append(seconds)
|
|
||||||
|
|
||||||
def handler(request: httpx.Request) -> httpx.Response:
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
calls.append(1)
|
calls.append(1)
|
||||||
if len(calls) == 1:
|
|
||||||
return httpx.Response(
|
|
||||||
429,
|
|
||||||
json={
|
|
||||||
"ok": False,
|
|
||||||
"description": "retry",
|
|
||||||
"parameters": {"retry_after": 3},
|
|
||||||
},
|
|
||||||
request=request,
|
|
||||||
)
|
|
||||||
return httpx.Response(
|
return httpx.Response(
|
||||||
200,
|
429,
|
||||||
json={"ok": True, "result": {"message_id": 1}},
|
json={
|
||||||
|
"ok": False,
|
||||||
|
"description": "retry",
|
||||||
|
"parameters": {"retry_after": 3},
|
||||||
|
},
|
||||||
request=request,
|
request=request,
|
||||||
)
|
)
|
||||||
|
|
||||||
transport = httpx.MockTransport(handler)
|
transport = httpx.MockTransport(handler)
|
||||||
|
|
||||||
async def run() -> dict:
|
async def run() -> dict | None:
|
||||||
client = httpx.AsyncClient(transport=transport)
|
client = httpx.AsyncClient(transport=transport)
|
||||||
try:
|
try:
|
||||||
tg = TelegramClient("123:abcDEF_ghij", client=client, sleep=fake_sleep)
|
tg = TelegramClient("123:abcDEF_ghij", client=client)
|
||||||
return await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
return await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
||||||
finally:
|
finally:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
result = asyncio.run(run())
|
result = asyncio.run(run())
|
||||||
|
|
||||||
assert result == {"message_id": 1}
|
assert result is None
|
||||||
assert sleeps == [3]
|
assert len(calls) == 1
|
||||||
assert len(calls) == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> None:
|
def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> None:
|
||||||
@@ -70,8 +59,7 @@ def test_no_token_in_logs_on_http_error(caplog: pytest.LogCaptureFixture) -> Non
|
|||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
caplog.set_level(logging.ERROR)
|
caplog.set_level(logging.ERROR)
|
||||||
with pytest.raises(httpx.HTTPStatusError):
|
asyncio.run(run())
|
||||||
asyncio.run(run())
|
|
||||||
|
|
||||||
root_logger.removeFilter(redactor)
|
root_logger.removeFilter(redactor)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user