feat(telegram): add queued cancel placeholder (#136)

This commit is contained in:
banteg
2026-01-15 01:31:47 +04:00
committed by GitHub
parent 699ae3b38e
commit ff64741607
7 changed files with 285 additions and 33 deletions
+14 -13
View File
@@ -244,11 +244,11 @@ async def send_initial_progress(
reply_to: MessageRef,
label: str,
tracker: ProgressTracker,
progress_ref: MessageRef | None = None,
resume_formatter: Callable[[ResumeToken], str] | None = None,
context_line: str | None = None,
thread_id: ThreadId | None = None,
) -> ProgressMessageState:
progress_ref: MessageRef | None = None
last_rendered: RenderedMessage | None = None
state = tracker.snapshot(
@@ -260,27 +260,26 @@ async def send_initial_progress(
elapsed_s=0.0,
label=label,
)
logger.debug(
"transport.send_message",
channel_id=channel_id,
reply_to_message_id=reply_to.message_id,
rendered=initial_rendered.text,
)
progress_ref = await cfg.transport.send(
sent_ref, _ = await _send_or_edit_message(
cfg.transport,
channel_id=channel_id,
message=initial_rendered,
options=SendOptions(reply_to=reply_to, notify=False, thread_id=thread_id),
edit_ref=progress_ref,
reply_to=reply_to,
notify=False,
replace_ref=progress_ref,
thread_id=thread_id,
)
if progress_ref is not None:
if sent_ref is not None:
last_rendered = initial_rendered
logger.debug(
"progress.sent",
channel_id=progress_ref.channel_id,
message_id=progress_ref.message_id,
channel_id=sent_ref.channel_id,
message_id=sent_ref.message_id,
)
return ProgressMessageState(
ref=progress_ref,
ref=sent_ref,
last_rendered=last_rendered,
)
@@ -396,6 +395,7 @@ async def handle_message(
running_tasks: RunningTasks | None = None,
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
| None = None,
progress_ref: MessageRef | None = None,
clock: Callable[[], float] = time.monotonic,
) -> None:
logger.info(
@@ -422,6 +422,7 @@ async def handle_message(
reply_to=user_ref,
label="starting",
tracker=progress_tracker,
progress_ref=progress_ref,
resume_formatter=runner.format_resume,
context_line=context_line,
thread_id=incoming.thread_id,
+31 -1
View File
@@ -10,7 +10,7 @@ import anyio
from .context import RunContext
from .logging import get_logger
from .model import ResumeToken
from .transport import ChannelId, MessageId, ThreadId
from .transport import ChannelId, MessageId, MessageRef, ThreadId
logger = get_logger(__name__)
@@ -24,6 +24,7 @@ class ThreadJob:
context: RunContext | None = None
thread_id: ThreadId | None = None
session_key: tuple[int, int | None] | None = None
progress_ref: MessageRef | None = None
RunJob = Callable[[ThreadJob], Awaitable[None]]
@@ -41,6 +42,7 @@ class ThreadScheduler:
self._run_job = run_job
self._lock = anyio.Lock()
self._pending_by_thread: dict[str, deque[ThreadJob]] = {}
self._queued_by_progress: dict[tuple[ChannelId, MessageId], ThreadJob] = {}
self._active_threads: set[str] = set()
self._busy_until: dict[str, anyio.Event] = {}
@@ -64,6 +66,9 @@ class ThreadScheduler:
queue = deque()
self._pending_by_thread[key] = queue
queue.append(job)
if job.progress_ref is not None:
progress_key = (job.chat_id, job.progress_ref.message_id)
self._queued_by_progress[progress_key] = job
if key in self._active_threads:
return
self._active_threads.add(key)
@@ -78,6 +83,7 @@ class ThreadScheduler:
context: RunContext | None = None,
thread_id: ThreadId | None = None,
session_key: tuple[int, int | None] | None = None,
progress_ref: MessageRef | None = None,
) -> None:
await self.enqueue(
ThreadJob(
@@ -88,9 +94,30 @@ class ThreadScheduler:
context=context,
thread_id=thread_id,
session_key=session_key,
progress_ref=progress_ref,
)
)
async def cancel_queued(
self, chat_id: ChannelId, progress_msg_id: MessageId
) -> ThreadJob | None:
progress_key = (chat_id, progress_msg_id)
async with self._lock:
job = self._queued_by_progress.pop(progress_key, None)
if job is None:
return None
thread_key = self.thread_key(job.resume_token)
queue = self._pending_by_thread.get(thread_key)
if queue is None:
return None
try:
queue.remove(job)
except ValueError:
return None
if not queue:
self._pending_by_thread.pop(thread_key, None)
return job
async def _clear_busy(self, key: str, done: anyio.Event) -> None:
await done.wait()
async with self._lock:
@@ -108,6 +135,9 @@ class ThreadScheduler:
self._active_threads.discard(key)
return
job = queue.popleft()
if job.progress_ref is not None:
progress_key = (job.chat_id, job.progress_ref.message_id)
self._queued_by_progress.pop(progress_key, None)
if done is not None and not done.is_set():
await done.wait()
+6 -2
View File
@@ -12,6 +12,7 @@ from ..transport import MessageRef, RenderedMessage, SendOptions, Transport
from ..transport_runtime import TransportRuntime
from ..context import RunContext
from ..model import ResumeToken
from ..scheduler import ThreadScheduler
from ..settings import (
TelegramFilesSettings,
TelegramTopicsSettings,
@@ -326,20 +327,22 @@ async def handle_cancel(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
running_tasks: RunningTasks,
scheduler: ThreadScheduler | None = None,
) -> None:
from .commands import handle_cancel as _handle_cancel
await _handle_cancel(cfg, msg, running_tasks)
await _handle_cancel(cfg, msg, running_tasks, scheduler)
async def handle_callback_cancel(
cfg: TelegramBridgeConfig,
query: TelegramCallbackQuery,
running_tasks: RunningTasks,
scheduler: ThreadScheduler | None = None,
) -> None:
from .commands import handle_callback_cancel as _handle_callback_cancel
await _handle_callback_cancel(cfg, query, running_tasks)
await _handle_callback_cancel(cfg, query, running_tasks, scheduler)
async def send_with_resume(
@@ -353,6 +356,7 @@ async def send_with_resume(
RunContext | None,
int | None,
tuple[int, int | None] | None,
MessageRef | None,
],
Awaitable[None],
],
+47
View File
@@ -3,7 +3,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from ...logging import get_logger
from ...progress import ProgressTracker
from ...runner_bridge import RunningTasks
from ...scheduler import ThreadJob, ThreadScheduler
from ...transport import MessageRef
from ..types import TelegramCallbackQuery, TelegramIncomingMessage
from .reply import make_reply
@@ -18,6 +20,7 @@ async def handle_cancel(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
running_tasks: RunningTasks,
scheduler: ThreadScheduler | None = None,
) -> None:
reply = make_reply(cfg, msg)
chat_id = msg.chat_id
@@ -33,6 +36,17 @@ async def handle_cancel(
progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id)
running_task = running_tasks.get(progress_ref)
if running_task is None:
if scheduler is not None:
job = await scheduler.cancel_queued(chat_id, reply_id)
if job is not None:
logger.info(
"cancel.queued",
chat_id=chat_id,
progress_message_id=reply_id,
resume=job.resume_token.value,
)
await _edit_cancelled_message(cfg, progress_ref, job)
return
await reply(text="nothing is currently running for that message.")
return
@@ -48,10 +62,26 @@ async def handle_callback_cancel(
cfg: TelegramBridgeConfig,
query: TelegramCallbackQuery,
running_tasks: RunningTasks,
scheduler: ThreadScheduler | None = None,
) -> None:
progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id)
running_task = running_tasks.get(progress_ref)
if running_task is None:
if scheduler is not None:
job = await scheduler.cancel_queued(query.chat_id, query.message_id)
if job is not None:
logger.info(
"cancel.queued",
chat_id=query.chat_id,
progress_message_id=query.message_id,
resume=job.resume_token.value,
)
await _edit_cancelled_message(cfg, progress_ref, job)
await cfg.bot.answer_callback_query(
callback_query_id=query.callback_query_id,
text="dropped from queue.",
)
return
await cfg.bot.answer_callback_query(
callback_query_id=query.callback_query_id,
text="nothing is currently running for that message.",
@@ -67,3 +97,20 @@ async def handle_callback_cancel(
callback_query_id=query.callback_query_id,
text="cancelling...",
)
async def _edit_cancelled_message(
cfg: TelegramBridgeConfig,
progress_ref: MessageRef,
job: ThreadJob,
) -> None:
tracker = ProgressTracker(engine=job.resume_token.engine)
tracker.set_resume(job.resume_token)
context_line = cfg.runtime.format_context_line(job.context)
state = tracker.snapshot(context_line=context_line)
message = cfg.exec_cfg.presenter.render_progress(
state,
elapsed_s=0.0,
label="`cancelled`",
)
await cfg.exec_cfg.transport.edit(ref=progress_ref, message=message)
+2
View File
@@ -107,6 +107,7 @@ async def _run_engine(
engine_override: EngineId | None = None,
thread_id: int | None = None,
show_resume_line: bool = True,
progress_ref: MessageRef | None = None,
) -> None:
reply = partial(
send_plain,
@@ -176,6 +177,7 @@ async def _run_engine(
strip_resume_line=runtime.is_resume_line,
running_tasks=running_tasks,
on_thread_known=on_thread_known,
progress_ref=progress_ref,
)
finally:
reset_run_base_dir(run_base_token)
+68 -3
View File
@@ -15,8 +15,9 @@ from ..directives import DirectiveError
from ..logging import get_logger
from ..model import EngineId, ResumeToken
from ..scheduler import ThreadJob, ThreadScheduler
from ..progress import ProgressTracker
from ..settings import TelegramTransportSettings
from ..transport import MessageRef
from ..transport import MessageRef, SendOptions
from ..transport_runtime import ResolvedMessage
from ..context import RunContext
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
@@ -257,6 +258,36 @@ async def _wait_for_resume(running_task) -> ResumeToken | None:
return resume
async def _send_queued_progress(
cfg: TelegramBridgeConfig,
*,
chat_id: int,
user_msg_id: int,
thread_id: int | None,
resume_token: ResumeToken,
context: RunContext | None,
) -> MessageRef | None:
tracker = ProgressTracker(engine=resume_token.engine)
tracker.set_resume(resume_token)
context_line = cfg.runtime.format_context_line(context)
state = tracker.snapshot(context_line=context_line)
message = cfg.exec_cfg.presenter.render_progress(
state,
elapsed_s=0.0,
label="queued",
)
reply_ref = MessageRef(
channel_id=chat_id,
message_id=user_msg_id,
thread_id=thread_id,
)
return await cfg.exec_cfg.transport.send(
channel_id=chat_id,
message=message,
options=SendOptions(reply_to=reply_ref, notify=False, thread_id=thread_id),
)
async def send_with_resume(
cfg: TelegramBridgeConfig,
enqueue: Callable[
@@ -268,6 +299,7 @@ async def send_with_resume(
RunContext | None,
int | None,
tuple[int, int | None] | None,
MessageRef | None,
],
Awaitable[None],
],
@@ -292,6 +324,14 @@ async def send_with_resume(
notify=False,
)
return
progress_ref = await _send_queued_progress(
cfg,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=thread_id,
resume_token=resume,
context=running_task.context,
)
await enqueue(
chat_id,
user_msg_id,
@@ -300,6 +340,7 @@ async def send_with_resume(
running_task.context,
thread_id,
session_key,
progress_ref,
)
@@ -459,6 +500,7 @@ async def run_main_loop(
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
| None = None,
engine_override: EngineId | None = None,
progress_ref: MessageRef | None = None,
) -> None:
topic_key = (
(chat_id, thread_id)
@@ -491,6 +533,7 @@ async def run_main_loop(
engine_override=engine_override,
thread_id=thread_id,
show_resume_line=show_resume_line,
progress_ref=progress_ref,
)
async def run_thread_job(job: ThreadJob) -> None:
@@ -504,6 +547,8 @@ async def run_main_loop(
job.session_key,
None,
scheduler.note_thread_known,
None,
job.progress_ref,
)
scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
@@ -673,6 +718,14 @@ async def run_main_loop(
engine_override,
)
return
progress_ref = await _send_queued_progress(
cfg,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=msg.thread_id,
resume_token=resume_token,
context=context,
)
await scheduler.enqueue_resume(
chat_id,
user_msg_id,
@@ -681,6 +734,7 @@ async def run_main_loop(
context,
msg.thread_id,
chat_session_key,
progress_ref,
)
async def handle_prompt_upload(
@@ -735,7 +789,9 @@ async def run_main_loop(
async for msg in poller(cfg):
if isinstance(msg, TelegramCallbackQuery):
if msg.data == CANCEL_CALLBACK_DATA:
tg.start_soon(handle_callback_cancel, cfg, msg, running_tasks)
tg.start_soon(
handle_callback_cancel, cfg, msg, running_tasks, scheduler
)
else:
tg.start_soon(
cfg.bot.answer_callback_query,
@@ -798,7 +854,7 @@ async def run_main_loop(
continue
if is_cancel_command(text):
tg.start_soon(handle_cancel, cfg, msg, running_tasks)
tg.start_soon(handle_cancel, cfg, msg, running_tasks, scheduler)
continue
command_id, args_text = _parse_slash_command(text)
@@ -1017,6 +1073,14 @@ async def run_main_loop(
engine_override,
)
else:
progress_ref = await _send_queued_progress(
cfg,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=msg.thread_id,
resume_token=resume_token,
context=context,
)
await scheduler.enqueue_resume(
chat_id,
user_msg_id,
@@ -1025,6 +1089,7 @@ async def run_main_loop(
context,
msg.thread_id,
chat_session_key,
progress_ref,
)
finally:
await cfg.exec_cfg.transport.close()
+109 -6
View File
@@ -44,6 +44,7 @@ from takopi.markdown import MarkdownPresenter
from takopi.model import ResumeToken
from takopi.progress import ProgressTracker
from takopi.router import AutoRouter, RunnerEntry
from takopi.scheduler import ThreadScheduler
from takopi.transport_runtime import TransportRuntime
from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait
from takopi.telegram.types import (
@@ -68,6 +69,12 @@ def _make_router(runner) -> AutoRouter:
)
class _NoopTaskGroup:
def start_soon(self, func, *args: Any) -> None:
_ = func, args
return None
class _FakeTransport:
def __init__(self, progress_ready: anyio.Event | None = None) -> None:
self._next_id = 1
@@ -849,6 +856,42 @@ async def test_handle_cancel_only_cancels_matching_progress_message() -> None:
assert len(transport.send_calls) == 0
@pytest.mark.anyio
async def test_handle_cancel_cancels_queued_job() -> None:
transport = _FakeTransport()
cfg = _make_cfg(transport)
async def _noop_run_job(_) -> None:
return None
scheduler = ThreadScheduler(task_group=_NoopTaskGroup(), run_job=_noop_run_job)
progress_id = 55
progress_ref = MessageRef(channel_id=123, message_id=progress_id)
resume = ResumeToken(engine=CODEX_ENGINE, value="sid")
await scheduler.enqueue_resume(
chat_id=123,
user_msg_id=10,
text="queued",
resume_token=resume,
progress_ref=progress_ref,
)
msg = TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=10,
text="/cancel",
reply_to_message_id=progress_id,
reply_to_text=None,
sender_id=123,
)
await handle_cancel(cfg, msg, {}, scheduler)
assert transport.edit_calls
assert "cancelled" in transport.edit_calls[0]["message"].text.lower()
assert await scheduler.cancel_queued(123, progress_ref.message_id) is None
@pytest.mark.anyio
async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
payload = b"hello"
@@ -998,6 +1041,43 @@ async def test_handle_callback_cancel_cancels_running_task() -> None:
assert bot.callback_calls[-1]["text"] == "cancelling..."
@pytest.mark.anyio
async def test_handle_callback_cancel_cancels_queued_job() -> None:
transport = _FakeTransport()
cfg = _make_cfg(transport)
async def _noop_run_job(_) -> None:
return None
scheduler = ThreadScheduler(task_group=_NoopTaskGroup(), run_job=_noop_run_job)
progress_id = 77
progress_ref = MessageRef(channel_id=123, message_id=progress_id)
resume = ResumeToken(engine=CODEX_ENGINE, value="sid")
await scheduler.enqueue_resume(
chat_id=123,
user_msg_id=10,
text="queued",
resume_token=resume,
progress_ref=progress_ref,
)
query = TelegramCallbackQuery(
transport="telegram",
chat_id=123,
message_id=progress_id,
callback_query_id="cbq-queued",
data="takopi:cancel",
sender_id=123,
)
await handle_callback_cancel(cfg, query, {}, scheduler)
assert transport.edit_calls
assert "cancelled" in transport.edit_calls[0]["message"].text.lower()
bot = cast(_FakeBot, cfg.bot)
assert bot.callback_calls
assert bot.callback_calls[-1]["text"] == "dropped from queue."
@pytest.mark.anyio
async def test_handle_callback_cancel_without_task_acknowledges() -> None:
transport = _FakeTransport()
@@ -1249,6 +1329,7 @@ async def test_send_with_resume_waits_for_token() -> None:
RunContext | None,
int | None,
tuple[int, int | None] | None,
MessageRef | None,
]
] = []
@@ -1260,9 +1341,19 @@ async def test_send_with_resume_waits_for_token() -> None:
context: RunContext | None,
thread_id: int | None,
session_key: tuple[int, int | None] | None,
progress_ref: MessageRef | None,
) -> None:
sent.append(
(chat_id, user_msg_id, text, resume, context, thread_id, session_key)
(
chat_id,
user_msg_id,
text,
resume,
context,
thread_id,
session_key,
progress_ref,
)
)
running_task = RunningTask()
@@ -1285,8 +1376,8 @@ async def test_send_with_resume_waits_for_token() -> None:
"hello",
)
assert sent == [
(
assert len(sent) == 1
assert sent[0][:7] == (
123,
10,
"hello",
@@ -1295,8 +1386,9 @@ async def test_send_with_resume_waits_for_token() -> None:
None,
None,
)
]
assert transport.send_calls == []
assert sent[0][7] == transport.send_calls[0]["ref"]
assert transport.send_calls
assert "queued" in transport.send_calls[0]["message"].text.lower()
@pytest.mark.anyio
@@ -1312,6 +1404,7 @@ async def test_send_with_resume_reports_when_missing() -> None:
RunContext | None,
int | None,
tuple[int, int | None] | None,
MessageRef | None,
]
] = []
@@ -1323,9 +1416,19 @@ async def test_send_with_resume_reports_when_missing() -> None:
context: RunContext | None,
thread_id: int | None,
session_key: tuple[int, int | None] | None,
progress_ref: MessageRef | None,
) -> None:
sent.append(
(chat_id, user_msg_id, text, resume, context, thread_id, session_key)
(
chat_id,
user_msg_id,
text,
resume,
context,
thread_id,
session_key,
progress_ref,
)
)
running_task = RunningTask()