diff --git a/readme.md b/readme.md index 973f713..036f926 100644 --- a/readme.md +++ b/readme.md @@ -94,6 +94,10 @@ Reply to a bot message (containing `resume: `), or include the resume line resume: `019b66fc-64c2-7a71-81cd-081c504cfeb2` ``` +### Cancel a Run + +Reply to a progress message with `/cancel` to stop the running execution. + ## Notes - **Startup**: Pending updates are drained (ignored) on startup diff --git a/src/takopi/exec_bridge.py b/src/takopi/exec_bridge.py index 01dfd39..0eefb8a 100644 --- a/src/takopi/exec_bridge.py +++ b/src/takopi/exec_bridge.py @@ -19,11 +19,16 @@ import typer from . import __version__ from .config import ConfigError, load_telegram_config -from .exec_render import ExecProgressRenderer, render_event_cli, render_markdown +from .exec_render import ( + ExecProgressRenderer, + render_event_cli, + render_markdown, +) from .logging import setup_logging from .onboarding import check_setup, render_setup_guide from .telegram import TelegramClient + logger = logging.getLogger(__name__) UUID_PATTERN_TEXT = r"\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b" UUID_PATTERN = re.compile(UUID_PATTERN_TEXT, re.IGNORECASE) @@ -277,10 +282,10 @@ class CodexExecRunner: except Exception as e: logger.info("[codex][on_event] callback error: %s", e) - if evt.get("type") == "thread.started": + if evt["type"] == "thread.started": found_session = evt.get("thread_id") or found_session - if evt.get("type") == "item.completed": + if evt["type"] == "item.completed": item = evt.get("item") or {} if item.get("type") == "agent_message" and isinstance( item.get("text"), str @@ -291,6 +296,8 @@ class CodexExecRunner: cancelled = True finally: if cancelled: + if not stderr_task.done(): + stderr_task.cancel() task = cast(asyncio.Task, asyncio.current_task()) while task.cancelling(): task.uncancel() @@ -431,6 +438,7 @@ async def _handle_message( user_msg_id: int, text: str, resume_session: str | None, + running_tasks: dict[str, asyncio.Task[Any]] | None = None, clock: Callable[[], float] = time.monotonic, progress_edit_every: float = PROGRESS_EDIT_EVERY_S, ) -> None: @@ -511,12 +519,25 @@ async def _handle_message( "[handle] failed to send progress message chat_id=%s: %s", chat_id, e ) + exec_task: asyncio.Task[tuple[str, str, bool]] | None = None + tracked_session_id: str | None = None + async def on_event(evt: dict[str, Any]) -> None: - nonlocal last_edit_at, edit_task, pending_rendered + nonlocal last_edit_at, edit_task, pending_rendered, tracked_session_id if progress_id is None: return if not progress_renderer.note_event(evt): return + + if ( + evt["type"] == "thread.started" + and running_tasks is not None + and exec_task is not None + ): + tracked_session_id = progress_renderer.resume_session + if tracked_session_id: + running_tasks[tracked_session_id] = exec_task + now = clock() if (now - last_edit_at) < progress_edit_every: return @@ -531,10 +552,16 @@ async def _handle_message( pending_rendered = rendered edit_task = asyncio.create_task(_edit_progress(md, rendered, entities)) + exec_task = asyncio.create_task( + cfg.runner.run_serialized(text, resume_session, on_event=on_event) + ) + + cancelled = False try: - session_id, answer, saw_agent_message = await cfg.runner.run_serialized( - text, resume_session, on_event=on_event - ) + session_id, answer, saw_agent_message = await exec_task + except asyncio.CancelledError: + cancelled = True + session_id = tracked_session_id or resume_session except Exception as e: if edit_task is not None: await asyncio.gather(edit_task, return_exceptions=True) @@ -551,11 +578,34 @@ async def _handle_message( limit=TELEGRAM_MARKDOWN_LIMIT, ) return + finally: + if tracked_session_id and running_tasks is not None and exec_task is not None: + # Avoid removing a newer task for the same session_id if another run + # registered while this one was finishing. + if running_tasks.get(tracked_session_id) is exec_task: + running_tasks.pop(tracked_session_id, None) if edit_task is not None: await asyncio.gather(edit_task, return_exceptions=True) elapsed = clock() - started_at + if cancelled: + logger.info( + "[handle] cancelled session_id=%s elapsed=%.1fs", session_id, elapsed + ) + progress_renderer.resume_session = session_id + final_md = progress_renderer.render_progress(elapsed, label="`cancelled`") + await _send_or_edit_markdown( + cfg.bot, + chat_id=chat_id, + text=final_md, + edit_message_id=progress_id, + reply_to_message_id=user_msg_id, + disable_notification=True, + limit=TELEGRAM_MARKDOWN_LIMIT, + ) + return + status = "done" if saw_agent_message else "error" progress_renderer.resume_session = session_id final_md = progress_renderer.render_final(elapsed, answer, status=status) @@ -624,11 +674,51 @@ async def poll_updates(cfg: BridgeConfig): yield msg +async def _handle_cancel( + cfg: BridgeConfig, + msg: dict[str, Any], + running_tasks: dict[str, asyncio.Task[Any]], +) -> None: + chat_id = msg["chat"]["id"] + user_msg_id = msg["message_id"] + reply = msg.get("reply_to_message") + + if not reply: + await cfg.bot.send_message( + chat_id=chat_id, + text="reply to the progress message to cancel.", + reply_to_message_id=user_msg_id, + ) + return + + session_id = extract_session_id(reply.get("text")) + if not session_id: + await cfg.bot.send_message( + chat_id=chat_id, + text="nothing is currently running for that message.", + reply_to_message_id=user_msg_id, + ) + return + + task = running_tasks.get(session_id) + if not task or task.done(): + await cfg.bot.send_message( + chat_id=chat_id, + text="nothing is currently running for that message.", + reply_to_message_id=user_msg_id, + ) + return + + logger.info("[cancel] cancelling session_id=%s", session_id) + task.cancel() + + async def _run_main_loop(cfg: BridgeConfig) -> None: worker_count = max(1, min(cfg.max_concurrency, 16)) queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue( maxsize=worker_count * 2 ) + running_tasks: dict[str, asyncio.Task[Any]] = {} async def worker() -> None: while True: @@ -640,6 +730,7 @@ async def _run_main_loop(cfg: BridgeConfig) -> None: user_msg_id=user_msg_id, text=text, resume_session=resume_session, + running_tasks=running_tasks, ) except Exception: logger.exception("[handle] worker failed") @@ -653,6 +744,11 @@ async def _run_main_loop(cfg: BridgeConfig) -> None: async for msg in poll_updates(cfg): text = msg["text"] user_msg_id = msg["message_id"] + + if text == "/cancel": + tg.create_task(_handle_cancel(cfg, msg, running_tasks)) + continue + r = msg.get("reply_to_message") or {} resume_session = resolve_resume_session(text, r.get("text")) diff --git a/src/takopi/exec_render.py b/src/takopi/exec_render.py index 6164915..b14b563 100644 --- a/src/takopi/exec_render.py +++ b/src/takopi/exec_render.py @@ -240,8 +240,8 @@ class ExecProgressRenderer: self.recent_actions.append(progress_line) return True - def render_progress(self, elapsed_s: float) -> str: - header = format_header(elapsed_s, self.last_item, label="working") + def render_progress(self, elapsed_s: float, label: str = "working") -> str: + header = format_header(elapsed_s, self.last_item, label=label) message = self._assemble(header, list(self.recent_actions)) return self._append_resume(message) diff --git a/tests/test_exec_bridge.py b/tests/test_exec_bridge.py index fccb304..6d2a397 100644 --- a/tests/test_exec_bridge.py +++ b/tests/test_exec_bridge.py @@ -403,3 +403,167 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None: assert session_id in bot.send_calls[-1]["text"] assert "resume:" in bot.send_calls[-1]["text"].lower() assert len(bot.delete_calls) == 1 + + +def test_handle_cancel_without_reply_prompts_user() -> None: + from takopi.exec_bridge import BridgeConfig, _handle_cancel + + bot = _FakeBot() + runner = _FakeRunner(answer="ok") + cfg = BridgeConfig( + bot=bot, # type: ignore[arg-type] + runner=runner, # type: ignore[arg-type] + chat_id=123, + final_notify=True, + startup_msg="", + max_concurrency=1, + ) + msg = {"chat": {"id": 123}, "message_id": 10} + running_tasks: dict = {} + + asyncio.run(_handle_cancel(cfg, msg, running_tasks)) + + assert len(bot.send_calls) == 1 + assert "reply to the progress message" in bot.send_calls[0]["text"] + + +def test_handle_cancel_with_no_session_id_says_nothing_running() -> None: + from takopi.exec_bridge import BridgeConfig, _handle_cancel + + bot = _FakeBot() + runner = _FakeRunner(answer="ok") + cfg = BridgeConfig( + bot=bot, # type: ignore[arg-type] + runner=runner, # type: ignore[arg-type] + chat_id=123, + final_notify=True, + startup_msg="", + max_concurrency=1, + ) + msg = { + "chat": {"id": 123}, + "message_id": 10, + "reply_to_message": {"text": "no uuid here"}, + } + running_tasks: dict = {} + + asyncio.run(_handle_cancel(cfg, msg, running_tasks)) + + assert len(bot.send_calls) == 1 + assert "nothing is currently running" in bot.send_calls[0]["text"] + + +def test_handle_cancel_with_finished_task_says_nothing_running() -> None: + from takopi.exec_bridge import BridgeConfig, _handle_cancel + + bot = _FakeBot() + runner = _FakeRunner(answer="ok") + cfg = BridgeConfig( + bot=bot, # type: ignore[arg-type] + runner=runner, # type: ignore[arg-type] + chat_id=123, + final_notify=True, + startup_msg="", + max_concurrency=1, + ) + session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2" + msg = { + "chat": {"id": 123}, + "message_id": 10, + "reply_to_message": {"text": f"resume: `{session_id}`"}, + } + running_tasks: dict = {} # Session not in running_tasks + + asyncio.run(_handle_cancel(cfg, msg, running_tasks)) + + assert len(bot.send_calls) == 1 + assert "nothing is currently running" in bot.send_calls[0]["text"] + + +def test_handle_cancel_cancels_running_task() -> None: + from takopi.exec_bridge import BridgeConfig, _handle_cancel + + bot = _FakeBot() + runner = _FakeRunner(answer="ok") + cfg = BridgeConfig( + bot=bot, # type: ignore[arg-type] + runner=runner, # type: ignore[arg-type] + chat_id=123, + final_notify=True, + startup_msg="", + max_concurrency=1, + ) + session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2" + msg = { + "chat": {"id": 123}, + "message_id": 10, + "reply_to_message": {"text": f"resume: `{session_id}`"}, + } + + async def run_test(): + task = asyncio.create_task(asyncio.sleep(10)) + running_tasks = {session_id: task} + await _handle_cancel(cfg, msg, running_tasks) + try: + await task + except asyncio.CancelledError: + return True + return False + + cancelled = asyncio.run(run_test()) + + assert cancelled is True + assert len(bot.send_calls) == 0 # No error message sent + + +class _FakeRunnerCancellable: + def __init__(self, session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2"): + self._session_id = session_id + + async def run_serialized(self, *_args, **kwargs) -> tuple[str, str, bool]: + on_event = kwargs.get("on_event") + if on_event: + await on_event({"type": "thread.started", "thread_id": self._session_id}) + await asyncio.sleep(10) # Will be cancelled + return (self._session_id, "ok", True) + + +def test_handle_message_cancelled_renders_cancelled_state() -> None: + from takopi.exec_bridge import BridgeConfig, _handle_message + + bot = _FakeBot() + session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2" + runner = _FakeRunnerCancellable(session_id=session_id) + cfg = BridgeConfig( + bot=bot, # type: ignore[arg-type] + runner=runner, # type: ignore[arg-type] + chat_id=123, + final_notify=True, + startup_msg="", + max_concurrency=1, + ) + running_tasks: dict = {} + + async def run_test(): + task = asyncio.create_task( + _handle_message( + cfg, + chat_id=123, + user_msg_id=10, + text="do something", + resume_session=None, + running_tasks=running_tasks, + ) + ) + await asyncio.sleep(0.01) # Let task start and register + assert session_id in running_tasks + running_tasks[session_id].cancel() + await task + + asyncio.run(run_test()) + + assert len(bot.send_calls) == 1 # Progress message + assert len(bot.edit_calls) >= 1 + last_edit = bot.edit_calls[-1]["text"] + assert "cancelled" in last_edit.lower() + assert session_id in last_edit