refactor: add test hooks for timing and telegram client

This commit is contained in:
banteg
2025-12-29 14:49:31 +04:00
parent 19dc87061a
commit ff0115af92
3 changed files with 38 additions and 13 deletions
@@ -36,11 +36,20 @@ RESUME_LINE = re.compile(
def extract_session_id(text: str | None) -> str | None: def extract_session_id(text: str | None) -> str | None:
if not text: if not text:
return None return None
if m := RESUME_LINE.search(text): found = None
return m.group("id") for match in RESUME_LINE.finditer(text):
found = match.group("id")
if found:
return found
return None return None
def resolve_resume_session(
text: str | None, reply_text: str | None
) -> str | None:
return extract_session_id(text) or extract_session_id(reply_text)
async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -> None: async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -> None:
if stderr is None: if stderr is None:
return return
@@ -72,6 +81,7 @@ async def manage_subprocess(*args, **kwargs):
TELEGRAM_MARKDOWN_LIMIT = 3500 TELEGRAM_MARKDOWN_LIMIT = 3500
PROGRESS_EDIT_EVERY_S = 2.0
def _clamp_tg_text(text: str, limit: int = TELEGRAM_MARKDOWN_LIMIT) -> str: def _clamp_tg_text(text: str, limit: int = TELEGRAM_MARKDOWN_LIMIT) -> str:
@@ -436,6 +446,8 @@ async def _handle_message(
user_msg_id: int, user_msg_id: int,
text: str, text: str,
resume_session: str | None, resume_session: str | None,
clock: Callable[[], float] = time.monotonic,
progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
) -> None: ) -> None:
logger.debug( logger.debug(
"[handle] incoming chat_id=%s message_id=%s resume=%r text=%s", "[handle] incoming chat_id=%s message_id=%s resume=%r text=%s",
@@ -444,7 +456,7 @@ async def _handle_message(
resume_session, resume_session,
text, text,
) )
started_at = time.monotonic() started_at = clock()
progress_renderer = ExecProgressRenderer(max_actions=5) progress_renderer = ExecProgressRenderer(max_actions=5)
progress_id: int | None = None progress_id: int | None = None
@@ -498,7 +510,7 @@ async def _handle_message(
disable_notification=True, disable_notification=True,
) )
progress_id = int(progress_msg["message_id"]) progress_id = int(progress_msg["message_id"])
last_edit_at = time.monotonic() last_edit_at = clock()
logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id) logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id)
except Exception as e: except Exception as e:
logger.info( logger.info(
@@ -511,8 +523,8 @@ async def _handle_message(
return return
if not progress_renderer.note_event(evt): if not progress_renderer.note_event(evt):
return return
now = time.monotonic() now = clock()
if (now - last_edit_at) < 2.0: if (now - last_edit_at) < progress_edit_every:
return return
if edit_task is not None and not edit_task.done(): if edit_task is not None and not edit_task.done():
return return
@@ -547,7 +559,7 @@ async def _handle_message(
await asyncio.gather(edit_task, return_exceptions=True) await asyncio.gather(edit_task, return_exceptions=True)
answer = answer or "(No agent_message captured from JSON stream.)" answer = answer or "(No agent_message captured from JSON stream.)"
elapsed = time.monotonic() - started_at elapsed = clock() - started_at
status = "done" if saw_agent_message else "error" status = "done" if saw_agent_message else "error"
final_md = ( final_md = (
progress_renderer.render_final(elapsed, answer, status=status) progress_renderer.render_final(elapsed, answer, status=status)
@@ -647,9 +659,8 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
async for msg in poll_updates(cfg): async for msg in poll_updates(cfg):
text = msg["text"] text = msg["text"]
user_msg_id = msg["message_id"] user_msg_id = msg["message_id"]
resume_session = extract_session_id(text)
r = msg.get("reply_to_message") or {} r = msg.get("reply_to_message") or {}
resume_session = resume_session or extract_session_id(r.get("text")) resume_session = resolve_resume_session(text, r.get("text"))
await queue.put( await queue.put(
(msg["chat"]["id"], user_msg_id, text, resume_session) (msg["chat"]["id"], user_msg_id, text, resume_session)
@@ -37,4 +37,5 @@ def setup_logging(*, debug: bool = False) -> None:
console.setLevel(logging.DEBUG if debug else logging.INFO) console.setLevel(logging.DEBUG if debug else logging.INFO)
console.setFormatter(fmt) console.setFormatter(fmt)
console.addFilter(redactor) console.addFilter(redactor)
root_logger.addFilter(redactor)
root_logger.addHandler(console) root_logger.addHandler(console)
@@ -2,11 +2,15 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
import httpx import httpx
from .logging import RedactTokenFilter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.addFilter(RedactTokenFilter())
class TelegramAPIError(RuntimeError): class TelegramAPIError(RuntimeError):
@@ -24,13 +28,22 @@ class TelegramClient:
Minimal Telegram Bot API client. Minimal Telegram Bot API client.
""" """
def __init__(self, token: str, timeout_s: float = 120) -> None: def __init__(
self,
token: str,
timeout_s: float = 120,
client: httpx.AsyncClient | None = None,
sleep: Callable[[float], Awaitable[None]] = asyncio.sleep,
) -> None:
if not token: if not token:
raise ValueError("Telegram token is empty") raise ValueError("Telegram token is empty")
self._base = f"https://api.telegram.org/bot{token}" self._base = f"https://api.telegram.org/bot{token}"
self._client = httpx.AsyncClient(timeout=timeout_s) self._client = client or httpx.AsyncClient(timeout=timeout_s)
self._owns_client = client is None
self._sleep = sleep
async def close(self) -> None: async def close(self) -> None:
if self._owns_client:
await self._client.aclose() await self._client.aclose()
async def _post(self, method: str, json_data: dict[str, Any]) -> Any: async def _post(self, method: str, json_data: dict[str, Any]) -> Any:
@@ -50,7 +63,7 @@ class TelegramClient:
logger.warning( logger.warning(
"[telegram] 429 retry_after=%s method=%s", retry_after, method "[telegram] 429 retry_after=%s method=%s", retry_after, method
) )
await asyncio.sleep(retry_after) await self._sleep(retry_after)
return await self._post(method, json_data) return await self._post(method, json_data)
raise TelegramAPIError(method, payload, resp.status_code) raise TelegramAPIError(method, payload, resp.status_code)
logger.debug("[telegram] response %s: %s", method, payload) logger.debug("[telegram] response %s: %s", method, payload)