chore: move to top level
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Takopi — Telegram Codex bridge package."""
|
||||
@@ -0,0 +1,9 @@
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
|
||||
from .constants import TELEGRAM_CONFIG_PATH
|
||||
|
||||
|
||||
def load_telegram_config(path=None):
|
||||
cfg_path = Path(path) if path else TELEGRAM_CONFIG_PATH
|
||||
return tomllib.loads(cfg_path.read_text(encoding="utf-8"))
|
||||
@@ -0,0 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
TELEGRAM_HARD_LIMIT = 4096
|
||||
TELEGRAM_CONFIG_PATH = Path.home() / ".codex" / "telegram.toml"
|
||||
@@ -0,0 +1,708 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import typer
|
||||
|
||||
from .config import load_telegram_config
|
||||
from .exec_render import ExecProgressRenderer, render_event_cli
|
||||
from .logging import setup_logging
|
||||
from .rendering import render_markdown
|
||||
from .telegram_client import TelegramClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
UUID_PATTERN_TEXT = r"\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b"
|
||||
UUID_PATTERN = re.compile(UUID_PATTERN_TEXT, re.IGNORECASE)
|
||||
RESUME_LINE = re.compile(
|
||||
rf"^\s*resume\s*:\s*`?(?P<id>{UUID_PATTERN_TEXT})`?\s*$",
|
||||
re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
|
||||
|
||||
def extract_session_id(text: str | None) -> str | None:
|
||||
if not text:
|
||||
return None
|
||||
found = None
|
||||
for match in RESUME_LINE.finditer(text):
|
||||
found = match.group("id")
|
||||
if found:
|
||||
return found
|
||||
return None
|
||||
|
||||
|
||||
def resolve_resume_session(
|
||||
text: str | None, reply_text: str | None
|
||||
) -> str | None:
|
||||
return extract_session_id(text) or extract_session_id(reply_text)
|
||||
|
||||
|
||||
async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -> None:
|
||||
if stderr is None:
|
||||
return
|
||||
try:
|
||||
while True:
|
||||
line = await stderr.readline()
|
||||
if not line:
|
||||
return
|
||||
decoded = line.decode(errors="replace")
|
||||
logger.info("[codex][stderr] %s", decoded.rstrip())
|
||||
tail.append(decoded)
|
||||
except Exception as e:
|
||||
logger.debug("[codex][stderr] drain error: %s", e)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def manage_subprocess(*args, **kwargs):
|
||||
proc = await asyncio.create_subprocess_exec(*args, **kwargs)
|
||||
try:
|
||||
yield proc
|
||||
finally:
|
||||
if proc.returncode is None:
|
||||
proc.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
|
||||
|
||||
TELEGRAM_MARKDOWN_LIMIT = 3500
|
||||
PROGRESS_EDIT_EVERY_S = 2.0
|
||||
|
||||
|
||||
def _clamp_tg_text(text: str, limit: int = TELEGRAM_MARKDOWN_LIMIT) -> str:
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return text[: limit - 20] + "\n...(truncated)"
|
||||
|
||||
|
||||
def truncate_for_telegram(text: str, limit: int) -> str:
|
||||
"""
|
||||
Truncate text to fit Telegram limits while preserving the trailing `resume: ...`
|
||||
line (if present), otherwise preserving the last non-empty line.
|
||||
"""
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
|
||||
lines = text.splitlines()
|
||||
|
||||
tail_lines: list[str] | None = None
|
||||
is_resume_tail = False
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
line = lines[i]
|
||||
if "resume" in line and UUID_PATTERN.search(line):
|
||||
tail_lines = lines[i:]
|
||||
is_resume_tail = True
|
||||
break
|
||||
|
||||
if tail_lines is None:
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
if lines[i].strip():
|
||||
tail_lines = [lines[i]]
|
||||
break
|
||||
|
||||
tail = "\n".join(tail_lines or []).strip("\n")
|
||||
sep = "\n…\n"
|
||||
|
||||
max_tail = limit if is_resume_tail else (limit // 4)
|
||||
tail = tail[-max_tail:] if max_tail > 0 else ""
|
||||
|
||||
head_budget = limit - len(sep) - len(tail)
|
||||
if head_budget <= 0:
|
||||
return tail[-limit:] if tail else text[:limit]
|
||||
|
||||
head = text[:head_budget].rstrip()
|
||||
return (head + sep + tail)[:limit]
|
||||
|
||||
|
||||
def prepare_telegram(
|
||||
md: str, *, limit: int
|
||||
) -> tuple[str, list[dict[str, Any]] | None]:
|
||||
rendered, entities = render_markdown(md)
|
||||
if len(rendered) > limit:
|
||||
rendered = truncate_for_telegram(rendered, limit)
|
||||
return rendered, None
|
||||
return rendered, entities or None
|
||||
|
||||
|
||||
async def _send_or_edit_markdown(
|
||||
bot: TelegramClient,
|
||||
*,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
edit_message_id: int | None = None,
|
||||
reply_to_message_id: int | None = None,
|
||||
disable_notification: bool = False,
|
||||
limit: int = TELEGRAM_MARKDOWN_LIMIT,
|
||||
) -> tuple[dict[str, Any], bool]:
|
||||
if edit_message_id is not None:
|
||||
rendered, entities = prepare_telegram(text, limit=limit)
|
||||
try:
|
||||
return (
|
||||
await bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=edit_message_id,
|
||||
text=rendered,
|
||||
entities=entities,
|
||||
),
|
||||
True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"[tg] edit failed chat_id=%s message_id=%s: %s",
|
||||
chat_id,
|
||||
edit_message_id,
|
||||
e,
|
||||
)
|
||||
|
||||
rendered, entities = prepare_telegram(text, limit=limit)
|
||||
return (
|
||||
await bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=rendered,
|
||||
entities=entities,
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
disable_notification=disable_notification,
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
EventCallback = Callable[[dict[str, Any]], Awaitable[None] | None]
|
||||
|
||||
|
||||
class CodexExecRunner:
|
||||
"""
|
||||
Runs Codex in non-interactive mode:
|
||||
- new: codex exec --json ... -
|
||||
- resume: codex exec --json ... resume <SESSION_ID> -
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
codex_cmd: str,
|
||||
workspace: str | None,
|
||||
extra_args: list[str],
|
||||
) -> None:
|
||||
self.codex_cmd = codex_cmd
|
||||
self.workspace = workspace
|
||||
self.extra_args = extra_args
|
||||
|
||||
# Per-session locks to prevent concurrent resumes to the same session_id.
|
||||
self._session_locks: WeakValueDictionary[str, asyncio.Lock] = (
|
||||
WeakValueDictionary()
|
||||
)
|
||||
|
||||
async def _lock_for(self, session_id: str) -> asyncio.Lock:
|
||||
lock = self._session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
async def run(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str | None,
|
||||
on_event: EventCallback | None = None,
|
||||
) -> tuple[str, str, bool]:
|
||||
logger.info(
|
||||
"[codex] start run session_id=%r workspace=%r", session_id, self.workspace
|
||||
)
|
||||
logger.debug("[codex] prompt: %s", prompt)
|
||||
args = [self.codex_cmd, "exec", "--json"]
|
||||
args.extend(self.extra_args)
|
||||
if self.workspace:
|
||||
args.extend(["--cd", self.workspace])
|
||||
|
||||
# Always pipe prompt via stdin ("-") to avoid quoting issues.
|
||||
if session_id:
|
||||
args.extend(["resume", session_id, "-"])
|
||||
else:
|
||||
args.append("-")
|
||||
|
||||
async with manage_subprocess(
|
||||
*args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
) as proc:
|
||||
assert proc.stdin and proc.stdout and proc.stderr
|
||||
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
||||
|
||||
stderr_tail: deque[str] = deque(maxlen=200)
|
||||
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
|
||||
|
||||
found_session: str | None = session_id
|
||||
last_agent_text: str | None = None
|
||||
saw_agent_message = False
|
||||
cli_last_item: int | None = None
|
||||
|
||||
cancelled = False
|
||||
rc: int | None = None
|
||||
|
||||
try:
|
||||
proc.stdin.write(prompt.encode())
|
||||
await proc.stdin.drain()
|
||||
proc.stdin.close()
|
||||
|
||||
async for raw_line in proc.stdout:
|
||||
raw = raw_line.decode(errors="replace")
|
||||
logger.debug("[codex][jsonl] %s", raw.rstrip("\n"))
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
evt = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("[codex][jsonl] invalid line: %r", line)
|
||||
continue
|
||||
|
||||
cli_last_item, out_lines = render_event_cli(evt, cli_last_item)
|
||||
for out in out_lines:
|
||||
logger.info("[codex] %s", out)
|
||||
|
||||
if on_event is not None:
|
||||
try:
|
||||
res = on_event(evt)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
except Exception as e:
|
||||
logger.info("[codex][on_event] callback error: %s", e)
|
||||
|
||||
if evt.get("type") == "thread.started":
|
||||
found_session = evt.get("thread_id") or found_session
|
||||
|
||||
if evt.get("type") == "item.completed":
|
||||
item = evt.get("item") or {}
|
||||
if item.get("type") == "agent_message" and isinstance(
|
||||
item.get("text"), str
|
||||
):
|
||||
last_agent_text = item["text"]
|
||||
saw_agent_message = True
|
||||
except asyncio.CancelledError:
|
||||
cancelled = True
|
||||
finally:
|
||||
if cancelled:
|
||||
task = asyncio.current_task()
|
||||
if task is not None:
|
||||
while task.cancelling():
|
||||
task.uncancel()
|
||||
if not cancelled:
|
||||
rc = await proc.wait()
|
||||
await asyncio.gather(stderr_task, return_exceptions=True)
|
||||
|
||||
if cancelled:
|
||||
raise asyncio.CancelledError
|
||||
|
||||
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
|
||||
if rc != 0:
|
||||
tail = "".join(stderr_tail)
|
||||
raise RuntimeError(f"codex exec failed (rc={rc}). stderr tail:\n{tail}")
|
||||
|
||||
if not found_session:
|
||||
raise RuntimeError(
|
||||
"codex exec finished but no session_id/thread_id was captured"
|
||||
)
|
||||
|
||||
logger.info("[codex] done run session_id=%r", found_session)
|
||||
return (
|
||||
found_session,
|
||||
(last_agent_text or "(No agent_message captured from JSON stream.)"),
|
||||
saw_agent_message,
|
||||
)
|
||||
|
||||
async def run_serialized(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str | None,
|
||||
on_event: EventCallback | None = None,
|
||||
) -> tuple[str, str, bool]:
|
||||
"""
|
||||
If resuming, serialize per-session.
|
||||
"""
|
||||
if not session_id:
|
||||
return await self.run(prompt, session_id=None, on_event=on_event)
|
||||
lock = await self._lock_for(session_id)
|
||||
async with lock:
|
||||
return await self.run(prompt, session_id=session_id, on_event=on_event)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BridgeConfig:
|
||||
bot: TelegramClient
|
||||
runner: CodexExecRunner
|
||||
chat_id: int
|
||||
final_notify: bool
|
||||
startup_msg: str
|
||||
max_concurrency: int
|
||||
|
||||
|
||||
def _parse_bridge_config(
|
||||
*,
|
||||
final_notify: bool,
|
||||
cd: str | None,
|
||||
model: str | None,
|
||||
) -> BridgeConfig:
|
||||
config = load_telegram_config()
|
||||
token = config["bot_token"]
|
||||
chat_id = int(config["chat_id"])
|
||||
|
||||
codex_cmd = shutil.which("codex")
|
||||
if not codex_cmd:
|
||||
raise RuntimeError("codex not found on PATH")
|
||||
|
||||
startup_pwd = os.getcwd()
|
||||
workspace = None
|
||||
if cd is not None:
|
||||
expanded_cd = os.path.expanduser(cd)
|
||||
if not os.path.isdir(expanded_cd):
|
||||
raise RuntimeError(f"--cd must be an existing directory: {expanded_cd}")
|
||||
workspace = expanded_cd
|
||||
startup_pwd = expanded_cd
|
||||
|
||||
startup_msg = f"codex exec bridge has started\npwd: {startup_pwd}"
|
||||
raw_exec_args = config.get("codex_exec_args", "")
|
||||
if isinstance(raw_exec_args, list):
|
||||
extra_args = [str(v) for v in raw_exec_args]
|
||||
else:
|
||||
extra_args = shlex.split(str(raw_exec_args)) # e.g. "--full-auto --search"
|
||||
|
||||
if model:
|
||||
extra_args.extend(["--model", model])
|
||||
|
||||
def _has_notify_override(args: list[str]) -> bool:
|
||||
for i, arg in enumerate(args):
|
||||
if arg in ("-c", "--config"):
|
||||
key = args[i + 1].split("=", 1)[0].strip()
|
||||
if key == "notify" or key.endswith(".notify"):
|
||||
return True
|
||||
elif arg.startswith(("--config=", "-c=")):
|
||||
key = arg.split("=", 1)[1].split("=", 1)[0].strip()
|
||||
if key == "notify" or key.endswith(".notify"):
|
||||
return True
|
||||
return False
|
||||
|
||||
if not _has_notify_override(extra_args):
|
||||
extra_args.extend(["-c", "notify=[]"])
|
||||
|
||||
bot = TelegramClient(token)
|
||||
runner = CodexExecRunner(codex_cmd=codex_cmd, workspace=workspace, extra_args=extra_args)
|
||||
|
||||
return BridgeConfig(
|
||||
bot=bot,
|
||||
runner=runner,
|
||||
chat_id=chat_id,
|
||||
final_notify=final_notify,
|
||||
startup_msg=startup_msg,
|
||||
max_concurrency=16,
|
||||
)
|
||||
|
||||
|
||||
async def _send_startup(cfg: BridgeConfig) -> None:
|
||||
try:
|
||||
logger.debug("[startup] message: %s", cfg.startup_msg)
|
||||
await cfg.bot.send_message(chat_id=cfg.chat_id, text=cfg.startup_msg)
|
||||
logger.info("[startup] sent startup message to chat_id=%s", cfg.chat_id)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"[startup] failed to send startup message to chat_id=%s: %s", cfg.chat_id, e
|
||||
)
|
||||
|
||||
|
||||
async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
|
||||
try:
|
||||
updates = await cfg.bot.get_updates(
|
||||
offset=offset, timeout_s=0, allowed_updates=["message"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info("[startup] backlog drain failed: %s", e)
|
||||
return offset
|
||||
logger.debug("[startup] backlog updates: %s", updates)
|
||||
if updates:
|
||||
offset = updates[-1]["update_id"] + 1
|
||||
logger.info("[startup] drained %s pending update(s)", len(updates))
|
||||
return offset
|
||||
|
||||
|
||||
async def _handle_message(
|
||||
cfg: BridgeConfig,
|
||||
*,
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
resume_session: str | None,
|
||||
clock: Callable[[], float] = time.monotonic,
|
||||
progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
"[handle] incoming chat_id=%s message_id=%s resume=%r text=%s",
|
||||
chat_id,
|
||||
user_msg_id,
|
||||
resume_session,
|
||||
text,
|
||||
)
|
||||
started_at = clock()
|
||||
progress_renderer = ExecProgressRenderer(max_actions=5)
|
||||
|
||||
progress_id: int | None = None
|
||||
|
||||
last_edit_at = 0.0
|
||||
edit_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def _edit_progress(md: str) -> None:
|
||||
if progress_id is None:
|
||||
return
|
||||
rendered, entities = prepare_telegram(md, limit=TELEGRAM_MARKDOWN_LIMIT)
|
||||
logger.debug(
|
||||
"[progress] edit message_id=%s md=%s rendered=%s entities=%s",
|
||||
progress_id,
|
||||
md,
|
||||
rendered,
|
||||
entities,
|
||||
)
|
||||
try:
|
||||
await cfg.bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=progress_id,
|
||||
text=rendered,
|
||||
entities=entities,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"[progress] edit failed chat_id=%s message_id=%s: %s",
|
||||
chat_id,
|
||||
progress_id,
|
||||
e,
|
||||
)
|
||||
|
||||
try:
|
||||
initial_md = progress_renderer.render_progress(0.0)
|
||||
initial_rendered, initial_entities = prepare_telegram(
|
||||
initial_md, limit=TELEGRAM_MARKDOWN_LIMIT
|
||||
)
|
||||
logger.debug(
|
||||
"[progress] send reply_to=%s md=%s rendered=%s entities=%s",
|
||||
user_msg_id,
|
||||
initial_md,
|
||||
initial_rendered,
|
||||
initial_entities,
|
||||
)
|
||||
progress_msg = await cfg.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=initial_rendered,
|
||||
entities=initial_entities,
|
||||
reply_to_message_id=user_msg_id,
|
||||
disable_notification=True,
|
||||
)
|
||||
progress_id = int(progress_msg["message_id"])
|
||||
last_edit_at = clock()
|
||||
logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"[handle] failed to send progress message chat_id=%s: %s", chat_id, e
|
||||
)
|
||||
|
||||
async def on_event(evt: dict[str, Any]) -> None:
|
||||
nonlocal last_edit_at, edit_task
|
||||
if progress_id is None:
|
||||
return
|
||||
if not progress_renderer.note_event(evt):
|
||||
return
|
||||
now = clock()
|
||||
if (now - last_edit_at) < progress_edit_every:
|
||||
return
|
||||
if edit_task is not None and not edit_task.done():
|
||||
return
|
||||
last_edit_at = now
|
||||
elapsed = now - started_at
|
||||
edit_task = asyncio.create_task(
|
||||
_edit_progress(progress_renderer.render_progress(elapsed))
|
||||
)
|
||||
|
||||
try:
|
||||
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
|
||||
text, resume_session, on_event=on_event
|
||||
)
|
||||
except Exception as e:
|
||||
if edit_task is not None:
|
||||
await asyncio.gather(edit_task, return_exceptions=True)
|
||||
|
||||
err = _clamp_tg_text(f"Error:\n{e}")
|
||||
logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err)
|
||||
await _send_or_edit_markdown(
|
||||
cfg.bot,
|
||||
chat_id=chat_id,
|
||||
text=err,
|
||||
edit_message_id=progress_id,
|
||||
reply_to_message_id=user_msg_id,
|
||||
disable_notification=True,
|
||||
limit=TELEGRAM_MARKDOWN_LIMIT,
|
||||
)
|
||||
return
|
||||
|
||||
if edit_task is not None:
|
||||
await asyncio.gather(edit_task, return_exceptions=True)
|
||||
|
||||
answer = answer or "(No agent_message captured from JSON stream.)"
|
||||
elapsed = clock() - started_at
|
||||
status = "done" if saw_agent_message else "error"
|
||||
final_md = (
|
||||
progress_renderer.render_final(elapsed, answer, status=status)
|
||||
+ f"\n\nresume: `{session_id}`"
|
||||
)
|
||||
logger.debug("[final] markdown: %s", final_md)
|
||||
final_rendered, final_entities = render_markdown(final_md)
|
||||
can_edit_final = (
|
||||
progress_id is not None and len(final_rendered) <= TELEGRAM_MARKDOWN_LIMIT
|
||||
)
|
||||
edit_message_id = None if cfg.final_notify or not can_edit_final else progress_id
|
||||
|
||||
if edit_message_id is None:
|
||||
logger.debug(
|
||||
"[final] send reply_to=%s rendered=%s entities=%s",
|
||||
user_msg_id,
|
||||
final_rendered,
|
||||
final_entities,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"[final] edit message_id=%s rendered=%s entities=%s",
|
||||
edit_message_id,
|
||||
final_rendered,
|
||||
final_entities,
|
||||
)
|
||||
|
||||
_, edited = await _send_or_edit_markdown(
|
||||
cfg.bot,
|
||||
chat_id=chat_id,
|
||||
text=final_md,
|
||||
edit_message_id=edit_message_id,
|
||||
reply_to_message_id=user_msg_id,
|
||||
disable_notification=False,
|
||||
limit=TELEGRAM_MARKDOWN_LIMIT,
|
||||
)
|
||||
if progress_id is not None and (edit_message_id is None or not edited):
|
||||
try:
|
||||
logger.debug("[final] delete progress message_id=%s", progress_id)
|
||||
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def poll_updates(cfg: BridgeConfig):
|
||||
offset: int | None = None
|
||||
offset = await _drain_backlog(cfg, offset)
|
||||
await _send_startup(cfg)
|
||||
|
||||
while True:
|
||||
try:
|
||||
updates = await cfg.bot.get_updates(
|
||||
offset=offset, timeout_s=50, allowed_updates=["message"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info("[loop] getUpdates failed: %s", e)
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
logger.debug("[loop] updates: %s", updates)
|
||||
|
||||
for upd in updates:
|
||||
offset = upd["update_id"] + 1
|
||||
msg = upd["message"]
|
||||
if "text" not in msg:
|
||||
continue
|
||||
if not (msg["chat"]["id"] == msg["from"]["id"] == cfg.chat_id):
|
||||
continue
|
||||
yield msg
|
||||
|
||||
|
||||
async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
worker_count = max(1, min(cfg.max_concurrency, 16))
|
||||
queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue(
|
||||
maxsize=worker_count * 2
|
||||
)
|
||||
|
||||
async def worker() -> None:
|
||||
while True:
|
||||
chat_id, user_msg_id, text, resume_session = await queue.get()
|
||||
try:
|
||||
await _handle_message(
|
||||
cfg,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
text=text,
|
||||
resume_session=resume_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[handle] worker failed")
|
||||
finally:
|
||||
queue.task_done()
|
||||
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for _ in range(worker_count):
|
||||
tg.create_task(worker())
|
||||
async for msg in poll_updates(cfg):
|
||||
text = msg["text"]
|
||||
user_msg_id = msg["message_id"]
|
||||
r = msg.get("reply_to_message") or {}
|
||||
resume_session = resolve_resume_session(text, r.get("text"))
|
||||
|
||||
await queue.put(
|
||||
(msg["chat"]["id"], user_msg_id, text, resume_session)
|
||||
)
|
||||
finally:
|
||||
await cfg.bot.close()
|
||||
|
||||
|
||||
def run(
|
||||
final_notify: bool = typer.Option(
|
||||
True,
|
||||
"--final-notify/--no-final-notify",
|
||||
help="Send the final response as a new message (not an edit).",
|
||||
),
|
||||
debug: bool = typer.Option(
|
||||
False,
|
||||
"--debug/--no-debug",
|
||||
help="Log codex JSONL, Telegram requests, and rendered messages.",
|
||||
),
|
||||
cd: str | None = typer.Option(
|
||||
None,
|
||||
"--cd",
|
||||
help="Pass through to `codex --cd`.",
|
||||
),
|
||||
model: str | None = typer.Option(
|
||||
None,
|
||||
"--model",
|
||||
help="Codex model to pass to `codex exec`.",
|
||||
),
|
||||
) -> None:
|
||||
setup_logging(debug=debug)
|
||||
cfg = _parse_bridge_config(
|
||||
final_notify=final_notify,
|
||||
cd=cd,
|
||||
model=model,
|
||||
)
|
||||
asyncio.run(_run_main_loop(cfg))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
typer.run(run)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,208 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import textwrap
|
||||
from collections import deque
|
||||
from textwrap import indent
|
||||
from typing import Any
|
||||
|
||||
STATUS_RUNNING = "▸"
|
||||
STATUS_DONE = "✓"
|
||||
STATUS_FAIL = "✗"
|
||||
HEADER_SEP = " · "
|
||||
HARD_BREAK = " \n"
|
||||
|
||||
MAX_PROGRESS_CMD_LEN = 300
|
||||
MAX_QUERY_LEN = 60
|
||||
MAX_PATH_LEN = 40
|
||||
|
||||
|
||||
def format_elapsed(elapsed_s: float) -> str:
|
||||
total = max(0, int(elapsed_s))
|
||||
minutes, seconds = divmod(total, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
if hours:
|
||||
return f"{hours}h {minutes:02d}m"
|
||||
if minutes:
|
||||
return f"{minutes}m {seconds:02d}s"
|
||||
return f"{seconds}s"
|
||||
|
||||
|
||||
def format_header(elapsed_s: float, item: int | None, label: str) -> str:
|
||||
elapsed = format_elapsed(elapsed_s)
|
||||
parts = [label, elapsed]
|
||||
if item is not None:
|
||||
parts.append(f"item {item}")
|
||||
return HEADER_SEP.join(parts)
|
||||
|
||||
|
||||
def is_command_log_line(line: str) -> bool:
|
||||
return (
|
||||
f"{STATUS_RUNNING} " in line
|
||||
or f"{STATUS_DONE} " in line
|
||||
or f"{STATUS_FAIL} " in line
|
||||
)
|
||||
|
||||
|
||||
def extract_numeric_id(item_id: object, fallback: int | None = None) -> int | None:
|
||||
if isinstance(item_id, int):
|
||||
return item_id
|
||||
if isinstance(item_id, str):
|
||||
match = re.search(r"(?:item_)?(\d+)", item_id)
|
||||
if match:
|
||||
return int(match.group(1))
|
||||
return fallback
|
||||
|
||||
|
||||
def _shorten(text: str, width: int) -> str:
|
||||
return textwrap.shorten(text, width=width, placeholder="…")
|
||||
|
||||
|
||||
def _shorten_path(path: str, width: int) -> str:
|
||||
# Encourage word-boundary truncation for paths (since they may have no spaces).
|
||||
return _shorten(path.replace("/", " /"), width).replace(" /", "/")
|
||||
|
||||
|
||||
def format_event(
|
||||
event: dict[str, Any],
|
||||
last_item: int | None,
|
||||
*,
|
||||
command_width: int | None = None,
|
||||
) -> tuple[int | None, list[str], str | None, str | None]:
|
||||
"""
|
||||
Returns (new_last_item, cli_lines, progress_line, progress_prefix).
|
||||
progress_prefix is only set when progress_line is set, and is used for
|
||||
replacing a preceding "running" line on completion.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
|
||||
match event["type"]:
|
||||
case "thread.started":
|
||||
return last_item, ["thread started"], None, None
|
||||
case "turn.started":
|
||||
return last_item, ["turn started"], None, None
|
||||
case "turn.completed":
|
||||
return last_item, ["turn completed"], None, None
|
||||
case "turn.failed":
|
||||
return last_item, [f"turn failed: {event['error']['message']}"], None, None
|
||||
case "error":
|
||||
return last_item, [f"stream error: {event['message']}"], None, None
|
||||
case "item.started" | "item.updated" | "item.completed" as etype:
|
||||
item = event["item"]
|
||||
item_num = extract_numeric_id(item["id"], last_item)
|
||||
last_item = item_num if item_num is not None else last_item
|
||||
prefix = f"{item_num}. "
|
||||
|
||||
match (item["type"], etype):
|
||||
case ("agent_message", "item.completed"):
|
||||
lines.append("assistant:")
|
||||
lines.extend(indent(item["text"], " ").splitlines())
|
||||
return last_item, lines, None, None
|
||||
case ("reasoning", "item.completed"):
|
||||
text = item.get("text") or ""
|
||||
first_line = text.splitlines()[0] if text else ""
|
||||
line = prefix + first_line
|
||||
return last_item, [line], line, prefix
|
||||
case ("command_execution", "item.started"):
|
||||
command = item["command"]
|
||||
if command_width is not None:
|
||||
command = _shorten(command, command_width)
|
||||
command = f"`{command}`"
|
||||
line = prefix + f"{STATUS_RUNNING} {command}"
|
||||
return last_item, [line], line, prefix
|
||||
case ("command_execution", "item.completed"):
|
||||
command = item["command"]
|
||||
if command_width is not None:
|
||||
command = _shorten(command, command_width)
|
||||
command = f"`{command}`"
|
||||
exit_code = item["exit_code"]
|
||||
if exit_code == 0:
|
||||
status = STATUS_DONE
|
||||
exit_part = ""
|
||||
else:
|
||||
status = STATUS_FAIL if exit_code is not None else STATUS_DONE
|
||||
exit_part = f" (exit {exit_code})" if exit_code is not None else ""
|
||||
line = prefix + f"{status} {command}{exit_part}"
|
||||
return last_item, [line], line, prefix
|
||||
case ("mcp_tool_call", "item.started"):
|
||||
name = ".".join(part for part in (item["server"], item["tool"]) if part) or "tool"
|
||||
line = prefix + f"{STATUS_RUNNING} tool: {name}"
|
||||
return last_item, [line], line, prefix
|
||||
case ("mcp_tool_call", "item.completed"):
|
||||
name = ".".join(part for part in (item["server"], item["tool"]) if part) or "tool"
|
||||
line = prefix + f"{STATUS_DONE} tool: {name}"
|
||||
return last_item, [line], line, prefix
|
||||
case ("web_search", "item.completed"):
|
||||
query = _shorten(item["query"], MAX_QUERY_LEN)
|
||||
line = prefix + f"{STATUS_DONE} searched: {query}"
|
||||
return last_item, [line], line, prefix
|
||||
case ("file_change", "item.completed"):
|
||||
paths = [change["path"] for change in item["changes"] if change.get("path")]
|
||||
if not paths:
|
||||
total = len(item["changes"])
|
||||
desc = "updated files" if total == 0 else f"updated {total} files"
|
||||
elif len(paths) <= 3:
|
||||
desc = "updated " + ", ".join(f"`{_shorten_path(p, MAX_PATH_LEN)}`" for p in paths)
|
||||
else:
|
||||
desc = f"updated {len(paths)} files"
|
||||
line = prefix + f"{STATUS_DONE} {desc}"
|
||||
return last_item, [line], line, prefix
|
||||
case ("error", "item.completed"):
|
||||
warning = _shorten(item["message"], 120)
|
||||
line = prefix + f"{STATUS_DONE} warning: {warning}"
|
||||
return last_item, [line], line, prefix
|
||||
case _:
|
||||
return last_item, [], None, None
|
||||
case _:
|
||||
return last_item, [], None, None
|
||||
|
||||
|
||||
def render_event_cli(
|
||||
event: dict[str, Any], last_item: int | None = None
|
||||
) -> tuple[int | None, list[str]]:
|
||||
last_item, cli_lines, _, _ = format_event(event, last_item, command_width=None)
|
||||
return last_item, cli_lines
|
||||
|
||||
|
||||
class ExecProgressRenderer:
|
||||
def __init__(
|
||||
self,
|
||||
max_actions: int = 5,
|
||||
command_width: int | None = MAX_PROGRESS_CMD_LEN,
|
||||
) -> None:
|
||||
self.max_actions = max_actions
|
||||
self.command_width = command_width
|
||||
self.recent_actions: deque[str] = deque(maxlen=max_actions)
|
||||
self.last_item: int | None = None
|
||||
|
||||
def note_event(self, event: dict[str, Any]) -> bool:
|
||||
if event["type"] == "thread.started":
|
||||
return True
|
||||
|
||||
self.last_item, _, progress_line, progress_prefix = format_event(
|
||||
event, self.last_item, command_width=self.command_width
|
||||
)
|
||||
if progress_line is None:
|
||||
return False
|
||||
|
||||
# Replace the preceding "running" line for the same item on completion.
|
||||
if event["type"] == "item.completed" and progress_prefix and self.recent_actions:
|
||||
last = self.recent_actions[-1]
|
||||
if last.startswith(progress_prefix + f"{STATUS_RUNNING} "):
|
||||
self.recent_actions.pop()
|
||||
|
||||
self.recent_actions.append(progress_line)
|
||||
return True
|
||||
|
||||
def render_progress(self, elapsed_s: float) -> str:
|
||||
header = format_header(elapsed_s, self.last_item, label="working")
|
||||
return self._assemble(header, list(self.recent_actions))
|
||||
|
||||
def render_final(self, elapsed_s: float, answer: str, status: str = "done") -> str:
|
||||
header = format_header(elapsed_s, self.last_item, label=status)
|
||||
answer = (answer or "").strip()
|
||||
return header + ("\n\n" + answer if answer else "")
|
||||
|
||||
@staticmethod
|
||||
def _assemble(header: str, lines: list[str]) -> str:
|
||||
return header if not lines else header + "\n\n" + HARD_BREAK.join(lines)
|
||||
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
|
||||
TELEGRAM_TOKEN_RE = re.compile(r"bot\d+:[A-Za-z0-9_-]+")
|
||||
TELEGRAM_BARE_TOKEN_RE = re.compile(r"\b\d+:[A-Za-z0-9_-]{10,}\b")
|
||||
|
||||
|
||||
class RedactTokenFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
try:
|
||||
message = record.getMessage()
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
redacted = TELEGRAM_TOKEN_RE.sub("bot[REDACTED]", message)
|
||||
redacted = TELEGRAM_BARE_TOKEN_RE.sub("[REDACTED_TOKEN]", redacted)
|
||||
if redacted != message:
|
||||
record.msg = redacted
|
||||
record.args = ()
|
||||
return True
|
||||
|
||||
|
||||
def setup_logging(*, debug: bool = False) -> None:
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.DEBUG)
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
handler.close()
|
||||
|
||||
fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
redactor = RedactTokenFilter()
|
||||
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.DEBUG if debug else logging.INFO)
|
||||
console.setFormatter(fmt)
|
||||
console.addFilter(redactor)
|
||||
root_logger.addFilter(redactor)
|
||||
root_logger.addHandler(console)
|
||||
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
from sulguk import transform_html
|
||||
|
||||
_md = MarkdownIt("commonmark", {"html": False})
|
||||
|
||||
|
||||
def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]:
|
||||
html = _md.render(md or "")
|
||||
rendered = transform_html(html)
|
||||
|
||||
text = re.sub(r"(?m)^(\s*)•", r"\1-", rendered.text)
|
||||
|
||||
# FIX: Telegram requires MessageEntity.language (if present) to be a String.
|
||||
entities: list[dict[str, Any]] = []
|
||||
for e in rendered.entities:
|
||||
d = dict(e)
|
||||
if "language" in d and not isinstance(d["language"], str):
|
||||
d.pop("language", None)
|
||||
entities.append(d)
|
||||
return text, entities
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .logging import RedactTokenFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.addFilter(RedactTokenFilter())
|
||||
|
||||
|
||||
class TelegramAPIError(RuntimeError):
|
||||
def __init__(
|
||||
self, method: str, payload: dict[str, Any], status_code: int | None
|
||||
) -> None:
|
||||
desc = payload.get("description") or str(payload)
|
||||
super().__init__(f"{method} failed: {desc}")
|
||||
self.payload = payload
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class TelegramClient:
|
||||
"""
|
||||
Minimal Telegram Bot API client.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token: str,
|
||||
timeout_s: float = 120,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
sleep: Callable[[float], Awaitable[None]] = asyncio.sleep,
|
||||
) -> None:
|
||||
if not token:
|
||||
raise ValueError("Telegram token is empty")
|
||||
self._base = f"https://api.telegram.org/bot{token}"
|
||||
self._client = client or httpx.AsyncClient(timeout=timeout_s)
|
||||
self._owns_client = client is None
|
||||
self._sleep = sleep
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._owns_client:
|
||||
await self._client.aclose()
|
||||
|
||||
async def _post(self, method: str, json_data: dict[str, Any]) -> Any:
|
||||
try:
|
||||
logger.debug("[telegram] request %s: %s", method, json_data)
|
||||
resp = await self._client.post(f"{self._base}/{method}", json=json_data)
|
||||
payload: dict[str, Any] | None = None
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
resp.raise_for_status()
|
||||
raise
|
||||
if not payload.get("ok"):
|
||||
params = payload.get("parameters") or {}
|
||||
retry_after = params.get("retry_after")
|
||||
if resp.status_code == 429 and isinstance(retry_after, int):
|
||||
logger.warning(
|
||||
"[telegram] 429 retry_after=%s method=%s", retry_after, method
|
||||
)
|
||||
await self._sleep(retry_after)
|
||||
return await self._post(method, json_data)
|
||||
raise TelegramAPIError(method, payload, resp.status_code)
|
||||
logger.debug("[telegram] response %s: %s", method, payload)
|
||||
return payload["result"]
|
||||
except httpx.HTTPError as e:
|
||||
logger.error("Telegram network error: %s", e)
|
||||
raise
|
||||
|
||||
async def get_updates(
|
||||
self,
|
||||
offset: int | None,
|
||||
timeout_s: int = 50,
|
||||
allowed_updates: list[str] | None = None,
|
||||
) -> list[dict]:
|
||||
params: dict[str, Any] = {"timeout": timeout_s}
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
if allowed_updates is not None:
|
||||
params["allowed_updates"] = allowed_updates
|
||||
return await self._post("getUpdates", params) # type: ignore[return-value]
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_to_message_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
entities: list[dict] | None = None,
|
||||
parse_mode: str | None = None,
|
||||
) -> dict:
|
||||
params: dict[str, Any] = {
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
}
|
||||
if disable_notification is not None:
|
||||
params["disable_notification"] = disable_notification
|
||||
if reply_to_message_id is not None:
|
||||
params["reply_to_message_id"] = reply_to_message_id
|
||||
if entities is not None:
|
||||
params["entities"] = entities
|
||||
if parse_mode is not None:
|
||||
params["parse_mode"] = parse_mode
|
||||
return await self._post("sendMessage", params) # type: ignore[return-value]
|
||||
|
||||
async def edit_message_text(
|
||||
self,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
text: str,
|
||||
entities: list[dict] | None = None,
|
||||
parse_mode: str | None = None,
|
||||
) -> dict:
|
||||
params: dict[str, Any] = {
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"text": text,
|
||||
}
|
||||
if entities is not None:
|
||||
params["entities"] = entities
|
||||
if parse_mode is not None:
|
||||
params["parse_mode"] = parse_mode
|
||||
return await self._post("editMessageText", params) # type: ignore[return-value]
|
||||
|
||||
async def delete_message(self, chat_id: int, message_id: int) -> bool:
|
||||
res = await self._post(
|
||||
"deleteMessage",
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
},
|
||||
)
|
||||
return bool(res)
|
||||
Reference in New Issue
Block a user