refactor: migrate exec bridge to asyncio

This commit is contained in:
banteg
2025-12-29 11:35:19 +04:00
parent 9037c67328
commit 489a50aec6
4 changed files with 362 additions and 294 deletions
+1 -2
View File
@@ -5,9 +5,8 @@ description = "Telegram bridge tools for Codex."
readme = "readme.md" readme = "readme.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"httpx>=0.28.1",
"markdown-it-py", "markdown-it-py",
"requests",
"sulguk",
"typer", "typer",
] ]
@@ -1,32 +1,30 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import inspect
import json import json
import logging
import os import os
import re import re
import shlex import shlex
import shutil import shutil
import subprocess
import threading
import time
import logging
import sys import sys
import time
from collections import deque
from collections.abc import Awaitable, Callable
from dataclasses import dataclass from dataclasses import dataclass
from weakref import WeakValueDictionary from html import unescape
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any
from collections.abc import Callable
import typer import typer
from .config import load_telegram_config from .config import load_telegram_config
from .constants import TELEGRAM_HARD_LIMIT from .constants import TELEGRAM_HARD_LIMIT
from .exec_render import ExecProgressRenderer, render_event_cli from .exec_render import ExecProgressRenderer, render_event_cli
from .rendering import render_markdown from .rendering import render_to_html, strip_tags
from .telegram_client import TelegramClient from .telegram_client import TelegramClient
# -------------------- Codex runner --------------------
logger = logging.getLogger("exec_bridge") logger = logging.getLogger("exec_bridge")
UUID_PATTERN = re.compile( UUID_PATTERN = re.compile(
r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b" r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b"
@@ -41,11 +39,17 @@ def extract_session_id(text: str | None) -> str | None:
return None return None
def _drain_stderr(stderr, lines: list[str]) -> None: async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -> None:
if stderr is None:
return
try: try:
for line in stderr: while True:
logger.info("[codex][stderr] %s", line.rstrip()) line = await stderr.readline()
lines.append(line) if not line:
return
decoded = line.decode(errors="replace")
logger.info("[codex][stderr] %s", decoded.rstrip())
tail.append(decoded)
except Exception as e: except Exception as e:
logger.debug("[codex][stderr] drain error: %s", e) logger.debug("[codex][stderr] drain error: %s", e)
@@ -85,7 +89,7 @@ def _clamp_tg_text(text: str, limit: int = TELEGRAM_TEXT_LIMIT) -> str:
return text[: limit - 20] + "\n...(truncated)" return text[: limit - 20] + "\n...(truncated)"
def _send_markdown( async def _send_markdown(
bot: TelegramClient, bot: TelegramClient,
*, *,
chat_id: int, chat_id: int,
@@ -93,104 +97,30 @@ def _send_markdown(
reply_to_message_id: int | None = None, reply_to_message_id: int | None = None,
disable_notification: bool = False, disable_notification: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
rendered, entities = render_markdown(text) md = text
if len(rendered) > TELEGRAM_MARKDOWN_LIMIT: if len(md) > TELEGRAM_MARKDOWN_LIMIT:
rendered = rendered[: TELEGRAM_MARKDOWN_LIMIT - 20] + "\n…(truncated)" md = md[: TELEGRAM_MARKDOWN_LIMIT - 20] + "\n…(truncated)"
entities = None
return bot.send_message( rendered = render_to_html(md)
if len(rendered) > TELEGRAM_TEXT_LIMIT:
plain = _clamp_tg_text(unescape(strip_tags(rendered)))
return await bot.send_message(
chat_id=chat_id,
text=plain,
reply_to_message_id=reply_to_message_id,
disable_notification=disable_notification,
)
return await bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=rendered, text=rendered,
entities=entities, parse_mode="HTML",
reply_to_message_id=reply_to_message_id, reply_to_message_id=reply_to_message_id,
disable_notification=disable_notification, disable_notification=disable_notification,
) )
class ProgressEditor: EventCallback = Callable[[dict[str, Any]], Awaitable[None] | None]
def __init__(
self,
bot: TelegramClient,
chat_id: int,
message_id: int,
edit_every_s: float,
initial_text: str | None = None,
initial_entities: list[dict[str, Any]] | None = None,
) -> None:
self.bot = bot
self.chat_id = chat_id
self.message_id = message_id
self.edit_every_s = edit_every_s
self._lock = threading.Lock()
self._pending: tuple[str, list[dict[str, Any]] | None] | None = None
self._last_sent: tuple[str, list[dict[str, Any]] | None] | None = None
self._last_edit_at = 0.0
if initial_text is not None:
self._last_sent = (initial_text, initial_entities)
self._last_edit_at = time.monotonic()
self._stop = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def set(self, text: str, entities: list[dict[str, Any]] | None = None) -> None:
text = _clamp_tg_text(text)
with self._lock:
self._pending = (text, entities)
logger.debug(
"[progress] set pending len=%s entities=%s", len(text), bool(entities)
)
def set_markdown(self, text: str) -> None:
rendered_text, entities = render_markdown(text)
self.set(rendered_text, entities or None)
def stop(self) -> None:
self._stop.set()
self._thread.join(timeout=1.0)
def _edit(self, text: str, entities: list[dict[str, Any]] | None) -> None:
try:
self.bot.edit_message_text(
chat_id=self.chat_id,
message_id=self.message_id,
text=text,
entities=entities,
)
logger.debug(
"[progress] edit ok chat_id=%s message_id=%s len=%s",
self.chat_id,
self.message_id,
len(text),
)
except Exception as e:
logger.info(
"[progress] edit failed chat_id=%s message_id=%s: %s",
self.chat_id,
self.message_id,
e,
)
def _run(self) -> None:
while not self._stop.is_set():
to_send: tuple[str, list[dict[str, Any]] | None] | None = None
now = time.monotonic()
with self._lock:
if (
self._pending is not None
and (now - self._last_edit_at) >= self.edit_every_s
):
if self._pending != self._last_sent:
to_send = self._pending
self._last_sent = self._pending
self._last_edit_at = now
self._pending = None
if to_send is not None:
self._edit(to_send[0], to_send[1])
self._stop.wait(0.2)
class CodexExecRunner: class CodexExecRunner:
@@ -207,27 +137,24 @@ class CodexExecRunner:
self.workspace = workspace self.workspace = workspace
self.extra_args = extra_args self.extra_args = extra_args
# per-session locks to prevent concurrent resumes to same session_id # Per-session locks to prevent concurrent resumes to the same session_id.
self._session_locks: WeakValueDictionary[str, threading.Lock] = WeakValueDictionary() self._session_locks: dict[str, asyncio.Lock] = {}
self._locks_guard = threading.Lock() self._locks_guard = asyncio.Lock()
def _lock_for(self, session_id: str) -> threading.Lock: async def _lock_for(self, session_id: str) -> asyncio.Lock:
with self._locks_guard: async with self._locks_guard:
lock = self._session_locks.get(session_id) lock = self._session_locks.get(session_id)
if lock is None: if lock is None:
lock = threading.Lock() lock = asyncio.Lock()
self._session_locks[session_id] = lock self._session_locks[session_id] = lock
return lock return lock
def run( async def run(
self, self,
prompt: str, prompt: str,
session_id: str | None, session_id: str | None,
on_event: Callable[[dict[str, Any]], None] | None = None, on_event: EventCallback | None = None,
) -> tuple[str, str, bool]: ) -> tuple[str, str, bool]:
"""
Returns (session_id, final_agent_message_text)
"""
logger.info( logger.info(
"[codex] start run session_id=%r workspace=%r", session_id, self.workspace "[codex] start run session_id=%r workspace=%r", session_id, self.workspace
) )
@@ -242,67 +169,69 @@ class CodexExecRunner:
else: else:
args.append("-") args.append("-")
# read both stdout+stderr without deadlock proc = await asyncio.create_subprocess_exec(
proc = subprocess.Popen( *args,
args, stdin=asyncio.subprocess.PIPE,
stdin=subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stdout=subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
) )
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
assert proc.stdin and proc.stdout and proc.stderr assert proc.stdin and proc.stdout and proc.stderr
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
# send prompt then close stdin proc.stdin.write(prompt.encode())
proc.stdin.write(prompt) await proc.stdin.drain()
proc.stdin.close() proc.stdin.close()
stderr_lines: list[str] = [] stderr_tail: deque[str] = deque(maxlen=200)
t = threading.Thread(target=_drain_stderr, args=(proc.stderr, stderr_lines), daemon=True) stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
t.start()
found_session: str | None = session_id found_session: str | None = session_id
last_agent_text: str | None = None last_agent_text: str | None = None
saw_agent_message = False saw_agent_message = False
cli_last_turn: int | None = None
cli_last_turn = None try:
async for raw_line in proc.stdout:
for line in proc.stdout: line = raw_line.decode(errors="replace").strip()
line = line.strip() if not line:
if not line: continue
continue
try:
evt = json.loads(line)
except json.JSONDecodeError:
continue
cli_last_turn, out_lines = render_event_cli(evt, cli_last_turn)
for out in out_lines:
logger.info("[codex] %s", out)
if on_event is not None:
try: try:
on_event(evt) evt = json.loads(line)
except Exception as e: except json.JSONDecodeError:
logger.info("[codex][on_event] callback error: %s", e) continue
# From Codex JSONL event stream cli_last_turn, out_lines = render_event_cli(evt, cli_last_turn)
if evt.get("type") == "thread.started": for out in out_lines:
found_session = evt.get("thread_id") or found_session logger.info("[codex] %s", out)
if evt.get("type") == "item.completed": if on_event is not None:
item = evt.get("item") or {} try:
if item.get("type") == "agent_message" and isinstance( res = on_event(evt)
item.get("text"), str if inspect.isawaitable(res):
): await res
last_agent_text = item["text"] except Exception as e:
saw_agent_message = True logger.info("[codex][on_event] callback error: %s", e)
if evt.get("type") == "thread.started":
found_session = evt.get("thread_id") or found_session
if evt.get("type") == "item.completed":
item = evt.get("item") or {}
if item.get("type") == "agent_message" and isinstance(
item.get("text"), str
):
last_agent_text = item["text"]
saw_agent_message = True
except asyncio.CancelledError:
proc.terminate()
raise
finally:
rc = await proc.wait()
await stderr_task
rc = proc.wait()
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc) logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
t.join(timeout=2.0)
if rc != 0: if rc != 0:
tail = "".join(stderr_lines[-200:]) tail = "".join(stderr_tail)
raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}") raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}")
if not found_session: if not found_session:
@@ -317,35 +246,33 @@ class CodexExecRunner:
saw_agent_message, saw_agent_message,
) )
def run_serialized( async def run_serialized(
self, self,
prompt: str, prompt: str,
session_id: str | None, session_id: str | None,
on_event: Callable[[dict[str, Any]], None] | None = None, on_event: EventCallback | None = None,
) -> tuple[str, str, bool]: ) -> tuple[str, str, bool]:
""" """
If resuming, serialize per-session. If resuming, serialize per-session.
""" """
if not session_id: if not session_id:
return self.run(prompt, session_id=None, on_event=on_event) return await self.run(prompt, session_id=None, on_event=on_event)
lock = self._lock_for(session_id) lock = await self._lock_for(session_id)
with lock: async with lock:
return self.run(prompt, session_id=session_id, on_event=on_event) return await self.run(prompt, session_id=session_id, on_event=on_event)
# -------------------- Telegram loop --------------------
@dataclass(frozen=True) @dataclass(frozen=True)
class BridgeConfig: class BridgeConfig:
bot: TelegramClient bot: TelegramClient
runner: CodexExecRunner runner: CodexExecRunner
chat_id: int chat_id: int
pool: ThreadPoolExecutor
ignore_backlog: bool ignore_backlog: bool
progress_edit_every_s: float progress_edit_every_s: float
progress_silent: bool progress_silent: bool
final_notify: bool final_notify: bool
startup_msg: str startup_msg: str
max_concurrency: int
def _parse_bridge_config( def _parse_bridge_config(
@@ -395,34 +322,37 @@ def _parse_bridge_config(
bot = TelegramClient(token) bot = TelegramClient(token)
runner = CodexExecRunner(codex_cmd=codex_cmd, workspace=workspace, extra_args=extra_args) runner = CodexExecRunner(codex_cmd=codex_cmd, workspace=workspace, extra_args=extra_args)
pool = ThreadPoolExecutor(max_workers=16)
return BridgeConfig( return BridgeConfig(
bot=bot, bot=bot,
runner=runner, runner=runner,
chat_id=chat_id, chat_id=chat_id,
pool=pool,
ignore_backlog=bool(ignore_backlog), ignore_backlog=bool(ignore_backlog),
progress_edit_every_s=progress_edit_every_s, progress_edit_every_s=progress_edit_every_s,
progress_silent=progress_silent, progress_silent=progress_silent,
final_notify=final_notify, final_notify=final_notify,
startup_msg=startup_msg, startup_msg=startup_msg,
max_concurrency=16,
) )
def _send_startup(cfg: BridgeConfig) -> None: async def _send_startup(cfg: BridgeConfig) -> None:
try: try:
cfg.bot.send_message(chat_id=cfg.chat_id, text=cfg.startup_msg) await cfg.bot.send_message(chat_id=cfg.chat_id, text=cfg.startup_msg)
logger.info("[startup] sent startup message to chat_id=%s", cfg.chat_id) logger.info("[startup] sent startup message to chat_id=%s", cfg.chat_id)
except Exception as e: except Exception as e:
logger.info("[startup] failed to send startup message to chat_id=%s: %s", cfg.chat_id, e) logger.info(
"[startup] failed to send startup message to chat_id=%s: %s", cfg.chat_id, e
)
def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None: async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
if not cfg.ignore_backlog: if not cfg.ignore_backlog:
return offset return offset
try: try:
updates = cfg.bot.get_updates(offset=offset, timeout_s=0, allowed_updates=["message"]) updates = await cfg.bot.get_updates(
offset=offset, timeout_s=0, allowed_updates=["message"]
)
except Exception as e: except Exception as e:
logger.info("[startup] backlog drain failed: %s", e) logger.info("[startup] backlog drain failed: %s", e)
return offset return offset
@@ -432,9 +362,10 @@ def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
return offset return offset
def _handle_message( async def _handle_message(
cfg: BridgeConfig, cfg: BridgeConfig,
*, *,
semaphore: asyncio.Semaphore,
chat_id: int, chat_id: int,
user_msg_id: int, user_msg_id: int,
text: str, text: str,
@@ -444,110 +375,193 @@ def _handle_message(
progress_renderer = ExecProgressRenderer(max_actions=5) progress_renderer = ExecProgressRenderer(max_actions=5)
progress_id: int | None = None progress_id: int | None = None
progress: ProgressEditor | None = None
last_edit_at = 0.0
edit_task: asyncio.Task[None] | None = None
async def _edit_progress(md: str) -> None:
if progress_id is None:
return
parse_mode: str | None = "HTML"
rendered = render_to_html(md)
if len(rendered) > TELEGRAM_TEXT_LIMIT:
rendered = _clamp_tg_text(unescape(strip_tags(rendered)))
parse_mode = None
try:
await cfg.bot.edit_message_text(
chat_id=chat_id,
message_id=progress_id,
text=rendered,
parse_mode=parse_mode,
)
except Exception as e:
logger.info(
"[progress] edit failed chat_id=%s message_id=%s: %s",
chat_id,
progress_id,
e,
)
try: try:
initial_text = progress_renderer.render_progress(0.0) initial_md = progress_renderer.render_progress(0.0)
initial_rendered, initial_entities = render_markdown(initial_text) initial_rendered = render_to_html(initial_md)
progress_msg = cfg.bot.send_message( progress_msg = await cfg.bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=initial_rendered, text=initial_rendered,
entities=initial_entities or None, parse_mode="HTML",
reply_to_message_id=user_msg_id, reply_to_message_id=user_msg_id,
disable_notification=cfg.progress_silent, disable_notification=cfg.progress_silent,
) )
progress_id = int(progress_msg["message_id"]) progress_id = int(progress_msg["message_id"])
last_edit_at = time.monotonic()
logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id) logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id)
except Exception as e: except Exception as e:
logger.info("[handle] failed to send progress message chat_id=%s: %s", chat_id, e) logger.info(
"[handle] failed to send progress message chat_id=%s: %s", chat_id, e
if progress_id is not None:
progress = ProgressEditor(
cfg.bot,
chat_id,
progress_id,
cfg.progress_edit_every_s,
initial_text=initial_rendered,
initial_entities=initial_entities or None,
) )
def on_event(evt: dict[str, Any]) -> None: async def on_event(evt: dict[str, Any]) -> None:
if progress_renderer.note_event(evt) and progress is not None: nonlocal last_edit_at, edit_task
elapsed = time.monotonic() - started_at if progress_id is None:
progress.set_markdown(progress_renderer.render_progress(elapsed))
def _stop_background() -> None:
if progress is not None:
progress.stop()
try:
session_id, answer, saw_agent_message = cfg.runner.run_serialized(
text,
resume_session,
on_event=on_event,
)
except Exception as e:
_stop_background()
err = _clamp_tg_text(f"Error:\n{e}")
if progress_id is not None and len(err) <= TELEGRAM_TEXT_LIMIT:
cfg.bot.edit_message_text(chat_id=chat_id, message_id=progress_id, text=err)
return return
_send_markdown(cfg.bot, chat_id=chat_id, text=err, reply_to_message_id=user_msg_id) if not progress_renderer.note_event(evt):
return return
now = time.monotonic()
if (now - last_edit_at) < cfg.progress_edit_every_s:
return
if edit_task is not None and not edit_task.done():
return
last_edit_at = now
elapsed = now - started_at
edit_task = asyncio.create_task(
_edit_progress(progress_renderer.render_progress(elapsed))
)
_stop_background() async with semaphore:
try:
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
text, resume_session, on_event=on_event
)
except Exception as e:
if edit_task is not None:
await asyncio.gather(edit_task, return_exceptions=True)
err = _clamp_tg_text(f"Error:\n{e}")
if progress_id is not None and len(err) <= TELEGRAM_TEXT_LIMIT:
try:
await cfg.bot.edit_message_text(
chat_id=chat_id, message_id=progress_id, text=err
)
return
except Exception:
pass
await _send_markdown(
cfg.bot,
chat_id=chat_id,
text=err,
reply_to_message_id=user_msg_id,
disable_notification=cfg.progress_silent,
)
return
if edit_task is not None:
await asyncio.gather(edit_task, return_exceptions=True)
answer = answer or "(No agent_message captured from JSON stream.)" answer = answer or "(No agent_message captured from JSON stream.)"
elapsed = time.monotonic() - started_at elapsed = time.monotonic() - started_at
status = "done" if saw_agent_message else "error" status = "done" if saw_agent_message else "error"
final_md = progress_renderer.render_final(elapsed, answer, status=status) + f"\n\nresume: `{session_id}`" final_md = (
final_text, final_entities = render_markdown(final_md) progress_renderer.render_final(elapsed, answer, status=status)
can_edit_final = progress_id is not None and len(final_text) <= TELEGRAM_TEXT_LIMIT + f"\n\nresume: `{session_id}`"
)
final_rendered = render_to_html(final_md)
can_edit_final = progress_id is not None and len(final_rendered) <= TELEGRAM_TEXT_LIMIT
if cfg.final_notify or not can_edit_final: if cfg.final_notify or not can_edit_final:
_send_markdown(cfg.bot, chat_id=chat_id, text=final_md, reply_to_message_id=user_msg_id) await _send_markdown(
cfg.bot,
chat_id=chat_id,
text=final_md,
reply_to_message_id=user_msg_id,
disable_notification=cfg.progress_silent,
)
if progress_id is not None: if progress_id is not None:
cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id) try:
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
except Exception:
pass
else: else:
cfg.bot.edit_message_text( await cfg.bot.edit_message_text(
chat_id=chat_id, chat_id=chat_id,
message_id=progress_id, message_id=progress_id,
text=final_text, text=final_rendered,
entities=final_entities or None, parse_mode="HTML",
) )
def _run_main_loop(cfg: BridgeConfig) -> None: async def _run_main_loop(cfg: BridgeConfig) -> None:
semaphore = asyncio.Semaphore(cfg.max_concurrency)
tasks: set[asyncio.Task[None]] = set()
def _task_done(task: asyncio.Task[None]) -> None:
tasks.discard(task)
try:
task.result()
except asyncio.CancelledError:
pass
except Exception:
logger.exception("[handle] task failed")
offset: int | None = None offset: int | None = None
offset = _drain_backlog(cfg, offset) offset = await _drain_backlog(cfg, offset)
_send_startup(cfg) await _send_startup(cfg)
while True: try:
updates = cfg.bot.get_updates(offset=offset, timeout_s=50, allowed_updates=["message"]) while True:
for upd in updates: try:
offset = upd["update_id"] + 1 updates = await cfg.bot.get_updates(
msg = upd.get("message") or {} offset=offset, timeout_s=50, allowed_updates=["message"]
msg_chat_id = msg.get("chat", {}).get("id") )
if "text" not in msg: except Exception as e:
continue logger.info("[loop] getUpdates failed: %s", e)
if int(msg_chat_id) != cfg.chat_id: await asyncio.sleep(2)
continue
if msg.get("from", {}).get("is_bot"):
continue continue
text = msg["text"] for upd in updates:
user_msg_id = msg["message_id"] offset = upd["update_id"] + 1
resume_session = extract_session_id(text) msg = upd.get("message") or {}
r = msg.get("reply_to_message") or {} msg_chat_id = msg.get("chat", {}).get("id")
resume_session = resume_session or extract_session_id(r.get("text")) if "text" not in msg:
continue
if int(msg_chat_id) != cfg.chat_id:
continue
if msg.get("from", {}).get("is_bot"):
continue
cfg.pool.submit( text = msg["text"]
_handle_message, user_msg_id = msg["message_id"]
cfg, resume_session = extract_session_id(text)
chat_id=msg_chat_id, r = msg.get("reply_to_message") or {}
user_msg_id=user_msg_id, resume_session = resume_session or extract_session_id(r.get("text"))
text=text,
resume_session=resume_session, task = asyncio.create_task(
) _handle_message(
cfg,
semaphore=semaphore,
chat_id=msg_chat_id,
user_msg_id=user_msg_id,
text=text,
resume_session=resume_session,
)
)
tasks.add(task)
task.add_done_callback(_task_done)
finally:
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
await cfg.bot.close()
def run( def run(
@@ -597,7 +611,7 @@ def run(
cd=cd, cd=cd,
model=model, model=model,
) )
_run_main_loop(cfg) asyncio.run(_run_main_loop(cfg))
def main() -> None: def main() -> None:
@@ -1,23 +1,65 @@
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Any from html import escape
from markdown_it import MarkdownIt from markdown_it import MarkdownIt
from sulguk import transform_html
_md = MarkdownIt("commonmark", {"html": False, "breaks": True})
_CODE_CLASS_RE = re.compile(r'<code class="[^"]+">')
_IMG_ALT_RE = re.compile(r'<img[^>]*alt="([^"]*)"[^>]*/?>')
_IMG_RE = re.compile(r"<img[^>]*>")
_OL_OPEN_RE = re.compile(r'<ol(?: start="\d+")?>\s*')
_TAG_RE = re.compile(r"<[^>]+>")
def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]: def strip_tags(html: str) -> str:
html = MarkdownIt("commonmark", {"html": False}).render(md or "") return _TAG_RE.sub("", html)
rendered = transform_html(html)
text = re.sub("(?m)^(\\s*)\u2022", r"\1-", rendered.text)
# FIX: Telegram requires MessageEntity.language (if present) to be a String. def render_to_html(text: str) -> str:
entities: list[dict[str, Any]] = [] """
for e in rendered.entities: Render Markdown to Telegram-compatible HTML.
d = dict(e)
if "language" in d and not isinstance(d["language"], str): Telegram supports only a subset of HTML tags, so we post-process the
d.pop("language", None) MarkdownIt output to flatten unsupported block tags (p/ul/li/etc) into
entities.append(d) plain text with newlines and simple bullets.
return text, entities """
html = _md.render(text or "")
# Paragraphs and line breaks.
html = html.replace("<p>", "")
html = html.replace("<br />\n", "\n").replace("<br>\n", "\n")
html = html.replace("<br />", "\n").replace("<br>", "\n")
html = html.replace("</p>\n", "\n\n").replace("</p>", "\n\n")
# Lists -> "- " lines.
html = html.replace("<ul>\n", "").replace("</ul>\n", "")
html = _OL_OPEN_RE.sub("", html).replace("</ol>\n", "")
html = html.replace("<li>", "- ")
html = html.replace("</li>\n", "\n").replace("</li>", "\n")
# Headings -> bold line.
for level in range(1, 7):
html = html.replace(f"<h{level}>", "<b>")
html = html.replace(f"</h{level}>\n", "</b>\n\n").replace(
f"</h{level}>", "</b>\n\n"
)
# Code fences may include language class; Telegram doesn't need it.
html = _CODE_CLASS_RE.sub("<code>", html)
# Images are not supported: keep alt text if present.
html = _IMG_ALT_RE.sub(lambda m: escape(m.group(1) or ""), html)
html = _IMG_RE.sub("", html)
# <hr> isn't supported; render a separator line.
html = html.replace("<hr />", "\n----\n\n").replace("<hr>", "\n----\n\n")
# Flatten blockquotes.
html = html.replace("<blockquote>\n", "")
html = html.replace("</blockquote>\n", "\n\n").replace("</blockquote>", "\n\n")
html = re.sub(r"\n{3,}", "\n\n", html)
return html.strip()
@@ -1,6 +1,11 @@
from __future__ import annotations from __future__ import annotations
import requests import logging
from typing import Any
import httpx
logger = logging.getLogger(__name__)
class TelegramClient: class TelegramClient:
@@ -12,42 +17,46 @@ class TelegramClient:
if not token: if not token:
raise ValueError("Telegram token is empty") raise ValueError("Telegram token is empty")
self._base = f"https://api.telegram.org/bot{token}" self._base = f"https://api.telegram.org/bot{token}"
self._timeout_s = timeout_s self._client = httpx.AsyncClient(timeout=timeout_s)
def _call(self, method: str, params: dict) -> object: async def close(self) -> None:
resp = requests.post( await self._client.aclose()
f"{self._base}/{method}",
json=params,
timeout=self._timeout_s,
)
resp.raise_for_status()
payload = resp.json()
if not payload.get("ok"):
raise RuntimeError(f"Telegram API error: {payload}")
return payload["result"]
def get_updates( async def _post(self, method: str, json_data: dict[str, Any]) -> Any:
try:
resp = await self._client.post(f"{self._base}/{method}", json=json_data)
resp.raise_for_status()
payload = resp.json()
if not payload.get("ok"):
raise RuntimeError(f"Telegram API error: {payload}")
return payload["result"]
except httpx.HTTPError as e:
logger.error("Telegram network error: %s", e)
raise
async def get_updates(
self, self,
offset: int | None, offset: int | None,
timeout_s: int = 50, timeout_s: int = 50,
allowed_updates: list[str] | None = None, allowed_updates: list[str] | None = None,
) -> list[dict]: ) -> list[dict]:
params: dict = {"timeout": timeout_s} params: dict[str, Any] = {"timeout": timeout_s}
if offset is not None: if offset is not None:
params["offset"] = offset params["offset"] = offset
if allowed_updates is not None: if allowed_updates is not None:
params["allowed_updates"] = allowed_updates params["allowed_updates"] = allowed_updates
return self._call("getUpdates", params) # type: ignore[return-value] return await self._post("getUpdates", params) # type: ignore[return-value]
def send_message( async def send_message(
self, self,
chat_id: int, chat_id: int,
text: str, text: str,
reply_to_message_id: int | None = None, reply_to_message_id: int | None = None,
disable_notification: bool | None = False, disable_notification: bool | None = False,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None,
) -> dict: ) -> dict:
params: dict = { params: dict[str, Any] = {
"chat_id": chat_id, "chat_id": chat_id,
"text": text, "text": text,
} }
@@ -57,26 +66,31 @@ class TelegramClient:
params["reply_to_message_id"] = reply_to_message_id params["reply_to_message_id"] = reply_to_message_id
if entities is not None: if entities is not None:
params["entities"] = entities params["entities"] = entities
return self._call("sendMessage", params) # type: ignore[return-value] if parse_mode is not None:
params["parse_mode"] = parse_mode
return await self._post("sendMessage", params) # type: ignore[return-value]
def edit_message_text( async def edit_message_text(
self, self,
chat_id: int, chat_id: int,
message_id: int, message_id: int,
text: str, text: str,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None,
) -> dict: ) -> dict:
params: dict = { params: dict[str, Any] = {
"chat_id": chat_id, "chat_id": chat_id,
"message_id": message_id, "message_id": message_id,
"text": text, "text": text,
} }
if entities is not None: if entities is not None:
params["entities"] = entities params["entities"] = entities
return self._call("editMessageText", params) # type: ignore[return-value] if parse_mode is not None:
params["parse_mode"] = parse_mode
return await self._post("editMessageText", params) # type: ignore[return-value]
def delete_message(self, chat_id: int, message_id: int) -> bool: async def delete_message(self, chat_id: int, message_id: int) -> bool:
res = self._call( res = await self._post(
"deleteMessage", "deleteMessage",
{ {
"chat_id": chat_id, "chat_id": chat_id,
@@ -84,4 +98,3 @@ class TelegramClient:
}, },
) )
return bool(res) return bool(res)