refactor: migrate exec bridge to asyncio
This commit is contained in:
@@ -5,9 +5,8 @@ description = "Telegram bridge tools for Codex."
|
|||||||
readme = "readme.md"
|
readme = "readme.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"httpx>=0.28.1",
|
||||||
"markdown-it-py",
|
"markdown-it-py",
|
||||||
"requests",
|
|
||||||
"sulguk",
|
|
||||||
"typer",
|
"typer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -1,32 +1,30 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from weakref import WeakValueDictionary
|
from html import unescape
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from .config import load_telegram_config
|
from .config import load_telegram_config
|
||||||
from .constants import TELEGRAM_HARD_LIMIT
|
from .constants import TELEGRAM_HARD_LIMIT
|
||||||
from .exec_render import ExecProgressRenderer, render_event_cli
|
from .exec_render import ExecProgressRenderer, render_event_cli
|
||||||
from .rendering import render_markdown
|
from .rendering import render_to_html, strip_tags
|
||||||
from .telegram_client import TelegramClient
|
from .telegram_client import TelegramClient
|
||||||
|
|
||||||
# -------------------- Codex runner --------------------
|
|
||||||
|
|
||||||
logger = logging.getLogger("exec_bridge")
|
logger = logging.getLogger("exec_bridge")
|
||||||
UUID_PATTERN = re.compile(
|
UUID_PATTERN = re.compile(
|
||||||
r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b"
|
r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b"
|
||||||
@@ -41,11 +39,17 @@ def extract_session_id(text: str | None) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _drain_stderr(stderr, lines: list[str]) -> None:
|
async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -> None:
|
||||||
|
if stderr is None:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
for line in stderr:
|
while True:
|
||||||
logger.info("[codex][stderr] %s", line.rstrip())
|
line = await stderr.readline()
|
||||||
lines.append(line)
|
if not line:
|
||||||
|
return
|
||||||
|
decoded = line.decode(errors="replace")
|
||||||
|
logger.info("[codex][stderr] %s", decoded.rstrip())
|
||||||
|
tail.append(decoded)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("[codex][stderr] drain error: %s", e)
|
logger.debug("[codex][stderr] drain error: %s", e)
|
||||||
|
|
||||||
@@ -85,7 +89,7 @@ def _clamp_tg_text(text: str, limit: int = TELEGRAM_TEXT_LIMIT) -> str:
|
|||||||
return text[: limit - 20] + "\n...(truncated)"
|
return text[: limit - 20] + "\n...(truncated)"
|
||||||
|
|
||||||
|
|
||||||
def _send_markdown(
|
async def _send_markdown(
|
||||||
bot: TelegramClient,
|
bot: TelegramClient,
|
||||||
*,
|
*,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
@@ -93,104 +97,30 @@ def _send_markdown(
|
|||||||
reply_to_message_id: int | None = None,
|
reply_to_message_id: int | None = None,
|
||||||
disable_notification: bool = False,
|
disable_notification: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
rendered, entities = render_markdown(text)
|
md = text
|
||||||
if len(rendered) > TELEGRAM_MARKDOWN_LIMIT:
|
if len(md) > TELEGRAM_MARKDOWN_LIMIT:
|
||||||
rendered = rendered[: TELEGRAM_MARKDOWN_LIMIT - 20] + "\n…(truncated)"
|
md = md[: TELEGRAM_MARKDOWN_LIMIT - 20] + "\n…(truncated)"
|
||||||
entities = None
|
|
||||||
return bot.send_message(
|
rendered = render_to_html(md)
|
||||||
|
if len(rendered) > TELEGRAM_TEXT_LIMIT:
|
||||||
|
plain = _clamp_tg_text(unescape(strip_tags(rendered)))
|
||||||
|
return await bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=plain,
|
||||||
|
reply_to_message_id=reply_to_message_id,
|
||||||
|
disable_notification=disable_notification,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await bot.send_message(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
text=rendered,
|
text=rendered,
|
||||||
entities=entities,
|
parse_mode="HTML",
|
||||||
reply_to_message_id=reply_to_message_id,
|
reply_to_message_id=reply_to_message_id,
|
||||||
disable_notification=disable_notification,
|
disable_notification=disable_notification,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProgressEditor:
|
EventCallback = Callable[[dict[str, Any]], Awaitable[None] | None]
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
bot: TelegramClient,
|
|
||||||
chat_id: int,
|
|
||||||
message_id: int,
|
|
||||||
edit_every_s: float,
|
|
||||||
initial_text: str | None = None,
|
|
||||||
initial_entities: list[dict[str, Any]] | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.bot = bot
|
|
||||||
self.chat_id = chat_id
|
|
||||||
self.message_id = message_id
|
|
||||||
self.edit_every_s = edit_every_s
|
|
||||||
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._pending: tuple[str, list[dict[str, Any]] | None] | None = None
|
|
||||||
self._last_sent: tuple[str, list[dict[str, Any]] | None] | None = None
|
|
||||||
self._last_edit_at = 0.0
|
|
||||||
|
|
||||||
if initial_text is not None:
|
|
||||||
self._last_sent = (initial_text, initial_entities)
|
|
||||||
self._last_edit_at = time.monotonic()
|
|
||||||
|
|
||||||
self._stop = threading.Event()
|
|
||||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
|
|
||||||
def set(self, text: str, entities: list[dict[str, Any]] | None = None) -> None:
|
|
||||||
text = _clamp_tg_text(text)
|
|
||||||
with self._lock:
|
|
||||||
self._pending = (text, entities)
|
|
||||||
logger.debug(
|
|
||||||
"[progress] set pending len=%s entities=%s", len(text), bool(entities)
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_markdown(self, text: str) -> None:
|
|
||||||
rendered_text, entities = render_markdown(text)
|
|
||||||
self.set(rendered_text, entities or None)
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
self._stop.set()
|
|
||||||
self._thread.join(timeout=1.0)
|
|
||||||
|
|
||||||
def _edit(self, text: str, entities: list[dict[str, Any]] | None) -> None:
|
|
||||||
try:
|
|
||||||
self.bot.edit_message_text(
|
|
||||||
chat_id=self.chat_id,
|
|
||||||
message_id=self.message_id,
|
|
||||||
text=text,
|
|
||||||
entities=entities,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"[progress] edit ok chat_id=%s message_id=%s len=%s",
|
|
||||||
self.chat_id,
|
|
||||||
self.message_id,
|
|
||||||
len(text),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(
|
|
||||||
"[progress] edit failed chat_id=%s message_id=%s: %s",
|
|
||||||
self.chat_id,
|
|
||||||
self.message_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run(self) -> None:
|
|
||||||
while not self._stop.is_set():
|
|
||||||
to_send: tuple[str, list[dict[str, Any]] | None] | None = None
|
|
||||||
now = time.monotonic()
|
|
||||||
with self._lock:
|
|
||||||
if (
|
|
||||||
self._pending is not None
|
|
||||||
and (now - self._last_edit_at) >= self.edit_every_s
|
|
||||||
):
|
|
||||||
if self._pending != self._last_sent:
|
|
||||||
to_send = self._pending
|
|
||||||
self._last_sent = self._pending
|
|
||||||
self._last_edit_at = now
|
|
||||||
self._pending = None
|
|
||||||
|
|
||||||
if to_send is not None:
|
|
||||||
self._edit(to_send[0], to_send[1])
|
|
||||||
|
|
||||||
self._stop.wait(0.2)
|
|
||||||
|
|
||||||
|
|
||||||
class CodexExecRunner:
|
class CodexExecRunner:
|
||||||
@@ -207,27 +137,24 @@ class CodexExecRunner:
|
|||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.extra_args = extra_args
|
self.extra_args = extra_args
|
||||||
|
|
||||||
# per-session locks to prevent concurrent resumes to same session_id
|
# Per-session locks to prevent concurrent resumes to the same session_id.
|
||||||
self._session_locks: WeakValueDictionary[str, threading.Lock] = WeakValueDictionary()
|
self._session_locks: dict[str, asyncio.Lock] = {}
|
||||||
self._locks_guard = threading.Lock()
|
self._locks_guard = asyncio.Lock()
|
||||||
|
|
||||||
def _lock_for(self, session_id: str) -> threading.Lock:
|
async def _lock_for(self, session_id: str) -> asyncio.Lock:
|
||||||
with self._locks_guard:
|
async with self._locks_guard:
|
||||||
lock = self._session_locks.get(session_id)
|
lock = self._session_locks.get(session_id)
|
||||||
if lock is None:
|
if lock is None:
|
||||||
lock = threading.Lock()
|
lock = asyncio.Lock()
|
||||||
self._session_locks[session_id] = lock
|
self._session_locks[session_id] = lock
|
||||||
return lock
|
return lock
|
||||||
|
|
||||||
def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
session_id: str | None,
|
session_id: str | None,
|
||||||
on_event: Callable[[dict[str, Any]], None] | None = None,
|
on_event: EventCallback | None = None,
|
||||||
) -> tuple[str, str, bool]:
|
) -> tuple[str, str, bool]:
|
||||||
"""
|
|
||||||
Returns (session_id, final_agent_message_text)
|
|
||||||
"""
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[codex] start run session_id=%r workspace=%r", session_id, self.workspace
|
"[codex] start run session_id=%r workspace=%r", session_id, self.workspace
|
||||||
)
|
)
|
||||||
@@ -242,67 +169,69 @@ class CodexExecRunner:
|
|||||||
else:
|
else:
|
||||||
args.append("-")
|
args.append("-")
|
||||||
|
|
||||||
# read both stdout+stderr without deadlock
|
proc = await asyncio.create_subprocess_exec(
|
||||||
proc = subprocess.Popen(
|
*args,
|
||||||
args,
|
stdin=asyncio.subprocess.PIPE,
|
||||||
stdin=subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stdout=subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
text=True,
|
|
||||||
bufsize=1,
|
|
||||||
)
|
)
|
||||||
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
|
||||||
assert proc.stdin and proc.stdout and proc.stderr
|
assert proc.stdin and proc.stdout and proc.stderr
|
||||||
|
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
||||||
|
|
||||||
# send prompt then close stdin
|
proc.stdin.write(prompt.encode())
|
||||||
proc.stdin.write(prompt)
|
await proc.stdin.drain()
|
||||||
proc.stdin.close()
|
proc.stdin.close()
|
||||||
|
|
||||||
stderr_lines: list[str] = []
|
stderr_tail: deque[str] = deque(maxlen=200)
|
||||||
t = threading.Thread(target=_drain_stderr, args=(proc.stderr, stderr_lines), daemon=True)
|
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
|
||||||
t.start()
|
|
||||||
|
|
||||||
found_session: str | None = session_id
|
found_session: str | None = session_id
|
||||||
last_agent_text: str | None = None
|
last_agent_text: str | None = None
|
||||||
saw_agent_message = False
|
saw_agent_message = False
|
||||||
|
cli_last_turn: int | None = None
|
||||||
|
|
||||||
cli_last_turn = None
|
try:
|
||||||
|
async for raw_line in proc.stdout:
|
||||||
for line in proc.stdout:
|
line = raw_line.decode(errors="replace").strip()
|
||||||
line = line.strip()
|
if not line:
|
||||||
if not line:
|
continue
|
||||||
continue
|
|
||||||
try:
|
|
||||||
evt = json.loads(line)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
cli_last_turn, out_lines = render_event_cli(evt, cli_last_turn)
|
|
||||||
for out in out_lines:
|
|
||||||
logger.info("[codex] %s", out)
|
|
||||||
if on_event is not None:
|
|
||||||
try:
|
try:
|
||||||
on_event(evt)
|
evt = json.loads(line)
|
||||||
except Exception as e:
|
except json.JSONDecodeError:
|
||||||
logger.info("[codex][on_event] callback error: %s", e)
|
continue
|
||||||
|
|
||||||
# From Codex JSONL event stream
|
cli_last_turn, out_lines = render_event_cli(evt, cli_last_turn)
|
||||||
if evt.get("type") == "thread.started":
|
for out in out_lines:
|
||||||
found_session = evt.get("thread_id") or found_session
|
logger.info("[codex] %s", out)
|
||||||
|
|
||||||
if evt.get("type") == "item.completed":
|
if on_event is not None:
|
||||||
item = evt.get("item") or {}
|
try:
|
||||||
if item.get("type") == "agent_message" and isinstance(
|
res = on_event(evt)
|
||||||
item.get("text"), str
|
if inspect.isawaitable(res):
|
||||||
):
|
await res
|
||||||
last_agent_text = item["text"]
|
except Exception as e:
|
||||||
saw_agent_message = True
|
logger.info("[codex][on_event] callback error: %s", e)
|
||||||
|
|
||||||
|
if evt.get("type") == "thread.started":
|
||||||
|
found_session = evt.get("thread_id") or found_session
|
||||||
|
|
||||||
|
if evt.get("type") == "item.completed":
|
||||||
|
item = evt.get("item") or {}
|
||||||
|
if item.get("type") == "agent_message" and isinstance(
|
||||||
|
item.get("text"), str
|
||||||
|
):
|
||||||
|
last_agent_text = item["text"]
|
||||||
|
saw_agent_message = True
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
proc.terminate()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
rc = await proc.wait()
|
||||||
|
await stderr_task
|
||||||
|
|
||||||
rc = proc.wait()
|
|
||||||
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
|
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
|
||||||
t.join(timeout=2.0)
|
|
||||||
|
|
||||||
if rc != 0:
|
if rc != 0:
|
||||||
tail = "".join(stderr_lines[-200:])
|
tail = "".join(stderr_tail)
|
||||||
raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}")
|
raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}")
|
||||||
|
|
||||||
if not found_session:
|
if not found_session:
|
||||||
@@ -317,35 +246,33 @@ class CodexExecRunner:
|
|||||||
saw_agent_message,
|
saw_agent_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_serialized(
|
async def run_serialized(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
session_id: str | None,
|
session_id: str | None,
|
||||||
on_event: Callable[[dict[str, Any]], None] | None = None,
|
on_event: EventCallback | None = None,
|
||||||
) -> tuple[str, str, bool]:
|
) -> tuple[str, str, bool]:
|
||||||
"""
|
"""
|
||||||
If resuming, serialize per-session.
|
If resuming, serialize per-session.
|
||||||
"""
|
"""
|
||||||
if not session_id:
|
if not session_id:
|
||||||
return self.run(prompt, session_id=None, on_event=on_event)
|
return await self.run(prompt, session_id=None, on_event=on_event)
|
||||||
lock = self._lock_for(session_id)
|
lock = await self._lock_for(session_id)
|
||||||
with lock:
|
async with lock:
|
||||||
return self.run(prompt, session_id=session_id, on_event=on_event)
|
return await self.run(prompt, session_id=session_id, on_event=on_event)
|
||||||
|
|
||||||
|
|
||||||
# -------------------- Telegram loop --------------------
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BridgeConfig:
|
class BridgeConfig:
|
||||||
bot: TelegramClient
|
bot: TelegramClient
|
||||||
runner: CodexExecRunner
|
runner: CodexExecRunner
|
||||||
chat_id: int
|
chat_id: int
|
||||||
pool: ThreadPoolExecutor
|
|
||||||
ignore_backlog: bool
|
ignore_backlog: bool
|
||||||
progress_edit_every_s: float
|
progress_edit_every_s: float
|
||||||
progress_silent: bool
|
progress_silent: bool
|
||||||
final_notify: bool
|
final_notify: bool
|
||||||
startup_msg: str
|
startup_msg: str
|
||||||
|
max_concurrency: int
|
||||||
|
|
||||||
|
|
||||||
def _parse_bridge_config(
|
def _parse_bridge_config(
|
||||||
@@ -395,34 +322,37 @@ def _parse_bridge_config(
|
|||||||
|
|
||||||
bot = TelegramClient(token)
|
bot = TelegramClient(token)
|
||||||
runner = CodexExecRunner(codex_cmd=codex_cmd, workspace=workspace, extra_args=extra_args)
|
runner = CodexExecRunner(codex_cmd=codex_cmd, workspace=workspace, extra_args=extra_args)
|
||||||
pool = ThreadPoolExecutor(max_workers=16)
|
|
||||||
|
|
||||||
return BridgeConfig(
|
return BridgeConfig(
|
||||||
bot=bot,
|
bot=bot,
|
||||||
runner=runner,
|
runner=runner,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
pool=pool,
|
|
||||||
ignore_backlog=bool(ignore_backlog),
|
ignore_backlog=bool(ignore_backlog),
|
||||||
progress_edit_every_s=progress_edit_every_s,
|
progress_edit_every_s=progress_edit_every_s,
|
||||||
progress_silent=progress_silent,
|
progress_silent=progress_silent,
|
||||||
final_notify=final_notify,
|
final_notify=final_notify,
|
||||||
startup_msg=startup_msg,
|
startup_msg=startup_msg,
|
||||||
|
max_concurrency=16,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _send_startup(cfg: BridgeConfig) -> None:
|
async def _send_startup(cfg: BridgeConfig) -> None:
|
||||||
try:
|
try:
|
||||||
cfg.bot.send_message(chat_id=cfg.chat_id, text=cfg.startup_msg)
|
await cfg.bot.send_message(chat_id=cfg.chat_id, text=cfg.startup_msg)
|
||||||
logger.info("[startup] sent startup message to chat_id=%s", cfg.chat_id)
|
logger.info("[startup] sent startup message to chat_id=%s", cfg.chat_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("[startup] failed to send startup message to chat_id=%s: %s", cfg.chat_id, e)
|
logger.info(
|
||||||
|
"[startup] failed to send startup message to chat_id=%s: %s", cfg.chat_id, e
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
|
async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
|
||||||
if not cfg.ignore_backlog:
|
if not cfg.ignore_backlog:
|
||||||
return offset
|
return offset
|
||||||
try:
|
try:
|
||||||
updates = cfg.bot.get_updates(offset=offset, timeout_s=0, allowed_updates=["message"])
|
updates = await cfg.bot.get_updates(
|
||||||
|
offset=offset, timeout_s=0, allowed_updates=["message"]
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("[startup] backlog drain failed: %s", e)
|
logger.info("[startup] backlog drain failed: %s", e)
|
||||||
return offset
|
return offset
|
||||||
@@ -432,9 +362,10 @@ def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
|
|||||||
return offset
|
return offset
|
||||||
|
|
||||||
|
|
||||||
def _handle_message(
|
async def _handle_message(
|
||||||
cfg: BridgeConfig,
|
cfg: BridgeConfig,
|
||||||
*,
|
*,
|
||||||
|
semaphore: asyncio.Semaphore,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
user_msg_id: int,
|
user_msg_id: int,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -444,110 +375,193 @@ def _handle_message(
|
|||||||
progress_renderer = ExecProgressRenderer(max_actions=5)
|
progress_renderer = ExecProgressRenderer(max_actions=5)
|
||||||
|
|
||||||
progress_id: int | None = None
|
progress_id: int | None = None
|
||||||
progress: ProgressEditor | None = None
|
|
||||||
|
last_edit_at = 0.0
|
||||||
|
edit_task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
|
async def _edit_progress(md: str) -> None:
|
||||||
|
if progress_id is None:
|
||||||
|
return
|
||||||
|
parse_mode: str | None = "HTML"
|
||||||
|
rendered = render_to_html(md)
|
||||||
|
if len(rendered) > TELEGRAM_TEXT_LIMIT:
|
||||||
|
rendered = _clamp_tg_text(unescape(strip_tags(rendered)))
|
||||||
|
parse_mode = None
|
||||||
|
try:
|
||||||
|
await cfg.bot.edit_message_text(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=progress_id,
|
||||||
|
text=rendered,
|
||||||
|
parse_mode=parse_mode,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(
|
||||||
|
"[progress] edit failed chat_id=%s message_id=%s: %s",
|
||||||
|
chat_id,
|
||||||
|
progress_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
initial_text = progress_renderer.render_progress(0.0)
|
initial_md = progress_renderer.render_progress(0.0)
|
||||||
initial_rendered, initial_entities = render_markdown(initial_text)
|
initial_rendered = render_to_html(initial_md)
|
||||||
progress_msg = cfg.bot.send_message(
|
progress_msg = await cfg.bot.send_message(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
text=initial_rendered,
|
text=initial_rendered,
|
||||||
entities=initial_entities or None,
|
parse_mode="HTML",
|
||||||
reply_to_message_id=user_msg_id,
|
reply_to_message_id=user_msg_id,
|
||||||
disable_notification=cfg.progress_silent,
|
disable_notification=cfg.progress_silent,
|
||||||
)
|
)
|
||||||
progress_id = int(progress_msg["message_id"])
|
progress_id = int(progress_msg["message_id"])
|
||||||
|
last_edit_at = time.monotonic()
|
||||||
logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id)
|
logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("[handle] failed to send progress message chat_id=%s: %s", chat_id, e)
|
logger.info(
|
||||||
|
"[handle] failed to send progress message chat_id=%s: %s", chat_id, e
|
||||||
if progress_id is not None:
|
|
||||||
progress = ProgressEditor(
|
|
||||||
cfg.bot,
|
|
||||||
chat_id,
|
|
||||||
progress_id,
|
|
||||||
cfg.progress_edit_every_s,
|
|
||||||
initial_text=initial_rendered,
|
|
||||||
initial_entities=initial_entities or None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_event(evt: dict[str, Any]) -> None:
|
async def on_event(evt: dict[str, Any]) -> None:
|
||||||
if progress_renderer.note_event(evt) and progress is not None:
|
nonlocal last_edit_at, edit_task
|
||||||
elapsed = time.monotonic() - started_at
|
if progress_id is None:
|
||||||
progress.set_markdown(progress_renderer.render_progress(elapsed))
|
|
||||||
|
|
||||||
def _stop_background() -> None:
|
|
||||||
if progress is not None:
|
|
||||||
progress.stop()
|
|
||||||
|
|
||||||
try:
|
|
||||||
session_id, answer, saw_agent_message = cfg.runner.run_serialized(
|
|
||||||
text,
|
|
||||||
resume_session,
|
|
||||||
on_event=on_event,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
_stop_background()
|
|
||||||
err = _clamp_tg_text(f"Error:\n{e}")
|
|
||||||
if progress_id is not None and len(err) <= TELEGRAM_TEXT_LIMIT:
|
|
||||||
cfg.bot.edit_message_text(chat_id=chat_id, message_id=progress_id, text=err)
|
|
||||||
return
|
return
|
||||||
_send_markdown(cfg.bot, chat_id=chat_id, text=err, reply_to_message_id=user_msg_id)
|
if not progress_renderer.note_event(evt):
|
||||||
return
|
return
|
||||||
|
now = time.monotonic()
|
||||||
|
if (now - last_edit_at) < cfg.progress_edit_every_s:
|
||||||
|
return
|
||||||
|
if edit_task is not None and not edit_task.done():
|
||||||
|
return
|
||||||
|
last_edit_at = now
|
||||||
|
elapsed = now - started_at
|
||||||
|
edit_task = asyncio.create_task(
|
||||||
|
_edit_progress(progress_renderer.render_progress(elapsed))
|
||||||
|
)
|
||||||
|
|
||||||
_stop_background()
|
async with semaphore:
|
||||||
|
try:
|
||||||
|
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
|
||||||
|
text, resume_session, on_event=on_event
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if edit_task is not None:
|
||||||
|
await asyncio.gather(edit_task, return_exceptions=True)
|
||||||
|
|
||||||
|
err = _clamp_tg_text(f"Error:\n{e}")
|
||||||
|
if progress_id is not None and len(err) <= TELEGRAM_TEXT_LIMIT:
|
||||||
|
try:
|
||||||
|
await cfg.bot.edit_message_text(
|
||||||
|
chat_id=chat_id, message_id=progress_id, text=err
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await _send_markdown(
|
||||||
|
cfg.bot,
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=err,
|
||||||
|
reply_to_message_id=user_msg_id,
|
||||||
|
disable_notification=cfg.progress_silent,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if edit_task is not None:
|
||||||
|
await asyncio.gather(edit_task, return_exceptions=True)
|
||||||
|
|
||||||
answer = answer or "(No agent_message captured from JSON stream.)"
|
answer = answer or "(No agent_message captured from JSON stream.)"
|
||||||
elapsed = time.monotonic() - started_at
|
elapsed = time.monotonic() - started_at
|
||||||
status = "done" if saw_agent_message else "error"
|
status = "done" if saw_agent_message else "error"
|
||||||
final_md = progress_renderer.render_final(elapsed, answer, status=status) + f"\n\nresume: `{session_id}`"
|
final_md = (
|
||||||
final_text, final_entities = render_markdown(final_md)
|
progress_renderer.render_final(elapsed, answer, status=status)
|
||||||
can_edit_final = progress_id is not None and len(final_text) <= TELEGRAM_TEXT_LIMIT
|
+ f"\n\nresume: `{session_id}`"
|
||||||
|
)
|
||||||
|
final_rendered = render_to_html(final_md)
|
||||||
|
can_edit_final = progress_id is not None and len(final_rendered) <= TELEGRAM_TEXT_LIMIT
|
||||||
|
|
||||||
if cfg.final_notify or not can_edit_final:
|
if cfg.final_notify or not can_edit_final:
|
||||||
_send_markdown(cfg.bot, chat_id=chat_id, text=final_md, reply_to_message_id=user_msg_id)
|
await _send_markdown(
|
||||||
|
cfg.bot,
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=final_md,
|
||||||
|
reply_to_message_id=user_msg_id,
|
||||||
|
disable_notification=cfg.progress_silent,
|
||||||
|
)
|
||||||
if progress_id is not None:
|
if progress_id is not None:
|
||||||
cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
|
try:
|
||||||
|
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
cfg.bot.edit_message_text(
|
await cfg.bot.edit_message_text(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
message_id=progress_id,
|
message_id=progress_id,
|
||||||
text=final_text,
|
text=final_rendered,
|
||||||
entities=final_entities or None,
|
parse_mode="HTML",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _run_main_loop(cfg: BridgeConfig) -> None:
|
async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||||
|
semaphore = asyncio.Semaphore(cfg.max_concurrency)
|
||||||
|
|
||||||
|
tasks: set[asyncio.Task[None]] = set()
|
||||||
|
|
||||||
|
def _task_done(task: asyncio.Task[None]) -> None:
|
||||||
|
tasks.discard(task)
|
||||||
|
try:
|
||||||
|
task.result()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[handle] task failed")
|
||||||
|
|
||||||
offset: int | None = None
|
offset: int | None = None
|
||||||
offset = _drain_backlog(cfg, offset)
|
offset = await _drain_backlog(cfg, offset)
|
||||||
_send_startup(cfg)
|
await _send_startup(cfg)
|
||||||
|
|
||||||
while True:
|
try:
|
||||||
updates = cfg.bot.get_updates(offset=offset, timeout_s=50, allowed_updates=["message"])
|
while True:
|
||||||
for upd in updates:
|
try:
|
||||||
offset = upd["update_id"] + 1
|
updates = await cfg.bot.get_updates(
|
||||||
msg = upd.get("message") or {}
|
offset=offset, timeout_s=50, allowed_updates=["message"]
|
||||||
msg_chat_id = msg.get("chat", {}).get("id")
|
)
|
||||||
if "text" not in msg:
|
except Exception as e:
|
||||||
continue
|
logger.info("[loop] getUpdates failed: %s", e)
|
||||||
if int(msg_chat_id) != cfg.chat_id:
|
await asyncio.sleep(2)
|
||||||
continue
|
|
||||||
if msg.get("from", {}).get("is_bot"):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text = msg["text"]
|
for upd in updates:
|
||||||
user_msg_id = msg["message_id"]
|
offset = upd["update_id"] + 1
|
||||||
resume_session = extract_session_id(text)
|
msg = upd.get("message") or {}
|
||||||
r = msg.get("reply_to_message") or {}
|
msg_chat_id = msg.get("chat", {}).get("id")
|
||||||
resume_session = resume_session or extract_session_id(r.get("text"))
|
if "text" not in msg:
|
||||||
|
continue
|
||||||
|
if int(msg_chat_id) != cfg.chat_id:
|
||||||
|
continue
|
||||||
|
if msg.get("from", {}).get("is_bot"):
|
||||||
|
continue
|
||||||
|
|
||||||
cfg.pool.submit(
|
text = msg["text"]
|
||||||
_handle_message,
|
user_msg_id = msg["message_id"]
|
||||||
cfg,
|
resume_session = extract_session_id(text)
|
||||||
chat_id=msg_chat_id,
|
r = msg.get("reply_to_message") or {}
|
||||||
user_msg_id=user_msg_id,
|
resume_session = resume_session or extract_session_id(r.get("text"))
|
||||||
text=text,
|
|
||||||
resume_session=resume_session,
|
task = asyncio.create_task(
|
||||||
)
|
_handle_message(
|
||||||
|
cfg,
|
||||||
|
semaphore=semaphore,
|
||||||
|
chat_id=msg_chat_id,
|
||||||
|
user_msg_id=user_msg_id,
|
||||||
|
text=text,
|
||||||
|
resume_session=resume_session,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tasks.add(task)
|
||||||
|
task.add_done_callback(_task_done)
|
||||||
|
finally:
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
await cfg.bot.close()
|
||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
@@ -597,7 +611,7 @@ def run(
|
|||||||
cd=cd,
|
cd=cd,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
_run_main_loop(cfg)
|
asyncio.run(_run_main_loop(cfg))
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
|||||||
@@ -1,23 +1,65 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from html import escape
|
||||||
|
|
||||||
from markdown_it import MarkdownIt
|
from markdown_it import MarkdownIt
|
||||||
from sulguk import transform_html
|
|
||||||
|
_md = MarkdownIt("commonmark", {"html": False, "breaks": True})
|
||||||
|
|
||||||
|
_CODE_CLASS_RE = re.compile(r'<code class="[^"]+">')
|
||||||
|
_IMG_ALT_RE = re.compile(r'<img[^>]*alt="([^"]*)"[^>]*/?>')
|
||||||
|
_IMG_RE = re.compile(r"<img[^>]*>")
|
||||||
|
_OL_OPEN_RE = re.compile(r'<ol(?: start="\d+")?>\s*')
|
||||||
|
_TAG_RE = re.compile(r"<[^>]+>")
|
||||||
|
|
||||||
|
|
||||||
def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]:
|
def strip_tags(html: str) -> str:
|
||||||
html = MarkdownIt("commonmark", {"html": False}).render(md or "")
|
return _TAG_RE.sub("", html)
|
||||||
rendered = transform_html(html)
|
|
||||||
|
|
||||||
text = re.sub("(?m)^(\\s*)\u2022", r"\1-", rendered.text)
|
|
||||||
|
|
||||||
# FIX: Telegram requires MessageEntity.language (if present) to be a String.
|
def render_to_html(text: str) -> str:
|
||||||
entities: list[dict[str, Any]] = []
|
"""
|
||||||
for e in rendered.entities:
|
Render Markdown to Telegram-compatible HTML.
|
||||||
d = dict(e)
|
|
||||||
if "language" in d and not isinstance(d["language"], str):
|
Telegram supports only a subset of HTML tags, so we post-process the
|
||||||
d.pop("language", None)
|
MarkdownIt output to flatten unsupported block tags (p/ul/li/etc) into
|
||||||
entities.append(d)
|
plain text with newlines and simple bullets.
|
||||||
return text, entities
|
"""
|
||||||
|
html = _md.render(text or "")
|
||||||
|
|
||||||
|
# Paragraphs and line breaks.
|
||||||
|
html = html.replace("<p>", "")
|
||||||
|
html = html.replace("<br />\n", "\n").replace("<br>\n", "\n")
|
||||||
|
html = html.replace("<br />", "\n").replace("<br>", "\n")
|
||||||
|
html = html.replace("</p>\n", "\n\n").replace("</p>", "\n\n")
|
||||||
|
|
||||||
|
# Lists -> "- " lines.
|
||||||
|
html = html.replace("<ul>\n", "").replace("</ul>\n", "")
|
||||||
|
html = _OL_OPEN_RE.sub("", html).replace("</ol>\n", "")
|
||||||
|
html = html.replace("<li>", "- ")
|
||||||
|
html = html.replace("</li>\n", "\n").replace("</li>", "\n")
|
||||||
|
|
||||||
|
# Headings -> bold line.
|
||||||
|
for level in range(1, 7):
|
||||||
|
html = html.replace(f"<h{level}>", "<b>")
|
||||||
|
html = html.replace(f"</h{level}>\n", "</b>\n\n").replace(
|
||||||
|
f"</h{level}>", "</b>\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Code fences may include language class; Telegram doesn't need it.
|
||||||
|
html = _CODE_CLASS_RE.sub("<code>", html)
|
||||||
|
|
||||||
|
# Images are not supported: keep alt text if present.
|
||||||
|
html = _IMG_ALT_RE.sub(lambda m: escape(m.group(1) or ""), html)
|
||||||
|
html = _IMG_RE.sub("", html)
|
||||||
|
|
||||||
|
# <hr> isn't supported; render a separator line.
|
||||||
|
html = html.replace("<hr />", "\n----\n\n").replace("<hr>", "\n----\n\n")
|
||||||
|
|
||||||
|
# Flatten blockquotes.
|
||||||
|
html = html.replace("<blockquote>\n", "")
|
||||||
|
html = html.replace("</blockquote>\n", "\n\n").replace("</blockquote>", "\n\n")
|
||||||
|
|
||||||
|
html = re.sub(r"\n{3,}", "\n\n", html)
|
||||||
|
return html.strip()
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import requests
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TelegramClient:
|
class TelegramClient:
|
||||||
@@ -12,42 +17,46 @@ class TelegramClient:
|
|||||||
if not token:
|
if not token:
|
||||||
raise ValueError("Telegram token is empty")
|
raise ValueError("Telegram token is empty")
|
||||||
self._base = f"https://api.telegram.org/bot{token}"
|
self._base = f"https://api.telegram.org/bot{token}"
|
||||||
self._timeout_s = timeout_s
|
self._client = httpx.AsyncClient(timeout=timeout_s)
|
||||||
|
|
||||||
def _call(self, method: str, params: dict) -> object:
|
async def close(self) -> None:
|
||||||
resp = requests.post(
|
await self._client.aclose()
|
||||||
f"{self._base}/{method}",
|
|
||||||
json=params,
|
|
||||||
timeout=self._timeout_s,
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
payload = resp.json()
|
|
||||||
if not payload.get("ok"):
|
|
||||||
raise RuntimeError(f"Telegram API error: {payload}")
|
|
||||||
return payload["result"]
|
|
||||||
|
|
||||||
def get_updates(
|
async def _post(self, method: str, json_data: dict[str, Any]) -> Any:
|
||||||
|
try:
|
||||||
|
resp = await self._client.post(f"{self._base}/{method}", json=json_data)
|
||||||
|
resp.raise_for_status()
|
||||||
|
payload = resp.json()
|
||||||
|
if not payload.get("ok"):
|
||||||
|
raise RuntimeError(f"Telegram API error: {payload}")
|
||||||
|
return payload["result"]
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error("Telegram network error: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_updates(
|
||||||
self,
|
self,
|
||||||
offset: int | None,
|
offset: int | None,
|
||||||
timeout_s: int = 50,
|
timeout_s: int = 50,
|
||||||
allowed_updates: list[str] | None = None,
|
allowed_updates: list[str] | None = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
params: dict = {"timeout": timeout_s}
|
params: dict[str, Any] = {"timeout": timeout_s}
|
||||||
if offset is not None:
|
if offset is not None:
|
||||||
params["offset"] = offset
|
params["offset"] = offset
|
||||||
if allowed_updates is not None:
|
if allowed_updates is not None:
|
||||||
params["allowed_updates"] = allowed_updates
|
params["allowed_updates"] = allowed_updates
|
||||||
return self._call("getUpdates", params) # type: ignore[return-value]
|
return await self._post("getUpdates", params) # type: ignore[return-value]
|
||||||
|
|
||||||
def send_message(
|
async def send_message(
|
||||||
self,
|
self,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
text: str,
|
text: str,
|
||||||
reply_to_message_id: int | None = None,
|
reply_to_message_id: int | None = None,
|
||||||
disable_notification: bool | None = False,
|
disable_notification: bool | None = False,
|
||||||
entities: list[dict] | None = None,
|
entities: list[dict] | None = None,
|
||||||
|
parse_mode: str | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
params: dict = {
|
params: dict[str, Any] = {
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
"text": text,
|
"text": text,
|
||||||
}
|
}
|
||||||
@@ -57,26 +66,31 @@ class TelegramClient:
|
|||||||
params["reply_to_message_id"] = reply_to_message_id
|
params["reply_to_message_id"] = reply_to_message_id
|
||||||
if entities is not None:
|
if entities is not None:
|
||||||
params["entities"] = entities
|
params["entities"] = entities
|
||||||
return self._call("sendMessage", params) # type: ignore[return-value]
|
if parse_mode is not None:
|
||||||
|
params["parse_mode"] = parse_mode
|
||||||
|
return await self._post("sendMessage", params) # type: ignore[return-value]
|
||||||
|
|
||||||
def edit_message_text(
|
async def edit_message_text(
|
||||||
self,
|
self,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
message_id: int,
|
message_id: int,
|
||||||
text: str,
|
text: str,
|
||||||
entities: list[dict] | None = None,
|
entities: list[dict] | None = None,
|
||||||
|
parse_mode: str | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
params: dict = {
|
params: dict[str, Any] = {
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
"text": text,
|
"text": text,
|
||||||
}
|
}
|
||||||
if entities is not None:
|
if entities is not None:
|
||||||
params["entities"] = entities
|
params["entities"] = entities
|
||||||
return self._call("editMessageText", params) # type: ignore[return-value]
|
if parse_mode is not None:
|
||||||
|
params["parse_mode"] = parse_mode
|
||||||
|
return await self._post("editMessageText", params) # type: ignore[return-value]
|
||||||
|
|
||||||
def delete_message(self, chat_id: int, message_id: int) -> bool:
|
async def delete_message(self, chat_id: int, message_id: int) -> bool:
|
||||||
res = self._call(
|
res = await self._post(
|
||||||
"deleteMessage",
|
"deleteMessage",
|
||||||
{
|
{
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
@@ -84,4 +98,3 @@ class TelegramClient:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
return bool(res)
|
return bool(res)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user