feat(telegram): add queued cancel placeholder (#136)

This commit is contained in:
banteg
2026-01-15 01:31:47 +04:00
committed by GitHub
parent 699ae3b38e
commit ff64741607
7 changed files with 285 additions and 33 deletions
+14 -13
View File
@@ -244,11 +244,11 @@ async def send_initial_progress(
reply_to: MessageRef, reply_to: MessageRef,
label: str, label: str,
tracker: ProgressTracker, tracker: ProgressTracker,
progress_ref: MessageRef | None = None,
resume_formatter: Callable[[ResumeToken], str] | None = None, resume_formatter: Callable[[ResumeToken], str] | None = None,
context_line: str | None = None, context_line: str | None = None,
thread_id: ThreadId | None = None, thread_id: ThreadId | None = None,
) -> ProgressMessageState: ) -> ProgressMessageState:
progress_ref: MessageRef | None = None
last_rendered: RenderedMessage | None = None last_rendered: RenderedMessage | None = None
state = tracker.snapshot( state = tracker.snapshot(
@@ -260,27 +260,26 @@ async def send_initial_progress(
elapsed_s=0.0, elapsed_s=0.0,
label=label, label=label,
) )
logger.debug( sent_ref, _ = await _send_or_edit_message(
"transport.send_message", cfg.transport,
channel_id=channel_id,
reply_to_message_id=reply_to.message_id,
rendered=initial_rendered.text,
)
progress_ref = await cfg.transport.send(
channel_id=channel_id, channel_id=channel_id,
message=initial_rendered, message=initial_rendered,
options=SendOptions(reply_to=reply_to, notify=False, thread_id=thread_id), edit_ref=progress_ref,
reply_to=reply_to,
notify=False,
replace_ref=progress_ref,
thread_id=thread_id,
) )
if progress_ref is not None: if sent_ref is not None:
last_rendered = initial_rendered last_rendered = initial_rendered
logger.debug( logger.debug(
"progress.sent", "progress.sent",
channel_id=progress_ref.channel_id, channel_id=sent_ref.channel_id,
message_id=progress_ref.message_id, message_id=sent_ref.message_id,
) )
return ProgressMessageState( return ProgressMessageState(
ref=progress_ref, ref=sent_ref,
last_rendered=last_rendered, last_rendered=last_rendered,
) )
@@ -396,6 +395,7 @@ async def handle_message(
running_tasks: RunningTasks | None = None, running_tasks: RunningTasks | None = None,
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
| None = None, | None = None,
progress_ref: MessageRef | None = None,
clock: Callable[[], float] = time.monotonic, clock: Callable[[], float] = time.monotonic,
) -> None: ) -> None:
logger.info( logger.info(
@@ -422,6 +422,7 @@ async def handle_message(
reply_to=user_ref, reply_to=user_ref,
label="starting", label="starting",
tracker=progress_tracker, tracker=progress_tracker,
progress_ref=progress_ref,
resume_formatter=runner.format_resume, resume_formatter=runner.format_resume,
context_line=context_line, context_line=context_line,
thread_id=incoming.thread_id, thread_id=incoming.thread_id,
+31 -1
View File
@@ -10,7 +10,7 @@ import anyio
from .context import RunContext from .context import RunContext
from .logging import get_logger from .logging import get_logger
from .model import ResumeToken from .model import ResumeToken
from .transport import ChannelId, MessageId, ThreadId from .transport import ChannelId, MessageId, MessageRef, ThreadId
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -24,6 +24,7 @@ class ThreadJob:
context: RunContext | None = None context: RunContext | None = None
thread_id: ThreadId | None = None thread_id: ThreadId | None = None
session_key: tuple[int, int | None] | None = None session_key: tuple[int, int | None] | None = None
progress_ref: MessageRef | None = None
RunJob = Callable[[ThreadJob], Awaitable[None]] RunJob = Callable[[ThreadJob], Awaitable[None]]
@@ -41,6 +42,7 @@ class ThreadScheduler:
self._run_job = run_job self._run_job = run_job
self._lock = anyio.Lock() self._lock = anyio.Lock()
self._pending_by_thread: dict[str, deque[ThreadJob]] = {} self._pending_by_thread: dict[str, deque[ThreadJob]] = {}
self._queued_by_progress: dict[tuple[ChannelId, MessageId], ThreadJob] = {}
self._active_threads: set[str] = set() self._active_threads: set[str] = set()
self._busy_until: dict[str, anyio.Event] = {} self._busy_until: dict[str, anyio.Event] = {}
@@ -64,6 +66,9 @@ class ThreadScheduler:
queue = deque() queue = deque()
self._pending_by_thread[key] = queue self._pending_by_thread[key] = queue
queue.append(job) queue.append(job)
if job.progress_ref is not None:
progress_key = (job.chat_id, job.progress_ref.message_id)
self._queued_by_progress[progress_key] = job
if key in self._active_threads: if key in self._active_threads:
return return
self._active_threads.add(key) self._active_threads.add(key)
@@ -78,6 +83,7 @@ class ThreadScheduler:
context: RunContext | None = None, context: RunContext | None = None,
thread_id: ThreadId | None = None, thread_id: ThreadId | None = None,
session_key: tuple[int, int | None] | None = None, session_key: tuple[int, int | None] | None = None,
progress_ref: MessageRef | None = None,
) -> None: ) -> None:
await self.enqueue( await self.enqueue(
ThreadJob( ThreadJob(
@@ -88,9 +94,30 @@ class ThreadScheduler:
context=context, context=context,
thread_id=thread_id, thread_id=thread_id,
session_key=session_key, session_key=session_key,
progress_ref=progress_ref,
) )
) )
async def cancel_queued(
self, chat_id: ChannelId, progress_msg_id: MessageId
) -> ThreadJob | None:
progress_key = (chat_id, progress_msg_id)
async with self._lock:
job = self._queued_by_progress.pop(progress_key, None)
if job is None:
return None
thread_key = self.thread_key(job.resume_token)
queue = self._pending_by_thread.get(thread_key)
if queue is None:
return None
try:
queue.remove(job)
except ValueError:
return None
if not queue:
self._pending_by_thread.pop(thread_key, None)
return job
async def _clear_busy(self, key: str, done: anyio.Event) -> None: async def _clear_busy(self, key: str, done: anyio.Event) -> None:
await done.wait() await done.wait()
async with self._lock: async with self._lock:
@@ -108,6 +135,9 @@ class ThreadScheduler:
self._active_threads.discard(key) self._active_threads.discard(key)
return return
job = queue.popleft() job = queue.popleft()
if job.progress_ref is not None:
progress_key = (job.chat_id, job.progress_ref.message_id)
self._queued_by_progress.pop(progress_key, None)
if done is not None and not done.is_set(): if done is not None and not done.is_set():
await done.wait() await done.wait()
+6 -2
View File
@@ -12,6 +12,7 @@ from ..transport import MessageRef, RenderedMessage, SendOptions, Transport
from ..transport_runtime import TransportRuntime from ..transport_runtime import TransportRuntime
from ..context import RunContext from ..context import RunContext
from ..model import ResumeToken from ..model import ResumeToken
from ..scheduler import ThreadScheduler
from ..settings import ( from ..settings import (
TelegramFilesSettings, TelegramFilesSettings,
TelegramTopicsSettings, TelegramTopicsSettings,
@@ -326,20 +327,22 @@ async def handle_cancel(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
running_tasks: RunningTasks, running_tasks: RunningTasks,
scheduler: ThreadScheduler | None = None,
) -> None: ) -> None:
from .commands import handle_cancel as _handle_cancel from .commands import handle_cancel as _handle_cancel
await _handle_cancel(cfg, msg, running_tasks) await _handle_cancel(cfg, msg, running_tasks, scheduler)
async def handle_callback_cancel( async def handle_callback_cancel(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
query: TelegramCallbackQuery, query: TelegramCallbackQuery,
running_tasks: RunningTasks, running_tasks: RunningTasks,
scheduler: ThreadScheduler | None = None,
) -> None: ) -> None:
from .commands import handle_callback_cancel as _handle_callback_cancel from .commands import handle_callback_cancel as _handle_callback_cancel
await _handle_callback_cancel(cfg, query, running_tasks) await _handle_callback_cancel(cfg, query, running_tasks, scheduler)
async def send_with_resume( async def send_with_resume(
@@ -353,6 +356,7 @@ async def send_with_resume(
RunContext | None, RunContext | None,
int | None, int | None,
tuple[int, int | None] | None, tuple[int, int | None] | None,
MessageRef | None,
], ],
Awaitable[None], Awaitable[None],
], ],
+47
View File
@@ -3,7 +3,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...logging import get_logger from ...logging import get_logger
from ...progress import ProgressTracker
from ...runner_bridge import RunningTasks from ...runner_bridge import RunningTasks
from ...scheduler import ThreadJob, ThreadScheduler
from ...transport import MessageRef from ...transport import MessageRef
from ..types import TelegramCallbackQuery, TelegramIncomingMessage from ..types import TelegramCallbackQuery, TelegramIncomingMessage
from .reply import make_reply from .reply import make_reply
@@ -18,6 +20,7 @@ async def handle_cancel(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
running_tasks: RunningTasks, running_tasks: RunningTasks,
scheduler: ThreadScheduler | None = None,
) -> None: ) -> None:
reply = make_reply(cfg, msg) reply = make_reply(cfg, msg)
chat_id = msg.chat_id chat_id = msg.chat_id
@@ -33,6 +36,17 @@ async def handle_cancel(
progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id) progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id)
running_task = running_tasks.get(progress_ref) running_task = running_tasks.get(progress_ref)
if running_task is None: if running_task is None:
if scheduler is not None:
job = await scheduler.cancel_queued(chat_id, reply_id)
if job is not None:
logger.info(
"cancel.queued",
chat_id=chat_id,
progress_message_id=reply_id,
resume=job.resume_token.value,
)
await _edit_cancelled_message(cfg, progress_ref, job)
return
await reply(text="nothing is currently running for that message.") await reply(text="nothing is currently running for that message.")
return return
@@ -48,10 +62,26 @@ async def handle_callback_cancel(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
query: TelegramCallbackQuery, query: TelegramCallbackQuery,
running_tasks: RunningTasks, running_tasks: RunningTasks,
scheduler: ThreadScheduler | None = None,
) -> None: ) -> None:
progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id) progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id)
running_task = running_tasks.get(progress_ref) running_task = running_tasks.get(progress_ref)
if running_task is None: if running_task is None:
if scheduler is not None:
job = await scheduler.cancel_queued(query.chat_id, query.message_id)
if job is not None:
logger.info(
"cancel.queued",
chat_id=query.chat_id,
progress_message_id=query.message_id,
resume=job.resume_token.value,
)
await _edit_cancelled_message(cfg, progress_ref, job)
await cfg.bot.answer_callback_query(
callback_query_id=query.callback_query_id,
text="dropped from queue.",
)
return
await cfg.bot.answer_callback_query( await cfg.bot.answer_callback_query(
callback_query_id=query.callback_query_id, callback_query_id=query.callback_query_id,
text="nothing is currently running for that message.", text="nothing is currently running for that message.",
@@ -67,3 +97,20 @@ async def handle_callback_cancel(
callback_query_id=query.callback_query_id, callback_query_id=query.callback_query_id,
text="cancelling...", text="cancelling...",
) )
async def _edit_cancelled_message(
cfg: TelegramBridgeConfig,
progress_ref: MessageRef,
job: ThreadJob,
) -> None:
tracker = ProgressTracker(engine=job.resume_token.engine)
tracker.set_resume(job.resume_token)
context_line = cfg.runtime.format_context_line(job.context)
state = tracker.snapshot(context_line=context_line)
message = cfg.exec_cfg.presenter.render_progress(
state,
elapsed_s=0.0,
label="`cancelled`",
)
await cfg.exec_cfg.transport.edit(ref=progress_ref, message=message)
+2
View File
@@ -107,6 +107,7 @@ async def _run_engine(
engine_override: EngineId | None = None, engine_override: EngineId | None = None,
thread_id: int | None = None, thread_id: int | None = None,
show_resume_line: bool = True, show_resume_line: bool = True,
progress_ref: MessageRef | None = None,
) -> None: ) -> None:
reply = partial( reply = partial(
send_plain, send_plain,
@@ -176,6 +177,7 @@ async def _run_engine(
strip_resume_line=runtime.is_resume_line, strip_resume_line=runtime.is_resume_line,
running_tasks=running_tasks, running_tasks=running_tasks,
on_thread_known=on_thread_known, on_thread_known=on_thread_known,
progress_ref=progress_ref,
) )
finally: finally:
reset_run_base_dir(run_base_token) reset_run_base_dir(run_base_token)
+68 -3
View File
@@ -15,8 +15,9 @@ from ..directives import DirectiveError
from ..logging import get_logger from ..logging import get_logger
from ..model import EngineId, ResumeToken from ..model import EngineId, ResumeToken
from ..scheduler import ThreadJob, ThreadScheduler from ..scheduler import ThreadJob, ThreadScheduler
from ..progress import ProgressTracker
from ..settings import TelegramTransportSettings from ..settings import TelegramTransportSettings
from ..transport import MessageRef from ..transport import MessageRef, SendOptions
from ..transport_runtime import ResolvedMessage from ..transport_runtime import ResolvedMessage
from ..context import RunContext from ..context import RunContext
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
@@ -257,6 +258,36 @@ async def _wait_for_resume(running_task) -> ResumeToken | None:
return resume return resume
async def _send_queued_progress(
cfg: TelegramBridgeConfig,
*,
chat_id: int,
user_msg_id: int,
thread_id: int | None,
resume_token: ResumeToken,
context: RunContext | None,
) -> MessageRef | None:
tracker = ProgressTracker(engine=resume_token.engine)
tracker.set_resume(resume_token)
context_line = cfg.runtime.format_context_line(context)
state = tracker.snapshot(context_line=context_line)
message = cfg.exec_cfg.presenter.render_progress(
state,
elapsed_s=0.0,
label="queued",
)
reply_ref = MessageRef(
channel_id=chat_id,
message_id=user_msg_id,
thread_id=thread_id,
)
return await cfg.exec_cfg.transport.send(
channel_id=chat_id,
message=message,
options=SendOptions(reply_to=reply_ref, notify=False, thread_id=thread_id),
)
async def send_with_resume( async def send_with_resume(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
enqueue: Callable[ enqueue: Callable[
@@ -268,6 +299,7 @@ async def send_with_resume(
RunContext | None, RunContext | None,
int | None, int | None,
tuple[int, int | None] | None, tuple[int, int | None] | None,
MessageRef | None,
], ],
Awaitable[None], Awaitable[None],
], ],
@@ -292,6 +324,14 @@ async def send_with_resume(
notify=False, notify=False,
) )
return return
progress_ref = await _send_queued_progress(
cfg,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=thread_id,
resume_token=resume,
context=running_task.context,
)
await enqueue( await enqueue(
chat_id, chat_id,
user_msg_id, user_msg_id,
@@ -300,6 +340,7 @@ async def send_with_resume(
running_task.context, running_task.context,
thread_id, thread_id,
session_key, session_key,
progress_ref,
) )
@@ -459,6 +500,7 @@ async def run_main_loop(
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
| None = None, | None = None,
engine_override: EngineId | None = None, engine_override: EngineId | None = None,
progress_ref: MessageRef | None = None,
) -> None: ) -> None:
topic_key = ( topic_key = (
(chat_id, thread_id) (chat_id, thread_id)
@@ -491,6 +533,7 @@ async def run_main_loop(
engine_override=engine_override, engine_override=engine_override,
thread_id=thread_id, thread_id=thread_id,
show_resume_line=show_resume_line, show_resume_line=show_resume_line,
progress_ref=progress_ref,
) )
async def run_thread_job(job: ThreadJob) -> None: async def run_thread_job(job: ThreadJob) -> None:
@@ -504,6 +547,8 @@ async def run_main_loop(
job.session_key, job.session_key,
None, None,
scheduler.note_thread_known, scheduler.note_thread_known,
None,
job.progress_ref,
) )
scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job) scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
@@ -673,6 +718,14 @@ async def run_main_loop(
engine_override, engine_override,
) )
return return
progress_ref = await _send_queued_progress(
cfg,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=msg.thread_id,
resume_token=resume_token,
context=context,
)
await scheduler.enqueue_resume( await scheduler.enqueue_resume(
chat_id, chat_id,
user_msg_id, user_msg_id,
@@ -681,6 +734,7 @@ async def run_main_loop(
context, context,
msg.thread_id, msg.thread_id,
chat_session_key, chat_session_key,
progress_ref,
) )
async def handle_prompt_upload( async def handle_prompt_upload(
@@ -735,7 +789,9 @@ async def run_main_loop(
async for msg in poller(cfg): async for msg in poller(cfg):
if isinstance(msg, TelegramCallbackQuery): if isinstance(msg, TelegramCallbackQuery):
if msg.data == CANCEL_CALLBACK_DATA: if msg.data == CANCEL_CALLBACK_DATA:
tg.start_soon(handle_callback_cancel, cfg, msg, running_tasks) tg.start_soon(
handle_callback_cancel, cfg, msg, running_tasks, scheduler
)
else: else:
tg.start_soon( tg.start_soon(
cfg.bot.answer_callback_query, cfg.bot.answer_callback_query,
@@ -798,7 +854,7 @@ async def run_main_loop(
continue continue
if is_cancel_command(text): if is_cancel_command(text):
tg.start_soon(handle_cancel, cfg, msg, running_tasks) tg.start_soon(handle_cancel, cfg, msg, running_tasks, scheduler)
continue continue
command_id, args_text = _parse_slash_command(text) command_id, args_text = _parse_slash_command(text)
@@ -1017,6 +1073,14 @@ async def run_main_loop(
engine_override, engine_override,
) )
else: else:
progress_ref = await _send_queued_progress(
cfg,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=msg.thread_id,
resume_token=resume_token,
context=context,
)
await scheduler.enqueue_resume( await scheduler.enqueue_resume(
chat_id, chat_id,
user_msg_id, user_msg_id,
@@ -1025,6 +1089,7 @@ async def run_main_loop(
context, context,
msg.thread_id, msg.thread_id,
chat_session_key, chat_session_key,
progress_ref,
) )
finally: finally:
await cfg.exec_cfg.transport.close() await cfg.exec_cfg.transport.close()
+117 -14
View File
@@ -44,6 +44,7 @@ from takopi.markdown import MarkdownPresenter
from takopi.model import ResumeToken from takopi.model import ResumeToken
from takopi.progress import ProgressTracker from takopi.progress import ProgressTracker
from takopi.router import AutoRouter, RunnerEntry from takopi.router import AutoRouter, RunnerEntry
from takopi.scheduler import ThreadScheduler
from takopi.transport_runtime import TransportRuntime from takopi.transport_runtime import TransportRuntime
from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait
from takopi.telegram.types import ( from takopi.telegram.types import (
@@ -68,6 +69,12 @@ def _make_router(runner) -> AutoRouter:
) )
class _NoopTaskGroup:
def start_soon(self, func, *args: Any) -> None:
_ = func, args
return None
class _FakeTransport: class _FakeTransport:
def __init__(self, progress_ready: anyio.Event | None = None) -> None: def __init__(self, progress_ready: anyio.Event | None = None) -> None:
self._next_id = 1 self._next_id = 1
@@ -849,6 +856,42 @@ async def test_handle_cancel_only_cancels_matching_progress_message() -> None:
assert len(transport.send_calls) == 0 assert len(transport.send_calls) == 0
@pytest.mark.anyio
async def test_handle_cancel_cancels_queued_job() -> None:
transport = _FakeTransport()
cfg = _make_cfg(transport)
async def _noop_run_job(_) -> None:
return None
scheduler = ThreadScheduler(task_group=_NoopTaskGroup(), run_job=_noop_run_job)
progress_id = 55
progress_ref = MessageRef(channel_id=123, message_id=progress_id)
resume = ResumeToken(engine=CODEX_ENGINE, value="sid")
await scheduler.enqueue_resume(
chat_id=123,
user_msg_id=10,
text="queued",
resume_token=resume,
progress_ref=progress_ref,
)
msg = TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=10,
text="/cancel",
reply_to_message_id=progress_id,
reply_to_text=None,
sender_id=123,
)
await handle_cancel(cfg, msg, {}, scheduler)
assert transport.edit_calls
assert "cancelled" in transport.edit_calls[0]["message"].text.lower()
assert await scheduler.cancel_queued(123, progress_ref.message_id) is None
@pytest.mark.anyio @pytest.mark.anyio
async def test_handle_file_put_writes_file(tmp_path: Path) -> None: async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
payload = b"hello" payload = b"hello"
@@ -998,6 +1041,43 @@ async def test_handle_callback_cancel_cancels_running_task() -> None:
assert bot.callback_calls[-1]["text"] == "cancelling..." assert bot.callback_calls[-1]["text"] == "cancelling..."
@pytest.mark.anyio
async def test_handle_callback_cancel_cancels_queued_job() -> None:
transport = _FakeTransport()
cfg = _make_cfg(transport)
async def _noop_run_job(_) -> None:
return None
scheduler = ThreadScheduler(task_group=_NoopTaskGroup(), run_job=_noop_run_job)
progress_id = 77
progress_ref = MessageRef(channel_id=123, message_id=progress_id)
resume = ResumeToken(engine=CODEX_ENGINE, value="sid")
await scheduler.enqueue_resume(
chat_id=123,
user_msg_id=10,
text="queued",
resume_token=resume,
progress_ref=progress_ref,
)
query = TelegramCallbackQuery(
transport="telegram",
chat_id=123,
message_id=progress_id,
callback_query_id="cbq-queued",
data="takopi:cancel",
sender_id=123,
)
await handle_callback_cancel(cfg, query, {}, scheduler)
assert transport.edit_calls
assert "cancelled" in transport.edit_calls[0]["message"].text.lower()
bot = cast(_FakeBot, cfg.bot)
assert bot.callback_calls
assert bot.callback_calls[-1]["text"] == "dropped from queue."
@pytest.mark.anyio @pytest.mark.anyio
async def test_handle_callback_cancel_without_task_acknowledges() -> None: async def test_handle_callback_cancel_without_task_acknowledges() -> None:
transport = _FakeTransport() transport = _FakeTransport()
@@ -1249,6 +1329,7 @@ async def test_send_with_resume_waits_for_token() -> None:
RunContext | None, RunContext | None,
int | None, int | None,
tuple[int, int | None] | None, tuple[int, int | None] | None,
MessageRef | None,
] ]
] = [] ] = []
@@ -1260,9 +1341,19 @@ async def test_send_with_resume_waits_for_token() -> None:
context: RunContext | None, context: RunContext | None,
thread_id: int | None, thread_id: int | None,
session_key: tuple[int, int | None] | None, session_key: tuple[int, int | None] | None,
progress_ref: MessageRef | None,
) -> None: ) -> None:
sent.append( sent.append(
(chat_id, user_msg_id, text, resume, context, thread_id, session_key) (
chat_id,
user_msg_id,
text,
resume,
context,
thread_id,
session_key,
progress_ref,
)
) )
running_task = RunningTask() running_task = RunningTask()
@@ -1285,18 +1376,19 @@ async def test_send_with_resume_waits_for_token() -> None:
"hello", "hello",
) )
assert sent == [ assert len(sent) == 1
( assert sent[0][:7] == (
123, 123,
10, 10,
"hello", "hello",
ResumeToken(engine=CODEX_ENGINE, value="abc123"), ResumeToken(engine=CODEX_ENGINE, value="abc123"),
None, None,
None, None,
None, None,
) )
] assert sent[0][7] == transport.send_calls[0]["ref"]
assert transport.send_calls == [] assert transport.send_calls
assert "queued" in transport.send_calls[0]["message"].text.lower()
@pytest.mark.anyio @pytest.mark.anyio
@@ -1312,6 +1404,7 @@ async def test_send_with_resume_reports_when_missing() -> None:
RunContext | None, RunContext | None,
int | None, int | None,
tuple[int, int | None] | None, tuple[int, int | None] | None,
MessageRef | None,
] ]
] = [] ] = []
@@ -1323,9 +1416,19 @@ async def test_send_with_resume_reports_when_missing() -> None:
context: RunContext | None, context: RunContext | None,
thread_id: int | None, thread_id: int | None,
session_key: tuple[int, int | None] | None, session_key: tuple[int, int | None] | None,
progress_ref: MessageRef | None,
) -> None: ) -> None:
sent.append( sent.append(
(chat_id, user_msg_id, text, resume, context, thread_id, session_key) (
chat_id,
user_msg_id,
text,
resume,
context,
thread_id,
session_key,
progress_ref,
)
) )
running_task = RunningTask() running_task = RunningTask()