feat: /cancel (#4)

This commit is contained in:
banteg
2025-12-30 16:20:07 +04:00
committed by GitHub
parent 25c3eccca8
commit 1231c9dc48
4 changed files with 273 additions and 9 deletions
+103 -7
View File
@@ -19,11 +19,16 @@ import typer
from . import __version__
from .config import ConfigError, load_telegram_config
from .exec_render import ExecProgressRenderer, render_event_cli, render_markdown
from .exec_render import (
ExecProgressRenderer,
render_event_cli,
render_markdown,
)
from .logging import setup_logging
from .onboarding import check_setup, render_setup_guide
from .telegram import TelegramClient
logger = logging.getLogger(__name__)
UUID_PATTERN_TEXT = r"\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b"
UUID_PATTERN = re.compile(UUID_PATTERN_TEXT, re.IGNORECASE)
@@ -277,10 +282,10 @@ class CodexExecRunner:
except Exception as e:
logger.info("[codex][on_event] callback error: %s", e)
if evt.get("type") == "thread.started":
if evt["type"] == "thread.started":
found_session = evt.get("thread_id") or found_session
if evt.get("type") == "item.completed":
if evt["type"] == "item.completed":
item = evt.get("item") or {}
if item.get("type") == "agent_message" and isinstance(
item.get("text"), str
@@ -291,6 +296,8 @@ class CodexExecRunner:
cancelled = True
finally:
if cancelled:
if not stderr_task.done():
stderr_task.cancel()
task = cast(asyncio.Task, asyncio.current_task())
while task.cancelling():
task.uncancel()
@@ -431,6 +438,7 @@ async def _handle_message(
user_msg_id: int,
text: str,
resume_session: str | None,
running_tasks: dict[str, asyncio.Task[Any]] | None = None,
clock: Callable[[], float] = time.monotonic,
progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
) -> None:
@@ -511,12 +519,25 @@ async def _handle_message(
"[handle] failed to send progress message chat_id=%s: %s", chat_id, e
)
exec_task: asyncio.Task[tuple[str, str, bool]] | None = None
tracked_session_id: str | None = None
async def on_event(evt: dict[str, Any]) -> None:
nonlocal last_edit_at, edit_task, pending_rendered
nonlocal last_edit_at, edit_task, pending_rendered, tracked_session_id
if progress_id is None:
return
if not progress_renderer.note_event(evt):
return
if (
evt["type"] == "thread.started"
and running_tasks is not None
and exec_task is not None
):
tracked_session_id = progress_renderer.resume_session
if tracked_session_id:
running_tasks[tracked_session_id] = exec_task
now = clock()
if (now - last_edit_at) < progress_edit_every:
return
@@ -531,10 +552,16 @@ async def _handle_message(
pending_rendered = rendered
edit_task = asyncio.create_task(_edit_progress(md, rendered, entities))
exec_task = asyncio.create_task(
cfg.runner.run_serialized(text, resume_session, on_event=on_event)
)
cancelled = False
try:
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
text, resume_session, on_event=on_event
)
session_id, answer, saw_agent_message = await exec_task
except asyncio.CancelledError:
cancelled = True
session_id = tracked_session_id or resume_session
except Exception as e:
if edit_task is not None:
await asyncio.gather(edit_task, return_exceptions=True)
@@ -551,11 +578,34 @@ async def _handle_message(
limit=TELEGRAM_MARKDOWN_LIMIT,
)
return
finally:
if tracked_session_id and running_tasks is not None and exec_task is not None:
# Avoid removing a newer task for the same session_id if another run
# registered while this one was finishing.
if running_tasks.get(tracked_session_id) is exec_task:
running_tasks.pop(tracked_session_id, None)
if edit_task is not None:
await asyncio.gather(edit_task, return_exceptions=True)
elapsed = clock() - started_at
if cancelled:
logger.info(
"[handle] cancelled session_id=%s elapsed=%.1fs", session_id, elapsed
)
progress_renderer.resume_session = session_id
final_md = progress_renderer.render_progress(elapsed, label="`cancelled`")
await _send_or_edit_markdown(
cfg.bot,
chat_id=chat_id,
text=final_md,
edit_message_id=progress_id,
reply_to_message_id=user_msg_id,
disable_notification=True,
limit=TELEGRAM_MARKDOWN_LIMIT,
)
return
status = "done" if saw_agent_message else "error"
progress_renderer.resume_session = session_id
final_md = progress_renderer.render_final(elapsed, answer, status=status)
@@ -624,11 +674,51 @@ async def poll_updates(cfg: BridgeConfig):
yield msg
async def _handle_cancel(
cfg: BridgeConfig,
msg: dict[str, Any],
running_tasks: dict[str, asyncio.Task[Any]],
) -> None:
chat_id = msg["chat"]["id"]
user_msg_id = msg["message_id"]
reply = msg.get("reply_to_message")
if not reply:
await cfg.bot.send_message(
chat_id=chat_id,
text="reply to the progress message to cancel.",
reply_to_message_id=user_msg_id,
)
return
session_id = extract_session_id(reply.get("text"))
if not session_id:
await cfg.bot.send_message(
chat_id=chat_id,
text="nothing is currently running for that message.",
reply_to_message_id=user_msg_id,
)
return
task = running_tasks.get(session_id)
if not task or task.done():
await cfg.bot.send_message(
chat_id=chat_id,
text="nothing is currently running for that message.",
reply_to_message_id=user_msg_id,
)
return
logger.info("[cancel] cancelling session_id=%s", session_id)
task.cancel()
async def _run_main_loop(cfg: BridgeConfig) -> None:
worker_count = max(1, min(cfg.max_concurrency, 16))
queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue(
maxsize=worker_count * 2
)
running_tasks: dict[str, asyncio.Task[Any]] = {}
async def worker() -> None:
while True:
@@ -640,6 +730,7 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
user_msg_id=user_msg_id,
text=text,
resume_session=resume_session,
running_tasks=running_tasks,
)
except Exception:
logger.exception("[handle] worker failed")
@@ -653,6 +744,11 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
async for msg in poll_updates(cfg):
text = msg["text"]
user_msg_id = msg["message_id"]
if text == "/cancel":
tg.create_task(_handle_cancel(cfg, msg, running_tasks))
continue
r = msg.get("reply_to_message") or {}
resume_session = resolve_resume_session(text, r.get("text"))
+2 -2
View File
@@ -240,8 +240,8 @@ class ExecProgressRenderer:
self.recent_actions.append(progress_line)
return True
def render_progress(self, elapsed_s: float) -> str:
header = format_header(elapsed_s, self.last_item, label="working")
def render_progress(self, elapsed_s: float, label: str = "working") -> str:
header = format_header(elapsed_s, self.last_item, label=label)
message = self._assemble(header, list(self.recent_actions))
return self._append_resume(message)