feat: /cancel (#4)
This commit is contained in:
@@ -94,6 +94,10 @@ Reply to a bot message (containing `resume: <uuid>`), or include the resume line
|
|||||||
resume: `019b66fc-64c2-7a71-81cd-081c504cfeb2`
|
resume: `019b66fc-64c2-7a71-81cd-081c504cfeb2`
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Cancel a Run
|
||||||
|
|
||||||
|
Reply to a progress message with `/cancel` to stop the running execution.
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
- **Startup**: Pending updates are drained (ignored) on startup
|
- **Startup**: Pending updates are drained (ignored) on startup
|
||||||
|
|||||||
+103
-7
@@ -19,11 +19,16 @@ import typer
|
|||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .config import ConfigError, load_telegram_config
|
from .config import ConfigError, load_telegram_config
|
||||||
from .exec_render import ExecProgressRenderer, render_event_cli, render_markdown
|
from .exec_render import (
|
||||||
|
ExecProgressRenderer,
|
||||||
|
render_event_cli,
|
||||||
|
render_markdown,
|
||||||
|
)
|
||||||
from .logging import setup_logging
|
from .logging import setup_logging
|
||||||
from .onboarding import check_setup, render_setup_guide
|
from .onboarding import check_setup, render_setup_guide
|
||||||
from .telegram import TelegramClient
|
from .telegram import TelegramClient
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
UUID_PATTERN_TEXT = r"\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b"
|
UUID_PATTERN_TEXT = r"\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b"
|
||||||
UUID_PATTERN = re.compile(UUID_PATTERN_TEXT, re.IGNORECASE)
|
UUID_PATTERN = re.compile(UUID_PATTERN_TEXT, re.IGNORECASE)
|
||||||
@@ -277,10 +282,10 @@ class CodexExecRunner:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("[codex][on_event] callback error: %s", e)
|
logger.info("[codex][on_event] callback error: %s", e)
|
||||||
|
|
||||||
if evt.get("type") == "thread.started":
|
if evt["type"] == "thread.started":
|
||||||
found_session = evt.get("thread_id") or found_session
|
found_session = evt.get("thread_id") or found_session
|
||||||
|
|
||||||
if evt.get("type") == "item.completed":
|
if evt["type"] == "item.completed":
|
||||||
item = evt.get("item") or {}
|
item = evt.get("item") or {}
|
||||||
if item.get("type") == "agent_message" and isinstance(
|
if item.get("type") == "agent_message" and isinstance(
|
||||||
item.get("text"), str
|
item.get("text"), str
|
||||||
@@ -291,6 +296,8 @@ class CodexExecRunner:
|
|||||||
cancelled = True
|
cancelled = True
|
||||||
finally:
|
finally:
|
||||||
if cancelled:
|
if cancelled:
|
||||||
|
if not stderr_task.done():
|
||||||
|
stderr_task.cancel()
|
||||||
task = cast(asyncio.Task, asyncio.current_task())
|
task = cast(asyncio.Task, asyncio.current_task())
|
||||||
while task.cancelling():
|
while task.cancelling():
|
||||||
task.uncancel()
|
task.uncancel()
|
||||||
@@ -431,6 +438,7 @@ async def _handle_message(
|
|||||||
user_msg_id: int,
|
user_msg_id: int,
|
||||||
text: str,
|
text: str,
|
||||||
resume_session: str | None,
|
resume_session: str | None,
|
||||||
|
running_tasks: dict[str, asyncio.Task[Any]] | None = None,
|
||||||
clock: Callable[[], float] = time.monotonic,
|
clock: Callable[[], float] = time.monotonic,
|
||||||
progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
|
progress_edit_every: float = PROGRESS_EDIT_EVERY_S,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -511,12 +519,25 @@ async def _handle_message(
|
|||||||
"[handle] failed to send progress message chat_id=%s: %s", chat_id, e
|
"[handle] failed to send progress message chat_id=%s: %s", chat_id, e
|
||||||
)
|
)
|
||||||
|
|
||||||
|
exec_task: asyncio.Task[tuple[str, str, bool]] | None = None
|
||||||
|
tracked_session_id: str | None = None
|
||||||
|
|
||||||
async def on_event(evt: dict[str, Any]) -> None:
|
async def on_event(evt: dict[str, Any]) -> None:
|
||||||
nonlocal last_edit_at, edit_task, pending_rendered
|
nonlocal last_edit_at, edit_task, pending_rendered, tracked_session_id
|
||||||
if progress_id is None:
|
if progress_id is None:
|
||||||
return
|
return
|
||||||
if not progress_renderer.note_event(evt):
|
if not progress_renderer.note_event(evt):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
evt["type"] == "thread.started"
|
||||||
|
and running_tasks is not None
|
||||||
|
and exec_task is not None
|
||||||
|
):
|
||||||
|
tracked_session_id = progress_renderer.resume_session
|
||||||
|
if tracked_session_id:
|
||||||
|
running_tasks[tracked_session_id] = exec_task
|
||||||
|
|
||||||
now = clock()
|
now = clock()
|
||||||
if (now - last_edit_at) < progress_edit_every:
|
if (now - last_edit_at) < progress_edit_every:
|
||||||
return
|
return
|
||||||
@@ -531,10 +552,16 @@ async def _handle_message(
|
|||||||
pending_rendered = rendered
|
pending_rendered = rendered
|
||||||
edit_task = asyncio.create_task(_edit_progress(md, rendered, entities))
|
edit_task = asyncio.create_task(_edit_progress(md, rendered, entities))
|
||||||
|
|
||||||
|
exec_task = asyncio.create_task(
|
||||||
|
cfg.runner.run_serialized(text, resume_session, on_event=on_event)
|
||||||
|
)
|
||||||
|
|
||||||
|
cancelled = False
|
||||||
try:
|
try:
|
||||||
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
|
session_id, answer, saw_agent_message = await exec_task
|
||||||
text, resume_session, on_event=on_event
|
except asyncio.CancelledError:
|
||||||
)
|
cancelled = True
|
||||||
|
session_id = tracked_session_id or resume_session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if edit_task is not None:
|
if edit_task is not None:
|
||||||
await asyncio.gather(edit_task, return_exceptions=True)
|
await asyncio.gather(edit_task, return_exceptions=True)
|
||||||
@@ -551,11 +578,34 @@ async def _handle_message(
|
|||||||
limit=TELEGRAM_MARKDOWN_LIMIT,
|
limit=TELEGRAM_MARKDOWN_LIMIT,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
finally:
|
||||||
|
if tracked_session_id and running_tasks is not None and exec_task is not None:
|
||||||
|
# Avoid removing a newer task for the same session_id if another run
|
||||||
|
# registered while this one was finishing.
|
||||||
|
if running_tasks.get(tracked_session_id) is exec_task:
|
||||||
|
running_tasks.pop(tracked_session_id, None)
|
||||||
|
|
||||||
if edit_task is not None:
|
if edit_task is not None:
|
||||||
await asyncio.gather(edit_task, return_exceptions=True)
|
await asyncio.gather(edit_task, return_exceptions=True)
|
||||||
|
|
||||||
elapsed = clock() - started_at
|
elapsed = clock() - started_at
|
||||||
|
if cancelled:
|
||||||
|
logger.info(
|
||||||
|
"[handle] cancelled session_id=%s elapsed=%.1fs", session_id, elapsed
|
||||||
|
)
|
||||||
|
progress_renderer.resume_session = session_id
|
||||||
|
final_md = progress_renderer.render_progress(elapsed, label="`cancelled`")
|
||||||
|
await _send_or_edit_markdown(
|
||||||
|
cfg.bot,
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=final_md,
|
||||||
|
edit_message_id=progress_id,
|
||||||
|
reply_to_message_id=user_msg_id,
|
||||||
|
disable_notification=True,
|
||||||
|
limit=TELEGRAM_MARKDOWN_LIMIT,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
status = "done" if saw_agent_message else "error"
|
status = "done" if saw_agent_message else "error"
|
||||||
progress_renderer.resume_session = session_id
|
progress_renderer.resume_session = session_id
|
||||||
final_md = progress_renderer.render_final(elapsed, answer, status=status)
|
final_md = progress_renderer.render_final(elapsed, answer, status=status)
|
||||||
@@ -624,11 +674,51 @@ async def poll_updates(cfg: BridgeConfig):
|
|||||||
yield msg
|
yield msg
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_cancel(
|
||||||
|
cfg: BridgeConfig,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
running_tasks: dict[str, asyncio.Task[Any]],
|
||||||
|
) -> None:
|
||||||
|
chat_id = msg["chat"]["id"]
|
||||||
|
user_msg_id = msg["message_id"]
|
||||||
|
reply = msg.get("reply_to_message")
|
||||||
|
|
||||||
|
if not reply:
|
||||||
|
await cfg.bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text="reply to the progress message to cancel.",
|
||||||
|
reply_to_message_id=user_msg_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
session_id = extract_session_id(reply.get("text"))
|
||||||
|
if not session_id:
|
||||||
|
await cfg.bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text="nothing is currently running for that message.",
|
||||||
|
reply_to_message_id=user_msg_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
task = running_tasks.get(session_id)
|
||||||
|
if not task or task.done():
|
||||||
|
await cfg.bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text="nothing is currently running for that message.",
|
||||||
|
reply_to_message_id=user_msg_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("[cancel] cancelling session_id=%s", session_id)
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
|
||||||
async def _run_main_loop(cfg: BridgeConfig) -> None:
|
async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||||
worker_count = max(1, min(cfg.max_concurrency, 16))
|
worker_count = max(1, min(cfg.max_concurrency, 16))
|
||||||
queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue(
|
queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue(
|
||||||
maxsize=worker_count * 2
|
maxsize=worker_count * 2
|
||||||
)
|
)
|
||||||
|
running_tasks: dict[str, asyncio.Task[Any]] = {}
|
||||||
|
|
||||||
async def worker() -> None:
|
async def worker() -> None:
|
||||||
while True:
|
while True:
|
||||||
@@ -640,6 +730,7 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
|
|||||||
user_msg_id=user_msg_id,
|
user_msg_id=user_msg_id,
|
||||||
text=text,
|
text=text,
|
||||||
resume_session=resume_session,
|
resume_session=resume_session,
|
||||||
|
running_tasks=running_tasks,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[handle] worker failed")
|
logger.exception("[handle] worker failed")
|
||||||
@@ -653,6 +744,11 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
|
|||||||
async for msg in poll_updates(cfg):
|
async for msg in poll_updates(cfg):
|
||||||
text = msg["text"]
|
text = msg["text"]
|
||||||
user_msg_id = msg["message_id"]
|
user_msg_id = msg["message_id"]
|
||||||
|
|
||||||
|
if text == "/cancel":
|
||||||
|
tg.create_task(_handle_cancel(cfg, msg, running_tasks))
|
||||||
|
continue
|
||||||
|
|
||||||
r = msg.get("reply_to_message") or {}
|
r = msg.get("reply_to_message") or {}
|
||||||
resume_session = resolve_resume_session(text, r.get("text"))
|
resume_session = resolve_resume_session(text, r.get("text"))
|
||||||
|
|
||||||
|
|||||||
@@ -240,8 +240,8 @@ class ExecProgressRenderer:
|
|||||||
self.recent_actions.append(progress_line)
|
self.recent_actions.append(progress_line)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def render_progress(self, elapsed_s: float) -> str:
|
def render_progress(self, elapsed_s: float, label: str = "working") -> str:
|
||||||
header = format_header(elapsed_s, self.last_item, label="working")
|
header = format_header(elapsed_s, self.last_item, label=label)
|
||||||
message = self._assemble(header, list(self.recent_actions))
|
message = self._assemble(header, list(self.recent_actions))
|
||||||
return self._append_resume(message)
|
return self._append_resume(message)
|
||||||
|
|
||||||
|
|||||||
@@ -403,3 +403,167 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
|
|||||||
assert session_id in bot.send_calls[-1]["text"]
|
assert session_id in bot.send_calls[-1]["text"]
|
||||||
assert "resume:" in bot.send_calls[-1]["text"].lower()
|
assert "resume:" in bot.send_calls[-1]["text"].lower()
|
||||||
assert len(bot.delete_calls) == 1
|
assert len(bot.delete_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_cancel_without_reply_prompts_user() -> None:
|
||||||
|
from takopi.exec_bridge import BridgeConfig, _handle_cancel
|
||||||
|
|
||||||
|
bot = _FakeBot()
|
||||||
|
runner = _FakeRunner(answer="ok")
|
||||||
|
cfg = BridgeConfig(
|
||||||
|
bot=bot, # type: ignore[arg-type]
|
||||||
|
runner=runner, # type: ignore[arg-type]
|
||||||
|
chat_id=123,
|
||||||
|
final_notify=True,
|
||||||
|
startup_msg="",
|
||||||
|
max_concurrency=1,
|
||||||
|
)
|
||||||
|
msg = {"chat": {"id": 123}, "message_id": 10}
|
||||||
|
running_tasks: dict = {}
|
||||||
|
|
||||||
|
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
|
||||||
|
|
||||||
|
assert len(bot.send_calls) == 1
|
||||||
|
assert "reply to the progress message" in bot.send_calls[0]["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_cancel_with_no_session_id_says_nothing_running() -> None:
|
||||||
|
from takopi.exec_bridge import BridgeConfig, _handle_cancel
|
||||||
|
|
||||||
|
bot = _FakeBot()
|
||||||
|
runner = _FakeRunner(answer="ok")
|
||||||
|
cfg = BridgeConfig(
|
||||||
|
bot=bot, # type: ignore[arg-type]
|
||||||
|
runner=runner, # type: ignore[arg-type]
|
||||||
|
chat_id=123,
|
||||||
|
final_notify=True,
|
||||||
|
startup_msg="",
|
||||||
|
max_concurrency=1,
|
||||||
|
)
|
||||||
|
msg = {
|
||||||
|
"chat": {"id": 123},
|
||||||
|
"message_id": 10,
|
||||||
|
"reply_to_message": {"text": "no uuid here"},
|
||||||
|
}
|
||||||
|
running_tasks: dict = {}
|
||||||
|
|
||||||
|
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
|
||||||
|
|
||||||
|
assert len(bot.send_calls) == 1
|
||||||
|
assert "nothing is currently running" in bot.send_calls[0]["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_cancel_with_finished_task_says_nothing_running() -> None:
|
||||||
|
from takopi.exec_bridge import BridgeConfig, _handle_cancel
|
||||||
|
|
||||||
|
bot = _FakeBot()
|
||||||
|
runner = _FakeRunner(answer="ok")
|
||||||
|
cfg = BridgeConfig(
|
||||||
|
bot=bot, # type: ignore[arg-type]
|
||||||
|
runner=runner, # type: ignore[arg-type]
|
||||||
|
chat_id=123,
|
||||||
|
final_notify=True,
|
||||||
|
startup_msg="",
|
||||||
|
max_concurrency=1,
|
||||||
|
)
|
||||||
|
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
|
||||||
|
msg = {
|
||||||
|
"chat": {"id": 123},
|
||||||
|
"message_id": 10,
|
||||||
|
"reply_to_message": {"text": f"resume: `{session_id}`"},
|
||||||
|
}
|
||||||
|
running_tasks: dict = {} # Session not in running_tasks
|
||||||
|
|
||||||
|
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
|
||||||
|
|
||||||
|
assert len(bot.send_calls) == 1
|
||||||
|
assert "nothing is currently running" in bot.send_calls[0]["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_cancel_cancels_running_task() -> None:
|
||||||
|
from takopi.exec_bridge import BridgeConfig, _handle_cancel
|
||||||
|
|
||||||
|
bot = _FakeBot()
|
||||||
|
runner = _FakeRunner(answer="ok")
|
||||||
|
cfg = BridgeConfig(
|
||||||
|
bot=bot, # type: ignore[arg-type]
|
||||||
|
runner=runner, # type: ignore[arg-type]
|
||||||
|
chat_id=123,
|
||||||
|
final_notify=True,
|
||||||
|
startup_msg="",
|
||||||
|
max_concurrency=1,
|
||||||
|
)
|
||||||
|
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
|
||||||
|
msg = {
|
||||||
|
"chat": {"id": 123},
|
||||||
|
"message_id": 10,
|
||||||
|
"reply_to_message": {"text": f"resume: `{session_id}`"},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def run_test():
|
||||||
|
task = asyncio.create_task(asyncio.sleep(10))
|
||||||
|
running_tasks = {session_id: task}
|
||||||
|
await _handle_cancel(cfg, msg, running_tasks)
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
cancelled = asyncio.run(run_test())
|
||||||
|
|
||||||
|
assert cancelled is True
|
||||||
|
assert len(bot.send_calls) == 0 # No error message sent
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRunnerCancellable:
|
||||||
|
def __init__(self, session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2"):
|
||||||
|
self._session_id = session_id
|
||||||
|
|
||||||
|
async def run_serialized(self, *_args, **kwargs) -> tuple[str, str, bool]:
|
||||||
|
on_event = kwargs.get("on_event")
|
||||||
|
if on_event:
|
||||||
|
await on_event({"type": "thread.started", "thread_id": self._session_id})
|
||||||
|
await asyncio.sleep(10) # Will be cancelled
|
||||||
|
return (self._session_id, "ok", True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_message_cancelled_renders_cancelled_state() -> None:
|
||||||
|
from takopi.exec_bridge import BridgeConfig, _handle_message
|
||||||
|
|
||||||
|
bot = _FakeBot()
|
||||||
|
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
|
||||||
|
runner = _FakeRunnerCancellable(session_id=session_id)
|
||||||
|
cfg = BridgeConfig(
|
||||||
|
bot=bot, # type: ignore[arg-type]
|
||||||
|
runner=runner, # type: ignore[arg-type]
|
||||||
|
chat_id=123,
|
||||||
|
final_notify=True,
|
||||||
|
startup_msg="",
|
||||||
|
max_concurrency=1,
|
||||||
|
)
|
||||||
|
running_tasks: dict = {}
|
||||||
|
|
||||||
|
async def run_test():
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_handle_message(
|
||||||
|
cfg,
|
||||||
|
chat_id=123,
|
||||||
|
user_msg_id=10,
|
||||||
|
text="do something",
|
||||||
|
resume_session=None,
|
||||||
|
running_tasks=running_tasks,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.01) # Let task start and register
|
||||||
|
assert session_id in running_tasks
|
||||||
|
running_tasks[session_id].cancel()
|
||||||
|
await task
|
||||||
|
|
||||||
|
asyncio.run(run_test())
|
||||||
|
|
||||||
|
assert len(bot.send_calls) == 1 # Progress message
|
||||||
|
assert len(bot.edit_calls) >= 1
|
||||||
|
last_edit = bot.edit_calls[-1]["text"]
|
||||||
|
assert "cancelled" in last_edit.lower()
|
||||||
|
assert session_id in last_edit
|
||||||
|
|||||||
Reference in New Issue
Block a user