diff --git a/src/takopi/runner_bridge.py b/src/takopi/runner_bridge.py index 14759a8..c22c11d 100644 --- a/src/takopi/runner_bridge.py +++ b/src/takopi/runner_bridge.py @@ -244,11 +244,11 @@ async def send_initial_progress( reply_to: MessageRef, label: str, tracker: ProgressTracker, + progress_ref: MessageRef | None = None, resume_formatter: Callable[[ResumeToken], str] | None = None, context_line: str | None = None, thread_id: ThreadId | None = None, ) -> ProgressMessageState: - progress_ref: MessageRef | None = None last_rendered: RenderedMessage | None = None state = tracker.snapshot( @@ -260,27 +260,26 @@ async def send_initial_progress( elapsed_s=0.0, label=label, ) - logger.debug( - "transport.send_message", - channel_id=channel_id, - reply_to_message_id=reply_to.message_id, - rendered=initial_rendered.text, - ) - progress_ref = await cfg.transport.send( + sent_ref, _ = await _send_or_edit_message( + cfg.transport, channel_id=channel_id, message=initial_rendered, - options=SendOptions(reply_to=reply_to, notify=False, thread_id=thread_id), + edit_ref=progress_ref, + reply_to=reply_to, + notify=False, + replace_ref=progress_ref, + thread_id=thread_id, ) - if progress_ref is not None: + if sent_ref is not None: last_rendered = initial_rendered logger.debug( "progress.sent", - channel_id=progress_ref.channel_id, - message_id=progress_ref.message_id, + channel_id=sent_ref.channel_id, + message_id=sent_ref.message_id, ) return ProgressMessageState( - ref=progress_ref, + ref=sent_ref, last_rendered=last_rendered, ) @@ -396,6 +395,7 @@ async def handle_message( running_tasks: RunningTasks | None = None, on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None = None, + progress_ref: MessageRef | None = None, clock: Callable[[], float] = time.monotonic, ) -> None: logger.info( @@ -422,6 +422,7 @@ async def handle_message( reply_to=user_ref, label="starting", tracker=progress_tracker, + progress_ref=progress_ref, resume_formatter=runner.format_resume, context_line=context_line, thread_id=incoming.thread_id, diff --git a/src/takopi/scheduler.py b/src/takopi/scheduler.py index 8daf3d0..86df778 100644 --- a/src/takopi/scheduler.py +++ b/src/takopi/scheduler.py @@ -10,7 +10,7 @@ import anyio from .context import RunContext from .logging import get_logger from .model import ResumeToken -from .transport import ChannelId, MessageId, ThreadId +from .transport import ChannelId, MessageId, MessageRef, ThreadId logger = get_logger(__name__) @@ -24,6 +24,7 @@ class ThreadJob: context: RunContext | None = None thread_id: ThreadId | None = None session_key: tuple[int, int | None] | None = None + progress_ref: MessageRef | None = None RunJob = Callable[[ThreadJob], Awaitable[None]] @@ -41,6 +42,7 @@ class ThreadScheduler: self._run_job = run_job self._lock = anyio.Lock() self._pending_by_thread: dict[str, deque[ThreadJob]] = {} + self._queued_by_progress: dict[tuple[ChannelId, MessageId], ThreadJob] = {} self._active_threads: set[str] = set() self._busy_until: dict[str, anyio.Event] = {} @@ -64,6 +66,9 @@ class ThreadScheduler: queue = deque() self._pending_by_thread[key] = queue queue.append(job) + if job.progress_ref is not None: + progress_key = (job.chat_id, job.progress_ref.message_id) + self._queued_by_progress[progress_key] = job if key in self._active_threads: return self._active_threads.add(key) @@ -78,6 +83,7 @@ class ThreadScheduler: context: RunContext | None = None, thread_id: ThreadId | None = None, session_key: tuple[int, int | None] | None = None, + progress_ref: MessageRef | None = None, ) -> None: await self.enqueue( ThreadJob( @@ -88,9 +94,30 @@ class ThreadScheduler: context=context, thread_id=thread_id, session_key=session_key, + progress_ref=progress_ref, ) ) + async def cancel_queued( + self, chat_id: ChannelId, progress_msg_id: MessageId + ) -> ThreadJob | None: + progress_key = (chat_id, progress_msg_id) + async with self._lock: + job = self._queued_by_progress.pop(progress_key, None) + if job is None: + return None + thread_key = self.thread_key(job.resume_token) + queue = self._pending_by_thread.get(thread_key) + if queue is None: + return None + try: + queue.remove(job) + except ValueError: + return None + if not queue: + self._pending_by_thread.pop(thread_key, None) + return job + async def _clear_busy(self, key: str, done: anyio.Event) -> None: await done.wait() async with self._lock: @@ -108,6 +135,9 @@ class ThreadScheduler: self._active_threads.discard(key) return job = queue.popleft() + if job.progress_ref is not None: + progress_key = (job.chat_id, job.progress_ref.message_id) + self._queued_by_progress.pop(progress_key, None) if done is not None and not done.is_set(): await done.wait() diff --git a/src/takopi/telegram/bridge.py b/src/takopi/telegram/bridge.py index d9d7e22..8ff89f7 100644 --- a/src/takopi/telegram/bridge.py +++ b/src/takopi/telegram/bridge.py @@ -12,6 +12,7 @@ from ..transport import MessageRef, RenderedMessage, SendOptions, Transport from ..transport_runtime import TransportRuntime from ..context import RunContext from ..model import ResumeToken +from ..scheduler import ThreadScheduler from ..settings import ( TelegramFilesSettings, TelegramTopicsSettings, @@ -326,20 +327,22 @@ async def handle_cancel( cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, running_tasks: RunningTasks, + scheduler: ThreadScheduler | None = None, ) -> None: from .commands import handle_cancel as _handle_cancel - await _handle_cancel(cfg, msg, running_tasks) + await _handle_cancel(cfg, msg, running_tasks, scheduler) async def handle_callback_cancel( cfg: TelegramBridgeConfig, query: TelegramCallbackQuery, running_tasks: RunningTasks, + scheduler: ThreadScheduler | None = None, ) -> None: from .commands import handle_callback_cancel as _handle_callback_cancel - await _handle_callback_cancel(cfg, query, running_tasks) + await _handle_callback_cancel(cfg, query, running_tasks, scheduler) async def send_with_resume( @@ -353,6 +356,7 @@ async def send_with_resume( RunContext | None, int | None, tuple[int, int | None] | None, + MessageRef | None, ], Awaitable[None], ], diff --git a/src/takopi/telegram/commands/cancel.py b/src/takopi/telegram/commands/cancel.py index cb0959e..e8b04ef 100644 --- a/src/takopi/telegram/commands/cancel.py +++ b/src/takopi/telegram/commands/cancel.py @@ -3,7 +3,9 @@ from __future__ import annotations from typing import TYPE_CHECKING from ...logging import get_logger +from ...progress import ProgressTracker from ...runner_bridge import RunningTasks +from ...scheduler import ThreadJob, ThreadScheduler from ...transport import MessageRef from ..types import TelegramCallbackQuery, TelegramIncomingMessage from .reply import make_reply @@ -18,6 +20,7 @@ async def handle_cancel( cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, running_tasks: RunningTasks, + scheduler: ThreadScheduler | None = None, ) -> None: reply = make_reply(cfg, msg) chat_id = msg.chat_id @@ -33,6 +36,17 @@ async def handle_cancel( progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id) running_task = running_tasks.get(progress_ref) if running_task is None: + if scheduler is not None: + job = await scheduler.cancel_queued(chat_id, reply_id) + if job is not None: + logger.info( + "cancel.queued", + chat_id=chat_id, + progress_message_id=reply_id, + resume=job.resume_token.value, + ) + await _edit_cancelled_message(cfg, progress_ref, job) + return await reply(text="nothing is currently running for that message.") return @@ -48,10 +62,26 @@ async def handle_callback_cancel( cfg: TelegramBridgeConfig, query: TelegramCallbackQuery, running_tasks: RunningTasks, + scheduler: ThreadScheduler | None = None, ) -> None: progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id) running_task = running_tasks.get(progress_ref) if running_task is None: + if scheduler is not None: + job = await scheduler.cancel_queued(query.chat_id, query.message_id) + if job is not None: + logger.info( + "cancel.queued", + chat_id=query.chat_id, + progress_message_id=query.message_id, + resume=job.resume_token.value, + ) + await _edit_cancelled_message(cfg, progress_ref, job) + await cfg.bot.answer_callback_query( + callback_query_id=query.callback_query_id, + text="dropped from queue.", + ) + return await cfg.bot.answer_callback_query( callback_query_id=query.callback_query_id, text="nothing is currently running for that message.", @@ -67,3 +97,20 @@ async def handle_callback_cancel( callback_query_id=query.callback_query_id, text="cancelling...", ) + + +async def _edit_cancelled_message( + cfg: TelegramBridgeConfig, + progress_ref: MessageRef, + job: ThreadJob, +) -> None: + tracker = ProgressTracker(engine=job.resume_token.engine) + tracker.set_resume(job.resume_token) + context_line = cfg.runtime.format_context_line(job.context) + state = tracker.snapshot(context_line=context_line) + message = cfg.exec_cfg.presenter.render_progress( + state, + elapsed_s=0.0, + label="`cancelled`", + ) + await cfg.exec_cfg.transport.edit(ref=progress_ref, message=message) diff --git a/src/takopi/telegram/commands/executor.py b/src/takopi/telegram/commands/executor.py index 842ac2b..6009313 100644 --- a/src/takopi/telegram/commands/executor.py +++ b/src/takopi/telegram/commands/executor.py @@ -107,6 +107,7 @@ async def _run_engine( engine_override: EngineId | None = None, thread_id: int | None = None, show_resume_line: bool = True, + progress_ref: MessageRef | None = None, ) -> None: reply = partial( send_plain, @@ -176,6 +177,7 @@ async def _run_engine( strip_resume_line=runtime.is_resume_line, running_tasks=running_tasks, on_thread_known=on_thread_known, + progress_ref=progress_ref, ) finally: reset_run_base_dir(run_base_token) diff --git a/src/takopi/telegram/loop.py b/src/takopi/telegram/loop.py index 2040495..5c00aed 100644 --- a/src/takopi/telegram/loop.py +++ b/src/takopi/telegram/loop.py @@ -15,8 +15,9 @@ from ..directives import DirectiveError from ..logging import get_logger from ..model import EngineId, ResumeToken from ..scheduler import ThreadJob, ThreadScheduler +from ..progress import ProgressTracker from ..settings import TelegramTransportSettings -from ..transport import MessageRef +from ..transport import MessageRef, SendOptions from ..transport_runtime import ResolvedMessage from ..context import RunContext from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain @@ -257,6 +258,36 @@ async def _wait_for_resume(running_task) -> ResumeToken | None: return resume +async def _send_queued_progress( + cfg: TelegramBridgeConfig, + *, + chat_id: int, + user_msg_id: int, + thread_id: int | None, + resume_token: ResumeToken, + context: RunContext | None, +) -> MessageRef | None: + tracker = ProgressTracker(engine=resume_token.engine) + tracker.set_resume(resume_token) + context_line = cfg.runtime.format_context_line(context) + state = tracker.snapshot(context_line=context_line) + message = cfg.exec_cfg.presenter.render_progress( + state, + elapsed_s=0.0, + label="queued", + ) + reply_ref = MessageRef( + channel_id=chat_id, + message_id=user_msg_id, + thread_id=thread_id, + ) + return await cfg.exec_cfg.transport.send( + channel_id=chat_id, + message=message, + options=SendOptions(reply_to=reply_ref, notify=False, thread_id=thread_id), + ) + + async def send_with_resume( cfg: TelegramBridgeConfig, enqueue: Callable[ @@ -268,6 +299,7 @@ async def send_with_resume( RunContext | None, int | None, tuple[int, int | None] | None, + MessageRef | None, ], Awaitable[None], ], @@ -292,6 +324,14 @@ async def send_with_resume( notify=False, ) return + progress_ref = await _send_queued_progress( + cfg, + chat_id=chat_id, + user_msg_id=user_msg_id, + thread_id=thread_id, + resume_token=resume, + context=running_task.context, + ) await enqueue( chat_id, user_msg_id, @@ -300,6 +340,7 @@ async def send_with_resume( running_task.context, thread_id, session_key, + progress_ref, ) @@ -459,6 +500,7 @@ async def run_main_loop( on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None = None, engine_override: EngineId | None = None, + progress_ref: MessageRef | None = None, ) -> None: topic_key = ( (chat_id, thread_id) @@ -491,6 +533,7 @@ async def run_main_loop( engine_override=engine_override, thread_id=thread_id, show_resume_line=show_resume_line, + progress_ref=progress_ref, ) async def run_thread_job(job: ThreadJob) -> None: @@ -504,6 +547,8 @@ async def run_main_loop( job.session_key, None, scheduler.note_thread_known, + None, + job.progress_ref, ) scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job) @@ -673,6 +718,14 @@ async def run_main_loop( engine_override, ) return + progress_ref = await _send_queued_progress( + cfg, + chat_id=chat_id, + user_msg_id=user_msg_id, + thread_id=msg.thread_id, + resume_token=resume_token, + context=context, + ) await scheduler.enqueue_resume( chat_id, user_msg_id, @@ -681,6 +734,7 @@ async def run_main_loop( context, msg.thread_id, chat_session_key, + progress_ref, ) async def handle_prompt_upload( @@ -735,7 +789,9 @@ async def run_main_loop( async for msg in poller(cfg): if isinstance(msg, TelegramCallbackQuery): if msg.data == CANCEL_CALLBACK_DATA: - tg.start_soon(handle_callback_cancel, cfg, msg, running_tasks) + tg.start_soon( + handle_callback_cancel, cfg, msg, running_tasks, scheduler + ) else: tg.start_soon( cfg.bot.answer_callback_query, @@ -798,7 +854,7 @@ async def run_main_loop( continue if is_cancel_command(text): - tg.start_soon(handle_cancel, cfg, msg, running_tasks) + tg.start_soon(handle_cancel, cfg, msg, running_tasks, scheduler) continue command_id, args_text = _parse_slash_command(text) @@ -1017,6 +1073,14 @@ async def run_main_loop( engine_override, ) else: + progress_ref = await _send_queued_progress( + cfg, + chat_id=chat_id, + user_msg_id=user_msg_id, + thread_id=msg.thread_id, + resume_token=resume_token, + context=context, + ) await scheduler.enqueue_resume( chat_id, user_msg_id, @@ -1025,6 +1089,7 @@ async def run_main_loop( context, msg.thread_id, chat_session_key, + progress_ref, ) finally: await cfg.exec_cfg.transport.close() diff --git a/tests/test_telegram_bridge.py b/tests/test_telegram_bridge.py index 6b720d1..116bbc1 100644 --- a/tests/test_telegram_bridge.py +++ b/tests/test_telegram_bridge.py @@ -44,6 +44,7 @@ from takopi.markdown import MarkdownPresenter from takopi.model import ResumeToken from takopi.progress import ProgressTracker from takopi.router import AutoRouter, RunnerEntry +from takopi.scheduler import ThreadScheduler from takopi.transport_runtime import TransportRuntime from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait from takopi.telegram.types import ( @@ -68,6 +69,12 @@ def _make_router(runner) -> AutoRouter: ) +class _NoopTaskGroup: + def start_soon(self, func, *args: Any) -> None: + _ = func, args + return None + + class _FakeTransport: def __init__(self, progress_ready: anyio.Event | None = None) -> None: self._next_id = 1 @@ -849,6 +856,42 @@ async def test_handle_cancel_only_cancels_matching_progress_message() -> None: assert len(transport.send_calls) == 0 +@pytest.mark.anyio +async def test_handle_cancel_cancels_queued_job() -> None: + transport = _FakeTransport() + cfg = _make_cfg(transport) + + async def _noop_run_job(_) -> None: + return None + + scheduler = ThreadScheduler(task_group=_NoopTaskGroup(), run_job=_noop_run_job) + progress_id = 55 + progress_ref = MessageRef(channel_id=123, message_id=progress_id) + resume = ResumeToken(engine=CODEX_ENGINE, value="sid") + await scheduler.enqueue_resume( + chat_id=123, + user_msg_id=10, + text="queued", + resume_token=resume, + progress_ref=progress_ref, + ) + msg = TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=10, + text="/cancel", + reply_to_message_id=progress_id, + reply_to_text=None, + sender_id=123, + ) + + await handle_cancel(cfg, msg, {}, scheduler) + + assert transport.edit_calls + assert "cancelled" in transport.edit_calls[0]["message"].text.lower() + assert await scheduler.cancel_queued(123, progress_ref.message_id) is None + + @pytest.mark.anyio async def test_handle_file_put_writes_file(tmp_path: Path) -> None: payload = b"hello" @@ -998,6 +1041,43 @@ async def test_handle_callback_cancel_cancels_running_task() -> None: assert bot.callback_calls[-1]["text"] == "cancelling..." +@pytest.mark.anyio +async def test_handle_callback_cancel_cancels_queued_job() -> None: + transport = _FakeTransport() + cfg = _make_cfg(transport) + + async def _noop_run_job(_) -> None: + return None + + scheduler = ThreadScheduler(task_group=_NoopTaskGroup(), run_job=_noop_run_job) + progress_id = 77 + progress_ref = MessageRef(channel_id=123, message_id=progress_id) + resume = ResumeToken(engine=CODEX_ENGINE, value="sid") + await scheduler.enqueue_resume( + chat_id=123, + user_msg_id=10, + text="queued", + resume_token=resume, + progress_ref=progress_ref, + ) + query = TelegramCallbackQuery( + transport="telegram", + chat_id=123, + message_id=progress_id, + callback_query_id="cbq-queued", + data="takopi:cancel", + sender_id=123, + ) + + await handle_callback_cancel(cfg, query, {}, scheduler) + + assert transport.edit_calls + assert "cancelled" in transport.edit_calls[0]["message"].text.lower() + bot = cast(_FakeBot, cfg.bot) + assert bot.callback_calls + assert bot.callback_calls[-1]["text"] == "dropped from queue." + + @pytest.mark.anyio async def test_handle_callback_cancel_without_task_acknowledges() -> None: transport = _FakeTransport() @@ -1249,6 +1329,7 @@ async def test_send_with_resume_waits_for_token() -> None: RunContext | None, int | None, tuple[int, int | None] | None, + MessageRef | None, ] ] = [] @@ -1260,9 +1341,19 @@ async def test_send_with_resume_waits_for_token() -> None: context: RunContext | None, thread_id: int | None, session_key: tuple[int, int | None] | None, + progress_ref: MessageRef | None, ) -> None: sent.append( - (chat_id, user_msg_id, text, resume, context, thread_id, session_key) + ( + chat_id, + user_msg_id, + text, + resume, + context, + thread_id, + session_key, + progress_ref, + ) ) running_task = RunningTask() @@ -1285,18 +1376,19 @@ async def test_send_with_resume_waits_for_token() -> None: "hello", ) - assert sent == [ - ( - 123, - 10, - "hello", - ResumeToken(engine=CODEX_ENGINE, value="abc123"), - None, - None, - None, - ) - ] - assert transport.send_calls == [] + assert len(sent) == 1 + assert sent[0][:7] == ( + 123, + 10, + "hello", + ResumeToken(engine=CODEX_ENGINE, value="abc123"), + None, + None, + None, + ) + assert sent[0][7] == transport.send_calls[0]["ref"] + assert transport.send_calls + assert "queued" in transport.send_calls[0]["message"].text.lower() @pytest.mark.anyio @@ -1312,6 +1404,7 @@ async def test_send_with_resume_reports_when_missing() -> None: RunContext | None, int | None, tuple[int, int | None] | None, + MessageRef | None, ] ] = [] @@ -1323,9 +1416,19 @@ async def test_send_with_resume_reports_when_missing() -> None: context: RunContext | None, thread_id: int | None, session_key: tuple[int, int | None] | None, + progress_ref: MessageRef | None, ) -> None: sent.append( - (chat_id, user_msg_id, text, resume, context, thread_id, session_key) + ( + chat_id, + user_msg_id, + text, + resume, + context, + thread_id, + session_key, + progress_ref, + ) ) running_task = RunningTask()