refactor: migrate exec bridge to asyncio
This commit is contained in:
@@ -5,9 +5,8 @@ description = "Telegram bridge tools for Codex."
|
||||
readme = "readme.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"httpx>=0.28.1",
|
||||
"markdown-it-py",
|
||||
"requests",
|
||||
"sulguk",
|
||||
"typer",
|
||||
]
|
||||
|
||||
|
||||
@@ -1,32 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from weakref import WeakValueDictionary
|
||||
from html import unescape
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
import typer
|
||||
|
||||
from .config import load_telegram_config
|
||||
from .constants import TELEGRAM_HARD_LIMIT
|
||||
from .exec_render import ExecProgressRenderer, render_event_cli
|
||||
from .rendering import render_markdown
|
||||
from .rendering import render_to_html, strip_tags
|
||||
from .telegram_client import TelegramClient
|
||||
|
||||
# -------------------- Codex runner --------------------
|
||||
|
||||
logger = logging.getLogger("exec_bridge")
|
||||
UUID_PATTERN = re.compile(
|
||||
r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b"
|
||||
@@ -41,11 +39,17 @@ def extract_session_id(text: str | None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _drain_stderr(stderr, lines: list[str]) -> None:
|
||||
async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -> None:
|
||||
if stderr is None:
|
||||
return
|
||||
try:
|
||||
for line in stderr:
|
||||
logger.info("[codex][stderr] %s", line.rstrip())
|
||||
lines.append(line)
|
||||
while True:
|
||||
line = await stderr.readline()
|
||||
if not line:
|
||||
return
|
||||
decoded = line.decode(errors="replace")
|
||||
logger.info("[codex][stderr] %s", decoded.rstrip())
|
||||
tail.append(decoded)
|
||||
except Exception as e:
|
||||
logger.debug("[codex][stderr] drain error: %s", e)
|
||||
|
||||
@@ -85,7 +89,7 @@ def _clamp_tg_text(text: str, limit: int = TELEGRAM_TEXT_LIMIT) -> str:
|
||||
return text[: limit - 20] + "\n...(truncated)"
|
||||
|
||||
|
||||
def _send_markdown(
|
||||
async def _send_markdown(
|
||||
bot: TelegramClient,
|
||||
*,
|
||||
chat_id: int,
|
||||
@@ -93,104 +97,30 @@ def _send_markdown(
|
||||
reply_to_message_id: int | None = None,
|
||||
disable_notification: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
rendered, entities = render_markdown(text)
|
||||
if len(rendered) > TELEGRAM_MARKDOWN_LIMIT:
|
||||
rendered = rendered[: TELEGRAM_MARKDOWN_LIMIT - 20] + "\n…(truncated)"
|
||||
entities = None
|
||||
return bot.send_message(
|
||||
md = text
|
||||
if len(md) > TELEGRAM_MARKDOWN_LIMIT:
|
||||
md = md[: TELEGRAM_MARKDOWN_LIMIT - 20] + "\n…(truncated)"
|
||||
|
||||
rendered = render_to_html(md)
|
||||
if len(rendered) > TELEGRAM_TEXT_LIMIT:
|
||||
plain = _clamp_tg_text(unescape(strip_tags(rendered)))
|
||||
return await bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=plain,
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
disable_notification=disable_notification,
|
||||
)
|
||||
|
||||
return await bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=rendered,
|
||||
entities=entities,
|
||||
parse_mode="HTML",
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
disable_notification=disable_notification,
|
||||
)
|
||||
|
||||
|
||||
class ProgressEditor:
|
||||
def __init__(
|
||||
self,
|
||||
bot: TelegramClient,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
edit_every_s: float,
|
||||
initial_text: str | None = None,
|
||||
initial_entities: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
self.bot = bot
|
||||
self.chat_id = chat_id
|
||||
self.message_id = message_id
|
||||
self.edit_every_s = edit_every_s
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._pending: tuple[str, list[dict[str, Any]] | None] | None = None
|
||||
self._last_sent: tuple[str, list[dict[str, Any]] | None] | None = None
|
||||
self._last_edit_at = 0.0
|
||||
|
||||
if initial_text is not None:
|
||||
self._last_sent = (initial_text, initial_entities)
|
||||
self._last_edit_at = time.monotonic()
|
||||
|
||||
self._stop = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def set(self, text: str, entities: list[dict[str, Any]] | None = None) -> None:
|
||||
text = _clamp_tg_text(text)
|
||||
with self._lock:
|
||||
self._pending = (text, entities)
|
||||
logger.debug(
|
||||
"[progress] set pending len=%s entities=%s", len(text), bool(entities)
|
||||
)
|
||||
|
||||
def set_markdown(self, text: str) -> None:
|
||||
rendered_text, entities = render_markdown(text)
|
||||
self.set(rendered_text, entities or None)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
self._thread.join(timeout=1.0)
|
||||
|
||||
def _edit(self, text: str, entities: list[dict[str, Any]] | None) -> None:
|
||||
try:
|
||||
self.bot.edit_message_text(
|
||||
chat_id=self.chat_id,
|
||||
message_id=self.message_id,
|
||||
text=text,
|
||||
entities=entities,
|
||||
)
|
||||
logger.debug(
|
||||
"[progress] edit ok chat_id=%s message_id=%s len=%s",
|
||||
self.chat_id,
|
||||
self.message_id,
|
||||
len(text),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"[progress] edit failed chat_id=%s message_id=%s: %s",
|
||||
self.chat_id,
|
||||
self.message_id,
|
||||
e,
|
||||
)
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
to_send: tuple[str, list[dict[str, Any]] | None] | None = None
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
if (
|
||||
self._pending is not None
|
||||
and (now - self._last_edit_at) >= self.edit_every_s
|
||||
):
|
||||
if self._pending != self._last_sent:
|
||||
to_send = self._pending
|
||||
self._last_sent = self._pending
|
||||
self._last_edit_at = now
|
||||
self._pending = None
|
||||
|
||||
if to_send is not None:
|
||||
self._edit(to_send[0], to_send[1])
|
||||
|
||||
self._stop.wait(0.2)
|
||||
EventCallback = Callable[[dict[str, Any]], Awaitable[None] | None]
|
||||
|
||||
|
||||
class CodexExecRunner:
|
||||
@@ -207,27 +137,24 @@ class CodexExecRunner:
|
||||
self.workspace = workspace
|
||||
self.extra_args = extra_args
|
||||
|
||||
# per-session locks to prevent concurrent resumes to same session_id
|
||||
self._session_locks: WeakValueDictionary[str, threading.Lock] = WeakValueDictionary()
|
||||
self._locks_guard = threading.Lock()
|
||||
# Per-session locks to prevent concurrent resumes to the same session_id.
|
||||
self._session_locks: dict[str, asyncio.Lock] = {}
|
||||
self._locks_guard = asyncio.Lock()
|
||||
|
||||
def _lock_for(self, session_id: str) -> threading.Lock:
|
||||
with self._locks_guard:
|
||||
async def _lock_for(self, session_id: str) -> asyncio.Lock:
|
||||
async with self._locks_guard:
|
||||
lock = self._session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
lock = asyncio.Lock()
|
||||
self._session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
def run(
|
||||
async def run(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str | None,
|
||||
on_event: Callable[[dict[str, Any]], None] | None = None,
|
||||
on_event: EventCallback | None = None,
|
||||
) -> tuple[str, str, bool]:
|
||||
"""
|
||||
Returns (session_id, final_agent_message_text)
|
||||
"""
|
||||
logger.info(
|
||||
"[codex] start run session_id=%r workspace=%r", session_id, self.workspace
|
||||
)
|
||||
@@ -242,50 +169,49 @@ class CodexExecRunner:
|
||||
else:
|
||||
args.append("-")
|
||||
|
||||
# read both stdout+stderr without deadlock
|
||||
proc = subprocess.Popen(
|
||||
args,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
||||
assert proc.stdin and proc.stdout and proc.stderr
|
||||
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
||||
|
||||
# send prompt then close stdin
|
||||
proc.stdin.write(prompt)
|
||||
proc.stdin.write(prompt.encode())
|
||||
await proc.stdin.drain()
|
||||
proc.stdin.close()
|
||||
|
||||
stderr_lines: list[str] = []
|
||||
t = threading.Thread(target=_drain_stderr, args=(proc.stderr, stderr_lines), daemon=True)
|
||||
t.start()
|
||||
stderr_tail: deque[str] = deque(maxlen=200)
|
||||
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
|
||||
|
||||
found_session: str | None = session_id
|
||||
last_agent_text: str | None = None
|
||||
saw_agent_message = False
|
||||
cli_last_turn: int | None = None
|
||||
|
||||
cli_last_turn = None
|
||||
|
||||
for line in proc.stdout:
|
||||
line = line.strip()
|
||||
try:
|
||||
async for raw_line in proc.stdout:
|
||||
line = raw_line.decode(errors="replace").strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
evt = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
cli_last_turn, out_lines = render_event_cli(evt, cli_last_turn)
|
||||
for out in out_lines:
|
||||
logger.info("[codex] %s", out)
|
||||
|
||||
if on_event is not None:
|
||||
try:
|
||||
on_event(evt)
|
||||
res = on_event(evt)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
except Exception as e:
|
||||
logger.info("[codex][on_event] callback error: %s", e)
|
||||
|
||||
# From Codex JSONL event stream
|
||||
if evt.get("type") == "thread.started":
|
||||
found_session = evt.get("thread_id") or found_session
|
||||
|
||||
@@ -296,13 +222,16 @@ class CodexExecRunner:
|
||||
):
|
||||
last_agent_text = item["text"]
|
||||
saw_agent_message = True
|
||||
except asyncio.CancelledError:
|
||||
proc.terminate()
|
||||
raise
|
||||
finally:
|
||||
rc = await proc.wait()
|
||||
await stderr_task
|
||||
|
||||
rc = proc.wait()
|
||||
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
|
||||
t.join(timeout=2.0)
|
||||
|
||||
if rc != 0:
|
||||
tail = "".join(stderr_lines[-200:])
|
||||
tail = "".join(stderr_tail)
|
||||
raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}")
|
||||
|
||||
if not found_session:
|
||||
@@ -317,35 +246,33 @@ class CodexExecRunner:
|
||||
saw_agent_message,
|
||||
)
|
||||
|
||||
def run_serialized(
|
||||
async def run_serialized(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str | None,
|
||||
on_event: Callable[[dict[str, Any]], None] | None = None,
|
||||
on_event: EventCallback | None = None,
|
||||
) -> tuple[str, str, bool]:
|
||||
"""
|
||||
If resuming, serialize per-session.
|
||||
"""
|
||||
if not session_id:
|
||||
return self.run(prompt, session_id=None, on_event=on_event)
|
||||
lock = self._lock_for(session_id)
|
||||
with lock:
|
||||
return self.run(prompt, session_id=session_id, on_event=on_event)
|
||||
return await self.run(prompt, session_id=None, on_event=on_event)
|
||||
lock = await self._lock_for(session_id)
|
||||
async with lock:
|
||||
return await self.run(prompt, session_id=session_id, on_event=on_event)
|
||||
|
||||
|
||||
# -------------------- Telegram loop --------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BridgeConfig:
|
||||
bot: TelegramClient
|
||||
runner: CodexExecRunner
|
||||
chat_id: int
|
||||
pool: ThreadPoolExecutor
|
||||
ignore_backlog: bool
|
||||
progress_edit_every_s: float
|
||||
progress_silent: bool
|
||||
final_notify: bool
|
||||
startup_msg: str
|
||||
max_concurrency: int
|
||||
|
||||
|
||||
def _parse_bridge_config(
|
||||
@@ -395,34 +322,37 @@ def _parse_bridge_config(
|
||||
|
||||
bot = TelegramClient(token)
|
||||
runner = CodexExecRunner(codex_cmd=codex_cmd, workspace=workspace, extra_args=extra_args)
|
||||
pool = ThreadPoolExecutor(max_workers=16)
|
||||
|
||||
return BridgeConfig(
|
||||
bot=bot,
|
||||
runner=runner,
|
||||
chat_id=chat_id,
|
||||
pool=pool,
|
||||
ignore_backlog=bool(ignore_backlog),
|
||||
progress_edit_every_s=progress_edit_every_s,
|
||||
progress_silent=progress_silent,
|
||||
final_notify=final_notify,
|
||||
startup_msg=startup_msg,
|
||||
max_concurrency=16,
|
||||
)
|
||||
|
||||
|
||||
def _send_startup(cfg: BridgeConfig) -> None:
|
||||
async def _send_startup(cfg: BridgeConfig) -> None:
|
||||
try:
|
||||
cfg.bot.send_message(chat_id=cfg.chat_id, text=cfg.startup_msg)
|
||||
await cfg.bot.send_message(chat_id=cfg.chat_id, text=cfg.startup_msg)
|
||||
logger.info("[startup] sent startup message to chat_id=%s", cfg.chat_id)
|
||||
except Exception as e:
|
||||
logger.info("[startup] failed to send startup message to chat_id=%s: %s", cfg.chat_id, e)
|
||||
logger.info(
|
||||
"[startup] failed to send startup message to chat_id=%s: %s", cfg.chat_id, e
|
||||
)
|
||||
|
||||
|
||||
def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
|
||||
async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
|
||||
if not cfg.ignore_backlog:
|
||||
return offset
|
||||
try:
|
||||
updates = cfg.bot.get_updates(offset=offset, timeout_s=0, allowed_updates=["message"])
|
||||
updates = await cfg.bot.get_updates(
|
||||
offset=offset, timeout_s=0, allowed_updates=["message"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info("[startup] backlog drain failed: %s", e)
|
||||
return offset
|
||||
@@ -432,9 +362,10 @@ def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
|
||||
return offset
|
||||
|
||||
|
||||
def _handle_message(
|
||||
async def _handle_message(
|
||||
cfg: BridgeConfig,
|
||||
*,
|
||||
semaphore: asyncio.Semaphore,
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
@@ -444,85 +375,159 @@ def _handle_message(
|
||||
progress_renderer = ExecProgressRenderer(max_actions=5)
|
||||
|
||||
progress_id: int | None = None
|
||||
progress: ProgressEditor | None = None
|
||||
|
||||
last_edit_at = 0.0
|
||||
edit_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def _edit_progress(md: str) -> None:
|
||||
if progress_id is None:
|
||||
return
|
||||
parse_mode: str | None = "HTML"
|
||||
rendered = render_to_html(md)
|
||||
if len(rendered) > TELEGRAM_TEXT_LIMIT:
|
||||
rendered = _clamp_tg_text(unescape(strip_tags(rendered)))
|
||||
parse_mode = None
|
||||
try:
|
||||
initial_text = progress_renderer.render_progress(0.0)
|
||||
initial_rendered, initial_entities = render_markdown(initial_text)
|
||||
progress_msg = cfg.bot.send_message(
|
||||
await cfg.bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=progress_id,
|
||||
text=rendered,
|
||||
parse_mode=parse_mode,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"[progress] edit failed chat_id=%s message_id=%s: %s",
|
||||
chat_id,
|
||||
progress_id,
|
||||
e,
|
||||
)
|
||||
|
||||
try:
|
||||
initial_md = progress_renderer.render_progress(0.0)
|
||||
initial_rendered = render_to_html(initial_md)
|
||||
progress_msg = await cfg.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=initial_rendered,
|
||||
entities=initial_entities or None,
|
||||
parse_mode="HTML",
|
||||
reply_to_message_id=user_msg_id,
|
||||
disable_notification=cfg.progress_silent,
|
||||
)
|
||||
progress_id = int(progress_msg["message_id"])
|
||||
last_edit_at = time.monotonic()
|
||||
logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id)
|
||||
except Exception as e:
|
||||
logger.info("[handle] failed to send progress message chat_id=%s: %s", chat_id, e)
|
||||
|
||||
if progress_id is not None:
|
||||
progress = ProgressEditor(
|
||||
cfg.bot,
|
||||
chat_id,
|
||||
progress_id,
|
||||
cfg.progress_edit_every_s,
|
||||
initial_text=initial_rendered,
|
||||
initial_entities=initial_entities or None,
|
||||
logger.info(
|
||||
"[handle] failed to send progress message chat_id=%s: %s", chat_id, e
|
||||
)
|
||||
|
||||
def on_event(evt: dict[str, Any]) -> None:
|
||||
if progress_renderer.note_event(evt) and progress is not None:
|
||||
elapsed = time.monotonic() - started_at
|
||||
progress.set_markdown(progress_renderer.render_progress(elapsed))
|
||||
|
||||
def _stop_background() -> None:
|
||||
if progress is not None:
|
||||
progress.stop()
|
||||
async def on_event(evt: dict[str, Any]) -> None:
|
||||
nonlocal last_edit_at, edit_task
|
||||
if progress_id is None:
|
||||
return
|
||||
if not progress_renderer.note_event(evt):
|
||||
return
|
||||
now = time.monotonic()
|
||||
if (now - last_edit_at) < cfg.progress_edit_every_s:
|
||||
return
|
||||
if edit_task is not None and not edit_task.done():
|
||||
return
|
||||
last_edit_at = now
|
||||
elapsed = now - started_at
|
||||
edit_task = asyncio.create_task(
|
||||
_edit_progress(progress_renderer.render_progress(elapsed))
|
||||
)
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
session_id, answer, saw_agent_message = cfg.runner.run_serialized(
|
||||
text,
|
||||
resume_session,
|
||||
on_event=on_event,
|
||||
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
|
||||
text, resume_session, on_event=on_event
|
||||
)
|
||||
except Exception as e:
|
||||
_stop_background()
|
||||
if edit_task is not None:
|
||||
await asyncio.gather(edit_task, return_exceptions=True)
|
||||
|
||||
err = _clamp_tg_text(f"Error:\n{e}")
|
||||
if progress_id is not None and len(err) <= TELEGRAM_TEXT_LIMIT:
|
||||
cfg.bot.edit_message_text(chat_id=chat_id, message_id=progress_id, text=err)
|
||||
try:
|
||||
await cfg.bot.edit_message_text(
|
||||
chat_id=chat_id, message_id=progress_id, text=err
|
||||
)
|
||||
return
|
||||
_send_markdown(cfg.bot, chat_id=chat_id, text=err, reply_to_message_id=user_msg_id)
|
||||
except Exception:
|
||||
pass
|
||||
await _send_markdown(
|
||||
cfg.bot,
|
||||
chat_id=chat_id,
|
||||
text=err,
|
||||
reply_to_message_id=user_msg_id,
|
||||
disable_notification=cfg.progress_silent,
|
||||
)
|
||||
return
|
||||
|
||||
_stop_background()
|
||||
if edit_task is not None:
|
||||
await asyncio.gather(edit_task, return_exceptions=True)
|
||||
|
||||
answer = answer or "(No agent_message captured from JSON stream.)"
|
||||
elapsed = time.monotonic() - started_at
|
||||
status = "done" if saw_agent_message else "error"
|
||||
final_md = progress_renderer.render_final(elapsed, answer, status=status) + f"\n\nresume: `{session_id}`"
|
||||
final_text, final_entities = render_markdown(final_md)
|
||||
can_edit_final = progress_id is not None and len(final_text) <= TELEGRAM_TEXT_LIMIT
|
||||
final_md = (
|
||||
progress_renderer.render_final(elapsed, answer, status=status)
|
||||
+ f"\n\nresume: `{session_id}`"
|
||||
)
|
||||
final_rendered = render_to_html(final_md)
|
||||
can_edit_final = progress_id is not None and len(final_rendered) <= TELEGRAM_TEXT_LIMIT
|
||||
|
||||
if cfg.final_notify or not can_edit_final:
|
||||
_send_markdown(cfg.bot, chat_id=chat_id, text=final_md, reply_to_message_id=user_msg_id)
|
||||
await _send_markdown(
|
||||
cfg.bot,
|
||||
chat_id=chat_id,
|
||||
text=final_md,
|
||||
reply_to_message_id=user_msg_id,
|
||||
disable_notification=cfg.progress_silent,
|
||||
)
|
||||
if progress_id is not None:
|
||||
cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
|
||||
try:
|
||||
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
cfg.bot.edit_message_text(
|
||||
await cfg.bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=progress_id,
|
||||
text=final_text,
|
||||
entities=final_entities or None,
|
||||
text=final_rendered,
|
||||
parse_mode="HTML",
|
||||
)
|
||||
|
||||
|
||||
def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
offset: int | None = None
|
||||
offset = _drain_backlog(cfg, offset)
|
||||
_send_startup(cfg)
|
||||
async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
semaphore = asyncio.Semaphore(cfg.max_concurrency)
|
||||
|
||||
tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
def _task_done(task: asyncio.Task[None]) -> None:
|
||||
tasks.discard(task)
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("[handle] task failed")
|
||||
|
||||
offset: int | None = None
|
||||
offset = await _drain_backlog(cfg, offset)
|
||||
await _send_startup(cfg)
|
||||
|
||||
try:
|
||||
while True:
|
||||
updates = cfg.bot.get_updates(offset=offset, timeout_s=50, allowed_updates=["message"])
|
||||
try:
|
||||
updates = await cfg.bot.get_updates(
|
||||
offset=offset, timeout_s=50, allowed_updates=["message"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info("[loop] getUpdates failed: %s", e)
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
for upd in updates:
|
||||
offset = upd["update_id"] + 1
|
||||
msg = upd.get("message") or {}
|
||||
@@ -540,14 +545,23 @@ def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
r = msg.get("reply_to_message") or {}
|
||||
resume_session = resume_session or extract_session_id(r.get("text"))
|
||||
|
||||
cfg.pool.submit(
|
||||
_handle_message,
|
||||
task = asyncio.create_task(
|
||||
_handle_message(
|
||||
cfg,
|
||||
semaphore=semaphore,
|
||||
chat_id=msg_chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
text=text,
|
||||
resume_session=resume_session,
|
||||
)
|
||||
)
|
||||
tasks.add(task)
|
||||
task.add_done_callback(_task_done)
|
||||
finally:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
await cfg.bot.close()
|
||||
|
||||
|
||||
def run(
|
||||
@@ -597,7 +611,7 @@ def run(
|
||||
cd=cd,
|
||||
model=model,
|
||||
)
|
||||
_run_main_loop(cfg)
|
||||
asyncio.run(_run_main_loop(cfg))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
@@ -1,23 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from html import escape
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
from sulguk import transform_html
|
||||
|
||||
_md = MarkdownIt("commonmark", {"html": False, "breaks": True})
|
||||
|
||||
_CODE_CLASS_RE = re.compile(r'<code class="[^"]+">')
|
||||
_IMG_ALT_RE = re.compile(r'<img[^>]*alt="([^"]*)"[^>]*/?>')
|
||||
_IMG_RE = re.compile(r"<img[^>]*>")
|
||||
_OL_OPEN_RE = re.compile(r'<ol(?: start="\d+")?>\s*')
|
||||
_TAG_RE = re.compile(r"<[^>]+>")
|
||||
|
||||
|
||||
def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]:
|
||||
html = MarkdownIt("commonmark", {"html": False}).render(md or "")
|
||||
rendered = transform_html(html)
|
||||
def strip_tags(html: str) -> str:
|
||||
return _TAG_RE.sub("", html)
|
||||
|
||||
text = re.sub("(?m)^(\\s*)\u2022", r"\1-", rendered.text)
|
||||
|
||||
# FIX: Telegram requires MessageEntity.language (if present) to be a String.
|
||||
entities: list[dict[str, Any]] = []
|
||||
for e in rendered.entities:
|
||||
d = dict(e)
|
||||
if "language" in d and not isinstance(d["language"], str):
|
||||
d.pop("language", None)
|
||||
entities.append(d)
|
||||
return text, entities
|
||||
def render_to_html(text: str) -> str:
|
||||
"""
|
||||
Render Markdown to Telegram-compatible HTML.
|
||||
|
||||
Telegram supports only a subset of HTML tags, so we post-process the
|
||||
MarkdownIt output to flatten unsupported block tags (p/ul/li/etc) into
|
||||
plain text with newlines and simple bullets.
|
||||
"""
|
||||
html = _md.render(text or "")
|
||||
|
||||
# Paragraphs and line breaks.
|
||||
html = html.replace("<p>", "")
|
||||
html = html.replace("<br />\n", "\n").replace("<br>\n", "\n")
|
||||
html = html.replace("<br />", "\n").replace("<br>", "\n")
|
||||
html = html.replace("</p>\n", "\n\n").replace("</p>", "\n\n")
|
||||
|
||||
# Lists -> "- " lines.
|
||||
html = html.replace("<ul>\n", "").replace("</ul>\n", "")
|
||||
html = _OL_OPEN_RE.sub("", html).replace("</ol>\n", "")
|
||||
html = html.replace("<li>", "- ")
|
||||
html = html.replace("</li>\n", "\n").replace("</li>", "\n")
|
||||
|
||||
# Headings -> bold line.
|
||||
for level in range(1, 7):
|
||||
html = html.replace(f"<h{level}>", "<b>")
|
||||
html = html.replace(f"</h{level}>\n", "</b>\n\n").replace(
|
||||
f"</h{level}>", "</b>\n\n"
|
||||
)
|
||||
|
||||
# Code fences may include language class; Telegram doesn't need it.
|
||||
html = _CODE_CLASS_RE.sub("<code>", html)
|
||||
|
||||
# Images are not supported: keep alt text if present.
|
||||
html = _IMG_ALT_RE.sub(lambda m: escape(m.group(1) or ""), html)
|
||||
html = _IMG_RE.sub("", html)
|
||||
|
||||
# <hr> isn't supported; render a separator line.
|
||||
html = html.replace("<hr />", "\n----\n\n").replace("<hr>", "\n----\n\n")
|
||||
|
||||
# Flatten blockquotes.
|
||||
html = html.replace("<blockquote>\n", "")
|
||||
html = html.replace("</blockquote>\n", "\n\n").replace("</blockquote>", "\n\n")
|
||||
|
||||
html = re.sub(r"\n{3,}", "\n\n", html)
|
||||
return html.strip()
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import requests
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TelegramClient:
|
||||
@@ -12,42 +17,46 @@ class TelegramClient:
|
||||
if not token:
|
||||
raise ValueError("Telegram token is empty")
|
||||
self._base = f"https://api.telegram.org/bot{token}"
|
||||
self._timeout_s = timeout_s
|
||||
self._client = httpx.AsyncClient(timeout=timeout_s)
|
||||
|
||||
def _call(self, method: str, params: dict) -> object:
|
||||
resp = requests.post(
|
||||
f"{self._base}/{method}",
|
||||
json=params,
|
||||
timeout=self._timeout_s,
|
||||
)
|
||||
async def close(self) -> None:
|
||||
await self._client.aclose()
|
||||
|
||||
async def _post(self, method: str, json_data: dict[str, Any]) -> Any:
|
||||
try:
|
||||
resp = await self._client.post(f"{self._base}/{method}", json=json_data)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not payload.get("ok"):
|
||||
raise RuntimeError(f"Telegram API error: {payload}")
|
||||
return payload["result"]
|
||||
except httpx.HTTPError as e:
|
||||
logger.error("Telegram network error: %s", e)
|
||||
raise
|
||||
|
||||
def get_updates(
|
||||
async def get_updates(
|
||||
self,
|
||||
offset: int | None,
|
||||
timeout_s: int = 50,
|
||||
allowed_updates: list[str] | None = None,
|
||||
) -> list[dict]:
|
||||
params: dict = {"timeout": timeout_s}
|
||||
params: dict[str, Any] = {"timeout": timeout_s}
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
if allowed_updates is not None:
|
||||
params["allowed_updates"] = allowed_updates
|
||||
return self._call("getUpdates", params) # type: ignore[return-value]
|
||||
return await self._post("getUpdates", params) # type: ignore[return-value]
|
||||
|
||||
def send_message(
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_to_message_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
entities: list[dict] | None = None,
|
||||
parse_mode: str | None = None,
|
||||
) -> dict:
|
||||
params: dict = {
|
||||
params: dict[str, Any] = {
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
}
|
||||
@@ -57,26 +66,31 @@ class TelegramClient:
|
||||
params["reply_to_message_id"] = reply_to_message_id
|
||||
if entities is not None:
|
||||
params["entities"] = entities
|
||||
return self._call("sendMessage", params) # type: ignore[return-value]
|
||||
if parse_mode is not None:
|
||||
params["parse_mode"] = parse_mode
|
||||
return await self._post("sendMessage", params) # type: ignore[return-value]
|
||||
|
||||
def edit_message_text(
|
||||
async def edit_message_text(
|
||||
self,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
text: str,
|
||||
entities: list[dict] | None = None,
|
||||
parse_mode: str | None = None,
|
||||
) -> dict:
|
||||
params: dict = {
|
||||
params: dict[str, Any] = {
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"text": text,
|
||||
}
|
||||
if entities is not None:
|
||||
params["entities"] = entities
|
||||
return self._call("editMessageText", params) # type: ignore[return-value]
|
||||
if parse_mode is not None:
|
||||
params["parse_mode"] = parse_mode
|
||||
return await self._post("editMessageText", params) # type: ignore[return-value]
|
||||
|
||||
def delete_message(self, chat_id: int, message_id: int) -> bool:
|
||||
res = self._call(
|
||||
async def delete_message(self, chat_id: int, message_id: int) -> bool:
|
||||
res = await self._post(
|
||||
"deleteMessage",
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
@@ -84,4 +98,3 @@ class TelegramClient:
|
||||
},
|
||||
)
|
||||
return bool(res)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user