feat(telegram): add queued cancel placeholder (#136)
This commit is contained in:
+14
-13
@@ -244,11 +244,11 @@ async def send_initial_progress(
|
|||||||
reply_to: MessageRef,
|
reply_to: MessageRef,
|
||||||
label: str,
|
label: str,
|
||||||
tracker: ProgressTracker,
|
tracker: ProgressTracker,
|
||||||
|
progress_ref: MessageRef | None = None,
|
||||||
resume_formatter: Callable[[ResumeToken], str] | None = None,
|
resume_formatter: Callable[[ResumeToken], str] | None = None,
|
||||||
context_line: str | None = None,
|
context_line: str | None = None,
|
||||||
thread_id: ThreadId | None = None,
|
thread_id: ThreadId | None = None,
|
||||||
) -> ProgressMessageState:
|
) -> ProgressMessageState:
|
||||||
progress_ref: MessageRef | None = None
|
|
||||||
last_rendered: RenderedMessage | None = None
|
last_rendered: RenderedMessage | None = None
|
||||||
|
|
||||||
state = tracker.snapshot(
|
state = tracker.snapshot(
|
||||||
@@ -260,27 +260,26 @@ async def send_initial_progress(
|
|||||||
elapsed_s=0.0,
|
elapsed_s=0.0,
|
||||||
label=label,
|
label=label,
|
||||||
)
|
)
|
||||||
logger.debug(
|
sent_ref, _ = await _send_or_edit_message(
|
||||||
"transport.send_message",
|
cfg.transport,
|
||||||
channel_id=channel_id,
|
|
||||||
reply_to_message_id=reply_to.message_id,
|
|
||||||
rendered=initial_rendered.text,
|
|
||||||
)
|
|
||||||
progress_ref = await cfg.transport.send(
|
|
||||||
channel_id=channel_id,
|
channel_id=channel_id,
|
||||||
message=initial_rendered,
|
message=initial_rendered,
|
||||||
options=SendOptions(reply_to=reply_to, notify=False, thread_id=thread_id),
|
edit_ref=progress_ref,
|
||||||
|
reply_to=reply_to,
|
||||||
|
notify=False,
|
||||||
|
replace_ref=progress_ref,
|
||||||
|
thread_id=thread_id,
|
||||||
)
|
)
|
||||||
if progress_ref is not None:
|
if sent_ref is not None:
|
||||||
last_rendered = initial_rendered
|
last_rendered = initial_rendered
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"progress.sent",
|
"progress.sent",
|
||||||
channel_id=progress_ref.channel_id,
|
channel_id=sent_ref.channel_id,
|
||||||
message_id=progress_ref.message_id,
|
message_id=sent_ref.message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ProgressMessageState(
|
return ProgressMessageState(
|
||||||
ref=progress_ref,
|
ref=sent_ref,
|
||||||
last_rendered=last_rendered,
|
last_rendered=last_rendered,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -396,6 +395,7 @@ async def handle_message(
|
|||||||
running_tasks: RunningTasks | None = None,
|
running_tasks: RunningTasks | None = None,
|
||||||
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
|
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
|
||||||
| None = None,
|
| None = None,
|
||||||
|
progress_ref: MessageRef | None = None,
|
||||||
clock: Callable[[], float] = time.monotonic,
|
clock: Callable[[], float] = time.monotonic,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -422,6 +422,7 @@ async def handle_message(
|
|||||||
reply_to=user_ref,
|
reply_to=user_ref,
|
||||||
label="starting",
|
label="starting",
|
||||||
tracker=progress_tracker,
|
tracker=progress_tracker,
|
||||||
|
progress_ref=progress_ref,
|
||||||
resume_formatter=runner.format_resume,
|
resume_formatter=runner.format_resume,
|
||||||
context_line=context_line,
|
context_line=context_line,
|
||||||
thread_id=incoming.thread_id,
|
thread_id=incoming.thread_id,
|
||||||
|
|||||||
+31
-1
@@ -10,7 +10,7 @@ import anyio
|
|||||||
from .context import RunContext
|
from .context import RunContext
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
from .model import ResumeToken
|
from .model import ResumeToken
|
||||||
from .transport import ChannelId, MessageId, ThreadId
|
from .transport import ChannelId, MessageId, MessageRef, ThreadId
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -24,6 +24,7 @@ class ThreadJob:
|
|||||||
context: RunContext | None = None
|
context: RunContext | None = None
|
||||||
thread_id: ThreadId | None = None
|
thread_id: ThreadId | None = None
|
||||||
session_key: tuple[int, int | None] | None = None
|
session_key: tuple[int, int | None] | None = None
|
||||||
|
progress_ref: MessageRef | None = None
|
||||||
|
|
||||||
|
|
||||||
RunJob = Callable[[ThreadJob], Awaitable[None]]
|
RunJob = Callable[[ThreadJob], Awaitable[None]]
|
||||||
@@ -41,6 +42,7 @@ class ThreadScheduler:
|
|||||||
self._run_job = run_job
|
self._run_job = run_job
|
||||||
self._lock = anyio.Lock()
|
self._lock = anyio.Lock()
|
||||||
self._pending_by_thread: dict[str, deque[ThreadJob]] = {}
|
self._pending_by_thread: dict[str, deque[ThreadJob]] = {}
|
||||||
|
self._queued_by_progress: dict[tuple[ChannelId, MessageId], ThreadJob] = {}
|
||||||
self._active_threads: set[str] = set()
|
self._active_threads: set[str] = set()
|
||||||
self._busy_until: dict[str, anyio.Event] = {}
|
self._busy_until: dict[str, anyio.Event] = {}
|
||||||
|
|
||||||
@@ -64,6 +66,9 @@ class ThreadScheduler:
|
|||||||
queue = deque()
|
queue = deque()
|
||||||
self._pending_by_thread[key] = queue
|
self._pending_by_thread[key] = queue
|
||||||
queue.append(job)
|
queue.append(job)
|
||||||
|
if job.progress_ref is not None:
|
||||||
|
progress_key = (job.chat_id, job.progress_ref.message_id)
|
||||||
|
self._queued_by_progress[progress_key] = job
|
||||||
if key in self._active_threads:
|
if key in self._active_threads:
|
||||||
return
|
return
|
||||||
self._active_threads.add(key)
|
self._active_threads.add(key)
|
||||||
@@ -78,6 +83,7 @@ class ThreadScheduler:
|
|||||||
context: RunContext | None = None,
|
context: RunContext | None = None,
|
||||||
thread_id: ThreadId | None = None,
|
thread_id: ThreadId | None = None,
|
||||||
session_key: tuple[int, int | None] | None = None,
|
session_key: tuple[int, int | None] | None = None,
|
||||||
|
progress_ref: MessageRef | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.enqueue(
|
await self.enqueue(
|
||||||
ThreadJob(
|
ThreadJob(
|
||||||
@@ -88,9 +94,30 @@ class ThreadScheduler:
|
|||||||
context=context,
|
context=context,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
|
progress_ref=progress_ref,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def cancel_queued(
|
||||||
|
self, chat_id: ChannelId, progress_msg_id: MessageId
|
||||||
|
) -> ThreadJob | None:
|
||||||
|
progress_key = (chat_id, progress_msg_id)
|
||||||
|
async with self._lock:
|
||||||
|
job = self._queued_by_progress.pop(progress_key, None)
|
||||||
|
if job is None:
|
||||||
|
return None
|
||||||
|
thread_key = self.thread_key(job.resume_token)
|
||||||
|
queue = self._pending_by_thread.get(thread_key)
|
||||||
|
if queue is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
queue.remove(job)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
if not queue:
|
||||||
|
self._pending_by_thread.pop(thread_key, None)
|
||||||
|
return job
|
||||||
|
|
||||||
async def _clear_busy(self, key: str, done: anyio.Event) -> None:
|
async def _clear_busy(self, key: str, done: anyio.Event) -> None:
|
||||||
await done.wait()
|
await done.wait()
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
@@ -108,6 +135,9 @@ class ThreadScheduler:
|
|||||||
self._active_threads.discard(key)
|
self._active_threads.discard(key)
|
||||||
return
|
return
|
||||||
job = queue.popleft()
|
job = queue.popleft()
|
||||||
|
if job.progress_ref is not None:
|
||||||
|
progress_key = (job.chat_id, job.progress_ref.message_id)
|
||||||
|
self._queued_by_progress.pop(progress_key, None)
|
||||||
|
|
||||||
if done is not None and not done.is_set():
|
if done is not None and not done.is_set():
|
||||||
await done.wait()
|
await done.wait()
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from ..transport import MessageRef, RenderedMessage, SendOptions, Transport
|
|||||||
from ..transport_runtime import TransportRuntime
|
from ..transport_runtime import TransportRuntime
|
||||||
from ..context import RunContext
|
from ..context import RunContext
|
||||||
from ..model import ResumeToken
|
from ..model import ResumeToken
|
||||||
|
from ..scheduler import ThreadScheduler
|
||||||
from ..settings import (
|
from ..settings import (
|
||||||
TelegramFilesSettings,
|
TelegramFilesSettings,
|
||||||
TelegramTopicsSettings,
|
TelegramTopicsSettings,
|
||||||
@@ -326,20 +327,22 @@ async def handle_cancel(
|
|||||||
cfg: TelegramBridgeConfig,
|
cfg: TelegramBridgeConfig,
|
||||||
msg: TelegramIncomingMessage,
|
msg: TelegramIncomingMessage,
|
||||||
running_tasks: RunningTasks,
|
running_tasks: RunningTasks,
|
||||||
|
scheduler: ThreadScheduler | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
from .commands import handle_cancel as _handle_cancel
|
from .commands import handle_cancel as _handle_cancel
|
||||||
|
|
||||||
await _handle_cancel(cfg, msg, running_tasks)
|
await _handle_cancel(cfg, msg, running_tasks, scheduler)
|
||||||
|
|
||||||
|
|
||||||
async def handle_callback_cancel(
|
async def handle_callback_cancel(
|
||||||
cfg: TelegramBridgeConfig,
|
cfg: TelegramBridgeConfig,
|
||||||
query: TelegramCallbackQuery,
|
query: TelegramCallbackQuery,
|
||||||
running_tasks: RunningTasks,
|
running_tasks: RunningTasks,
|
||||||
|
scheduler: ThreadScheduler | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
from .commands import handle_callback_cancel as _handle_callback_cancel
|
from .commands import handle_callback_cancel as _handle_callback_cancel
|
||||||
|
|
||||||
await _handle_callback_cancel(cfg, query, running_tasks)
|
await _handle_callback_cancel(cfg, query, running_tasks, scheduler)
|
||||||
|
|
||||||
|
|
||||||
async def send_with_resume(
|
async def send_with_resume(
|
||||||
@@ -353,6 +356,7 @@ async def send_with_resume(
|
|||||||
RunContext | None,
|
RunContext | None,
|
||||||
int | None,
|
int | None,
|
||||||
tuple[int, int | None] | None,
|
tuple[int, int | None] | None,
|
||||||
|
MessageRef | None,
|
||||||
],
|
],
|
||||||
Awaitable[None],
|
Awaitable[None],
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -3,7 +3,9 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...logging import get_logger
|
from ...logging import get_logger
|
||||||
|
from ...progress import ProgressTracker
|
||||||
from ...runner_bridge import RunningTasks
|
from ...runner_bridge import RunningTasks
|
||||||
|
from ...scheduler import ThreadJob, ThreadScheduler
|
||||||
from ...transport import MessageRef
|
from ...transport import MessageRef
|
||||||
from ..types import TelegramCallbackQuery, TelegramIncomingMessage
|
from ..types import TelegramCallbackQuery, TelegramIncomingMessage
|
||||||
from .reply import make_reply
|
from .reply import make_reply
|
||||||
@@ -18,6 +20,7 @@ async def handle_cancel(
|
|||||||
cfg: TelegramBridgeConfig,
|
cfg: TelegramBridgeConfig,
|
||||||
msg: TelegramIncomingMessage,
|
msg: TelegramIncomingMessage,
|
||||||
running_tasks: RunningTasks,
|
running_tasks: RunningTasks,
|
||||||
|
scheduler: ThreadScheduler | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
reply = make_reply(cfg, msg)
|
reply = make_reply(cfg, msg)
|
||||||
chat_id = msg.chat_id
|
chat_id = msg.chat_id
|
||||||
@@ -33,6 +36,17 @@ async def handle_cancel(
|
|||||||
progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id)
|
progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id)
|
||||||
running_task = running_tasks.get(progress_ref)
|
running_task = running_tasks.get(progress_ref)
|
||||||
if running_task is None:
|
if running_task is None:
|
||||||
|
if scheduler is not None:
|
||||||
|
job = await scheduler.cancel_queued(chat_id, reply_id)
|
||||||
|
if job is not None:
|
||||||
|
logger.info(
|
||||||
|
"cancel.queued",
|
||||||
|
chat_id=chat_id,
|
||||||
|
progress_message_id=reply_id,
|
||||||
|
resume=job.resume_token.value,
|
||||||
|
)
|
||||||
|
await _edit_cancelled_message(cfg, progress_ref, job)
|
||||||
|
return
|
||||||
await reply(text="nothing is currently running for that message.")
|
await reply(text="nothing is currently running for that message.")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -48,10 +62,26 @@ async def handle_callback_cancel(
|
|||||||
cfg: TelegramBridgeConfig,
|
cfg: TelegramBridgeConfig,
|
||||||
query: TelegramCallbackQuery,
|
query: TelegramCallbackQuery,
|
||||||
running_tasks: RunningTasks,
|
running_tasks: RunningTasks,
|
||||||
|
scheduler: ThreadScheduler | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id)
|
progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id)
|
||||||
running_task = running_tasks.get(progress_ref)
|
running_task = running_tasks.get(progress_ref)
|
||||||
if running_task is None:
|
if running_task is None:
|
||||||
|
if scheduler is not None:
|
||||||
|
job = await scheduler.cancel_queued(query.chat_id, query.message_id)
|
||||||
|
if job is not None:
|
||||||
|
logger.info(
|
||||||
|
"cancel.queued",
|
||||||
|
chat_id=query.chat_id,
|
||||||
|
progress_message_id=query.message_id,
|
||||||
|
resume=job.resume_token.value,
|
||||||
|
)
|
||||||
|
await _edit_cancelled_message(cfg, progress_ref, job)
|
||||||
|
await cfg.bot.answer_callback_query(
|
||||||
|
callback_query_id=query.callback_query_id,
|
||||||
|
text="dropped from queue.",
|
||||||
|
)
|
||||||
|
return
|
||||||
await cfg.bot.answer_callback_query(
|
await cfg.bot.answer_callback_query(
|
||||||
callback_query_id=query.callback_query_id,
|
callback_query_id=query.callback_query_id,
|
||||||
text="nothing is currently running for that message.",
|
text="nothing is currently running for that message.",
|
||||||
@@ -67,3 +97,20 @@ async def handle_callback_cancel(
|
|||||||
callback_query_id=query.callback_query_id,
|
callback_query_id=query.callback_query_id,
|
||||||
text="cancelling...",
|
text="cancelling...",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _edit_cancelled_message(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
progress_ref: MessageRef,
|
||||||
|
job: ThreadJob,
|
||||||
|
) -> None:
|
||||||
|
tracker = ProgressTracker(engine=job.resume_token.engine)
|
||||||
|
tracker.set_resume(job.resume_token)
|
||||||
|
context_line = cfg.runtime.format_context_line(job.context)
|
||||||
|
state = tracker.snapshot(context_line=context_line)
|
||||||
|
message = cfg.exec_cfg.presenter.render_progress(
|
||||||
|
state,
|
||||||
|
elapsed_s=0.0,
|
||||||
|
label="`cancelled`",
|
||||||
|
)
|
||||||
|
await cfg.exec_cfg.transport.edit(ref=progress_ref, message=message)
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ async def _run_engine(
|
|||||||
engine_override: EngineId | None = None,
|
engine_override: EngineId | None = None,
|
||||||
thread_id: int | None = None,
|
thread_id: int | None = None,
|
||||||
show_resume_line: bool = True,
|
show_resume_line: bool = True,
|
||||||
|
progress_ref: MessageRef | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
reply = partial(
|
reply = partial(
|
||||||
send_plain,
|
send_plain,
|
||||||
@@ -176,6 +177,7 @@ async def _run_engine(
|
|||||||
strip_resume_line=runtime.is_resume_line,
|
strip_resume_line=runtime.is_resume_line,
|
||||||
running_tasks=running_tasks,
|
running_tasks=running_tasks,
|
||||||
on_thread_known=on_thread_known,
|
on_thread_known=on_thread_known,
|
||||||
|
progress_ref=progress_ref,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
reset_run_base_dir(run_base_token)
|
reset_run_base_dir(run_base_token)
|
||||||
|
|||||||
@@ -15,8 +15,9 @@ from ..directives import DirectiveError
|
|||||||
from ..logging import get_logger
|
from ..logging import get_logger
|
||||||
from ..model import EngineId, ResumeToken
|
from ..model import EngineId, ResumeToken
|
||||||
from ..scheduler import ThreadJob, ThreadScheduler
|
from ..scheduler import ThreadJob, ThreadScheduler
|
||||||
|
from ..progress import ProgressTracker
|
||||||
from ..settings import TelegramTransportSettings
|
from ..settings import TelegramTransportSettings
|
||||||
from ..transport import MessageRef
|
from ..transport import MessageRef, SendOptions
|
||||||
from ..transport_runtime import ResolvedMessage
|
from ..transport_runtime import ResolvedMessage
|
||||||
from ..context import RunContext
|
from ..context import RunContext
|
||||||
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
|
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
|
||||||
@@ -257,6 +258,36 @@ async def _wait_for_resume(running_task) -> ResumeToken | None:
|
|||||||
return resume
|
return resume
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_queued_progress(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
*,
|
||||||
|
chat_id: int,
|
||||||
|
user_msg_id: int,
|
||||||
|
thread_id: int | None,
|
||||||
|
resume_token: ResumeToken,
|
||||||
|
context: RunContext | None,
|
||||||
|
) -> MessageRef | None:
|
||||||
|
tracker = ProgressTracker(engine=resume_token.engine)
|
||||||
|
tracker.set_resume(resume_token)
|
||||||
|
context_line = cfg.runtime.format_context_line(context)
|
||||||
|
state = tracker.snapshot(context_line=context_line)
|
||||||
|
message = cfg.exec_cfg.presenter.render_progress(
|
||||||
|
state,
|
||||||
|
elapsed_s=0.0,
|
||||||
|
label="queued",
|
||||||
|
)
|
||||||
|
reply_ref = MessageRef(
|
||||||
|
channel_id=chat_id,
|
||||||
|
message_id=user_msg_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
return await cfg.exec_cfg.transport.send(
|
||||||
|
channel_id=chat_id,
|
||||||
|
message=message,
|
||||||
|
options=SendOptions(reply_to=reply_ref, notify=False, thread_id=thread_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def send_with_resume(
|
async def send_with_resume(
|
||||||
cfg: TelegramBridgeConfig,
|
cfg: TelegramBridgeConfig,
|
||||||
enqueue: Callable[
|
enqueue: Callable[
|
||||||
@@ -268,6 +299,7 @@ async def send_with_resume(
|
|||||||
RunContext | None,
|
RunContext | None,
|
||||||
int | None,
|
int | None,
|
||||||
tuple[int, int | None] | None,
|
tuple[int, int | None] | None,
|
||||||
|
MessageRef | None,
|
||||||
],
|
],
|
||||||
Awaitable[None],
|
Awaitable[None],
|
||||||
],
|
],
|
||||||
@@ -292,6 +324,14 @@ async def send_with_resume(
|
|||||||
notify=False,
|
notify=False,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
progress_ref = await _send_queued_progress(
|
||||||
|
cfg,
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_msg_id=user_msg_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
resume_token=resume,
|
||||||
|
context=running_task.context,
|
||||||
|
)
|
||||||
await enqueue(
|
await enqueue(
|
||||||
chat_id,
|
chat_id,
|
||||||
user_msg_id,
|
user_msg_id,
|
||||||
@@ -300,6 +340,7 @@ async def send_with_resume(
|
|||||||
running_task.context,
|
running_task.context,
|
||||||
thread_id,
|
thread_id,
|
||||||
session_key,
|
session_key,
|
||||||
|
progress_ref,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -459,6 +500,7 @@ async def run_main_loop(
|
|||||||
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
|
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
|
||||||
| None = None,
|
| None = None,
|
||||||
engine_override: EngineId | None = None,
|
engine_override: EngineId | None = None,
|
||||||
|
progress_ref: MessageRef | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
topic_key = (
|
topic_key = (
|
||||||
(chat_id, thread_id)
|
(chat_id, thread_id)
|
||||||
@@ -491,6 +533,7 @@ async def run_main_loop(
|
|||||||
engine_override=engine_override,
|
engine_override=engine_override,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
show_resume_line=show_resume_line,
|
show_resume_line=show_resume_line,
|
||||||
|
progress_ref=progress_ref,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_thread_job(job: ThreadJob) -> None:
|
async def run_thread_job(job: ThreadJob) -> None:
|
||||||
@@ -504,6 +547,8 @@ async def run_main_loop(
|
|||||||
job.session_key,
|
job.session_key,
|
||||||
None,
|
None,
|
||||||
scheduler.note_thread_known,
|
scheduler.note_thread_known,
|
||||||
|
None,
|
||||||
|
job.progress_ref,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
|
scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
|
||||||
@@ -673,6 +718,14 @@ async def run_main_loop(
|
|||||||
engine_override,
|
engine_override,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
progress_ref = await _send_queued_progress(
|
||||||
|
cfg,
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_msg_id=user_msg_id,
|
||||||
|
thread_id=msg.thread_id,
|
||||||
|
resume_token=resume_token,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
await scheduler.enqueue_resume(
|
await scheduler.enqueue_resume(
|
||||||
chat_id,
|
chat_id,
|
||||||
user_msg_id,
|
user_msg_id,
|
||||||
@@ -681,6 +734,7 @@ async def run_main_loop(
|
|||||||
context,
|
context,
|
||||||
msg.thread_id,
|
msg.thread_id,
|
||||||
chat_session_key,
|
chat_session_key,
|
||||||
|
progress_ref,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_prompt_upload(
|
async def handle_prompt_upload(
|
||||||
@@ -735,7 +789,9 @@ async def run_main_loop(
|
|||||||
async for msg in poller(cfg):
|
async for msg in poller(cfg):
|
||||||
if isinstance(msg, TelegramCallbackQuery):
|
if isinstance(msg, TelegramCallbackQuery):
|
||||||
if msg.data == CANCEL_CALLBACK_DATA:
|
if msg.data == CANCEL_CALLBACK_DATA:
|
||||||
tg.start_soon(handle_callback_cancel, cfg, msg, running_tasks)
|
tg.start_soon(
|
||||||
|
handle_callback_cancel, cfg, msg, running_tasks, scheduler
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
tg.start_soon(
|
tg.start_soon(
|
||||||
cfg.bot.answer_callback_query,
|
cfg.bot.answer_callback_query,
|
||||||
@@ -798,7 +854,7 @@ async def run_main_loop(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if is_cancel_command(text):
|
if is_cancel_command(text):
|
||||||
tg.start_soon(handle_cancel, cfg, msg, running_tasks)
|
tg.start_soon(handle_cancel, cfg, msg, running_tasks, scheduler)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
command_id, args_text = _parse_slash_command(text)
|
command_id, args_text = _parse_slash_command(text)
|
||||||
@@ -1017,6 +1073,14 @@ async def run_main_loop(
|
|||||||
engine_override,
|
engine_override,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
progress_ref = await _send_queued_progress(
|
||||||
|
cfg,
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_msg_id=user_msg_id,
|
||||||
|
thread_id=msg.thread_id,
|
||||||
|
resume_token=resume_token,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
await scheduler.enqueue_resume(
|
await scheduler.enqueue_resume(
|
||||||
chat_id,
|
chat_id,
|
||||||
user_msg_id,
|
user_msg_id,
|
||||||
@@ -1025,6 +1089,7 @@ async def run_main_loop(
|
|||||||
context,
|
context,
|
||||||
msg.thread_id,
|
msg.thread_id,
|
||||||
chat_session_key,
|
chat_session_key,
|
||||||
|
progress_ref,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
await cfg.exec_cfg.transport.close()
|
await cfg.exec_cfg.transport.close()
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from takopi.markdown import MarkdownPresenter
|
|||||||
from takopi.model import ResumeToken
|
from takopi.model import ResumeToken
|
||||||
from takopi.progress import ProgressTracker
|
from takopi.progress import ProgressTracker
|
||||||
from takopi.router import AutoRouter, RunnerEntry
|
from takopi.router import AutoRouter, RunnerEntry
|
||||||
|
from takopi.scheduler import ThreadScheduler
|
||||||
from takopi.transport_runtime import TransportRuntime
|
from takopi.transport_runtime import TransportRuntime
|
||||||
from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait
|
from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait
|
||||||
from takopi.telegram.types import (
|
from takopi.telegram.types import (
|
||||||
@@ -68,6 +69,12 @@ def _make_router(runner) -> AutoRouter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _NoopTaskGroup:
|
||||||
|
def start_soon(self, func, *args: Any) -> None:
|
||||||
|
_ = func, args
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class _FakeTransport:
|
class _FakeTransport:
|
||||||
def __init__(self, progress_ready: anyio.Event | None = None) -> None:
|
def __init__(self, progress_ready: anyio.Event | None = None) -> None:
|
||||||
self._next_id = 1
|
self._next_id = 1
|
||||||
@@ -849,6 +856,42 @@ async def test_handle_cancel_only_cancels_matching_progress_message() -> None:
|
|||||||
assert len(transport.send_calls) == 0
|
assert len(transport.send_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_handle_cancel_cancels_queued_job() -> None:
|
||||||
|
transport = _FakeTransport()
|
||||||
|
cfg = _make_cfg(transport)
|
||||||
|
|
||||||
|
async def _noop_run_job(_) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
scheduler = ThreadScheduler(task_group=_NoopTaskGroup(), run_job=_noop_run_job)
|
||||||
|
progress_id = 55
|
||||||
|
progress_ref = MessageRef(channel_id=123, message_id=progress_id)
|
||||||
|
resume = ResumeToken(engine=CODEX_ENGINE, value="sid")
|
||||||
|
await scheduler.enqueue_resume(
|
||||||
|
chat_id=123,
|
||||||
|
user_msg_id=10,
|
||||||
|
text="queued",
|
||||||
|
resume_token=resume,
|
||||||
|
progress_ref=progress_ref,
|
||||||
|
)
|
||||||
|
msg = TelegramIncomingMessage(
|
||||||
|
transport="telegram",
|
||||||
|
chat_id=123,
|
||||||
|
message_id=10,
|
||||||
|
text="/cancel",
|
||||||
|
reply_to_message_id=progress_id,
|
||||||
|
reply_to_text=None,
|
||||||
|
sender_id=123,
|
||||||
|
)
|
||||||
|
|
||||||
|
await handle_cancel(cfg, msg, {}, scheduler)
|
||||||
|
|
||||||
|
assert transport.edit_calls
|
||||||
|
assert "cancelled" in transport.edit_calls[0]["message"].text.lower()
|
||||||
|
assert await scheduler.cancel_queued(123, progress_ref.message_id) is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
|
async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
|
||||||
payload = b"hello"
|
payload = b"hello"
|
||||||
@@ -998,6 +1041,43 @@ async def test_handle_callback_cancel_cancels_running_task() -> None:
|
|||||||
assert bot.callback_calls[-1]["text"] == "cancelling..."
|
assert bot.callback_calls[-1]["text"] == "cancelling..."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_handle_callback_cancel_cancels_queued_job() -> None:
|
||||||
|
transport = _FakeTransport()
|
||||||
|
cfg = _make_cfg(transport)
|
||||||
|
|
||||||
|
async def _noop_run_job(_) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
scheduler = ThreadScheduler(task_group=_NoopTaskGroup(), run_job=_noop_run_job)
|
||||||
|
progress_id = 77
|
||||||
|
progress_ref = MessageRef(channel_id=123, message_id=progress_id)
|
||||||
|
resume = ResumeToken(engine=CODEX_ENGINE, value="sid")
|
||||||
|
await scheduler.enqueue_resume(
|
||||||
|
chat_id=123,
|
||||||
|
user_msg_id=10,
|
||||||
|
text="queued",
|
||||||
|
resume_token=resume,
|
||||||
|
progress_ref=progress_ref,
|
||||||
|
)
|
||||||
|
query = TelegramCallbackQuery(
|
||||||
|
transport="telegram",
|
||||||
|
chat_id=123,
|
||||||
|
message_id=progress_id,
|
||||||
|
callback_query_id="cbq-queued",
|
||||||
|
data="takopi:cancel",
|
||||||
|
sender_id=123,
|
||||||
|
)
|
||||||
|
|
||||||
|
await handle_callback_cancel(cfg, query, {}, scheduler)
|
||||||
|
|
||||||
|
assert transport.edit_calls
|
||||||
|
assert "cancelled" in transport.edit_calls[0]["message"].text.lower()
|
||||||
|
bot = cast(_FakeBot, cfg.bot)
|
||||||
|
assert bot.callback_calls
|
||||||
|
assert bot.callback_calls[-1]["text"] == "dropped from queue."
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_handle_callback_cancel_without_task_acknowledges() -> None:
|
async def test_handle_callback_cancel_without_task_acknowledges() -> None:
|
||||||
transport = _FakeTransport()
|
transport = _FakeTransport()
|
||||||
@@ -1249,6 +1329,7 @@ async def test_send_with_resume_waits_for_token() -> None:
|
|||||||
RunContext | None,
|
RunContext | None,
|
||||||
int | None,
|
int | None,
|
||||||
tuple[int, int | None] | None,
|
tuple[int, int | None] | None,
|
||||||
|
MessageRef | None,
|
||||||
]
|
]
|
||||||
] = []
|
] = []
|
||||||
|
|
||||||
@@ -1260,9 +1341,19 @@ async def test_send_with_resume_waits_for_token() -> None:
|
|||||||
context: RunContext | None,
|
context: RunContext | None,
|
||||||
thread_id: int | None,
|
thread_id: int | None,
|
||||||
session_key: tuple[int, int | None] | None,
|
session_key: tuple[int, int | None] | None,
|
||||||
|
progress_ref: MessageRef | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
sent.append(
|
sent.append(
|
||||||
(chat_id, user_msg_id, text, resume, context, thread_id, session_key)
|
(
|
||||||
|
chat_id,
|
||||||
|
user_msg_id,
|
||||||
|
text,
|
||||||
|
resume,
|
||||||
|
context,
|
||||||
|
thread_id,
|
||||||
|
session_key,
|
||||||
|
progress_ref,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
running_task = RunningTask()
|
running_task = RunningTask()
|
||||||
@@ -1285,8 +1376,8 @@ async def test_send_with_resume_waits_for_token() -> None:
|
|||||||
"hello",
|
"hello",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert sent == [
|
assert len(sent) == 1
|
||||||
(
|
assert sent[0][:7] == (
|
||||||
123,
|
123,
|
||||||
10,
|
10,
|
||||||
"hello",
|
"hello",
|
||||||
@@ -1295,8 +1386,9 @@ async def test_send_with_resume_waits_for_token() -> None:
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
]
|
assert sent[0][7] == transport.send_calls[0]["ref"]
|
||||||
assert transport.send_calls == []
|
assert transport.send_calls
|
||||||
|
assert "queued" in transport.send_calls[0]["message"].text.lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -1312,6 +1404,7 @@ async def test_send_with_resume_reports_when_missing() -> None:
|
|||||||
RunContext | None,
|
RunContext | None,
|
||||||
int | None,
|
int | None,
|
||||||
tuple[int, int | None] | None,
|
tuple[int, int | None] | None,
|
||||||
|
MessageRef | None,
|
||||||
]
|
]
|
||||||
] = []
|
] = []
|
||||||
|
|
||||||
@@ -1323,9 +1416,19 @@ async def test_send_with_resume_reports_when_missing() -> None:
|
|||||||
context: RunContext | None,
|
context: RunContext | None,
|
||||||
thread_id: int | None,
|
thread_id: int | None,
|
||||||
session_key: tuple[int, int | None] | None,
|
session_key: tuple[int, int | None] | None,
|
||||||
|
progress_ref: MessageRef | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
sent.append(
|
sent.append(
|
||||||
(chat_id, user_msg_id, text, resume, context, thread_id, session_key)
|
(
|
||||||
|
chat_id,
|
||||||
|
user_msg_id,
|
||||||
|
text,
|
||||||
|
resume,
|
||||||
|
context,
|
||||||
|
thread_id,
|
||||||
|
session_key,
|
||||||
|
progress_ref,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
running_task = RunningTask()
|
running_task = RunningTask()
|
||||||
|
|||||||
Reference in New Issue
Block a user