feat(telegram): add inline cancel button (#79)

This commit is contained in:
banteg
2026-01-10 03:33:57 +04:00
committed by GitHub
parent 801d04cfdf
commit 5c1635ccb5
8 changed files with 397 additions and 18 deletions
+2 -1
View File
@@ -1,2 +1,3 @@
after you finish work, commit with a conventional message. only commit the files you edited. after you finish work, commit with a conventional message. only commit the files you edited.
run `just check` and fix any errors before committing. always run `just check` before code commits.
if you fix anything from `just check`, rerun it and confirm it passes before committing.
+8 -1
View File
@@ -1,10 +1,17 @@
"""Telegram-specific clients and adapters.""" """Telegram-specific clients and adapters."""
from .client import parse_incoming_update, poll_incoming from .client import parse_incoming_update, poll_incoming
from .types import TelegramIncomingMessage, TelegramVoice from .types import (
TelegramCallbackQuery,
TelegramIncomingMessage,
TelegramIncomingUpdate,
TelegramVoice,
)
__all__ = [ __all__ = [
"TelegramCallbackQuery",
"TelegramIncomingMessage", "TelegramIncomingMessage",
"TelegramIncomingUpdate",
"TelegramVoice", "TelegramVoice",
"parse_incoming_update", "parse_incoming_update",
"poll_incoming", "poll_incoming",
+71 -8
View File
@@ -41,7 +41,11 @@ from ..plugins import COMMAND_GROUP, list_entrypoints
from ..utils.paths import reset_run_base_dir, set_run_base_dir from ..utils.paths import reset_run_base_dir, set_run_base_dir
from ..transport_runtime import TransportRuntime from ..transport_runtime import TransportRuntime
from .client import BotClient, poll_incoming from .client import BotClient, poll_incoming
from .types import TelegramIncomingMessage from .types import (
TelegramCallbackQuery,
TelegramIncomingMessage,
TelegramIncomingUpdate,
)
from .render import prepare_telegram from .render import prepare_telegram
from .transcribe import transcribe_audio from .transcribe import transcribe_audio
@@ -51,6 +55,11 @@ _MAX_BOT_COMMANDS = 100
_OPENAI_AUDIO_MAX_BYTES = 25 * 1024 * 1024 _OPENAI_AUDIO_MAX_BYTES = 25 * 1024 * 1024
_OPENAI_TRANSCRIPTION_MODEL = "gpt-4o-mini-transcribe" _OPENAI_TRANSCRIPTION_MODEL = "gpt-4o-mini-transcribe"
_OPENAI_TRANSCRIPTION_CHUNKING = "auto" _OPENAI_TRANSCRIPTION_CHUNKING = "auto"
CANCEL_CALLBACK_DATA = "takopi:cancel"
CANCEL_MARKUP = {
"inline_keyboard": [[{"text": "cancel", "callback_data": CANCEL_CALLBACK_DATA}]]
}
CLEAR_MARKUP = {"inline_keyboard": []}
def _is_cancel_command(text: str) -> bool: def _is_cancel_command(text: str) -> bool:
@@ -218,7 +227,11 @@ class TelegramPresenter:
state, elapsed_s=elapsed_s, label=label state, elapsed_s=elapsed_s, label=label
) )
text, entities = prepare_telegram(parts) text, entities = prepare_telegram(parts)
return RenderedMessage(text=text, extra={"entities": entities}) reply_markup = CLEAR_MARKUP if _is_cancelled_label(label) else CANCEL_MARKUP
return RenderedMessage(
text=text,
extra={"entities": entities, "reply_markup": reply_markup},
)
def render_final( def render_final(
self, self,
@@ -232,7 +245,17 @@ class TelegramPresenter:
state, elapsed_s=elapsed_s, status=status, answer=answer state, elapsed_s=elapsed_s, status=status, answer=answer
) )
text, entities = prepare_telegram(parts) text, entities = prepare_telegram(parts)
return RenderedMessage(text=text, extra={"entities": entities}) return RenderedMessage(
text=text,
extra={"entities": entities, "reply_markup": CLEAR_MARKUP},
)
def _is_cancelled_label(label: str) -> bool:
stripped = label.strip()
if stripped.startswith("`") and stripped.endswith("`") and len(stripped) >= 2:
stripped = stripped[1:-1]
return stripped.lower() == "cancelled"
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -276,6 +299,7 @@ class TelegramTransport:
) )
entities = message.extra.get("entities") entities = message.extra.get("entities")
parse_mode = message.extra.get("parse_mode") parse_mode = message.extra.get("parse_mode")
reply_markup = message.extra.get("reply_markup")
sent = await self._bot.send_message( sent = await self._bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=message.text, text=message.text,
@@ -283,6 +307,7 @@ class TelegramTransport:
disable_notification=disable_notification, disable_notification=disable_notification,
entities=entities, entities=entities,
parse_mode=parse_mode, parse_mode=parse_mode,
reply_markup=reply_markup,
replace_message_id=replace_message_id, replace_message_id=replace_message_id,
) )
if sent is None: if sent is None:
@@ -303,12 +328,14 @@ class TelegramTransport:
message_id = _as_int(ref.message_id, label="message_id") message_id = _as_int(ref.message_id, label="message_id")
entities = message.extra.get("entities") entities = message.extra.get("entities")
parse_mode = message.extra.get("parse_mode") parse_mode = message.extra.get("parse_mode")
reply_markup = message.extra.get("reply_markup")
edited = await self._bot.edit_message_text( edited = await self._bot.edit_message_text(
chat_id=chat_id, chat_id=chat_id,
message_id=message_id, message_id=message_id,
text=message.text, text=message.text,
entities=entities, entities=entities,
parse_mode=parse_mode, parse_mode=parse_mode,
reply_markup=reply_markup,
wait=wait, wait=wait,
) )
if edited is None: if edited is None:
@@ -378,7 +405,9 @@ async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int |
drained = 0 drained = 0
while True: while True:
updates = await cfg.bot.get_updates( updates = await cfg.bot.get_updates(
offset=offset, timeout_s=0, allowed_updates=["message"] offset=offset,
timeout_s=0,
allowed_updates=["message", "callback_query"],
) )
if updates is None: if updates is None:
logger.info("startup.backlog.failed") logger.info("startup.backlog.failed")
@@ -394,7 +423,7 @@ async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int |
async def poll_updates( async def poll_updates(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
) -> AsyncIterator[TelegramIncomingMessage]: ) -> AsyncIterator[TelegramIncomingUpdate]:
offset: int | None = None offset: int | None = None
offset = await _drain_backlog(cfg, offset) offset = await _drain_backlog(cfg, offset)
await _send_startup(cfg) await _send_startup(cfg)
@@ -571,6 +600,31 @@ async def _handle_cancel(
running_task.cancel_requested.set() running_task.cancel_requested.set()
async def _handle_callback_cancel(
cfg: TelegramBridgeConfig,
query: TelegramCallbackQuery,
running_tasks: RunningTasks,
) -> None:
progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id)
running_task = running_tasks.get(progress_ref)
if running_task is None:
await cfg.bot.answer_callback_query(
callback_query_id=query.callback_query_id,
text="nothing is currently running for that message.",
)
return
logger.info(
"cancel.requested",
chat_id=query.chat_id,
progress_message_id=query.message_id,
)
running_task.cancel_requested.set()
await cfg.bot.answer_callback_query(
callback_query_id=query.callback_query_id,
text="cancelling...",
)
async def _wait_for_resume(running_task: RunningTask) -> ResumeToken | None: async def _wait_for_resume(running_task: RunningTask) -> ResumeToken | None:
if running_task.resume is not None: if running_task.resume is not None:
return running_task.resume return running_task.resume
@@ -963,9 +1017,9 @@ async def _dispatch_command(
async def run_main_loop( async def run_main_loop(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
poller: Callable[[TelegramBridgeConfig], AsyncIterator[TelegramIncomingMessage]] = ( poller: Callable[
poll_updates [TelegramBridgeConfig], AsyncIterator[TelegramIncomingUpdate]
), ] = poll_updates,
*, *,
watch_config: bool | None = None, watch_config: bool | None = None,
default_engine_override: str | None = None, default_engine_override: str | None = None,
@@ -1061,6 +1115,15 @@ async def run_main_loop(
scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job) scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
async for msg in poller(cfg): async for msg in poller(cfg):
if isinstance(msg, TelegramCallbackQuery):
if msg.data == CANCEL_CALLBACK_DATA:
tg.start_soon(_handle_callback_cancel, cfg, msg, running_tasks)
else:
tg.start_soon(
cfg.bot.answer_callback_query,
msg.callback_query_id,
)
continue
text = msg.text text = msg.text
if msg.voice is not None: if msg.voice is not None:
text = await _transcribe_voice(cfg, msg) text = await _transcribe_voice(cfg, msg)
+120 -5
View File
@@ -19,7 +19,12 @@ import httpx
import anyio import anyio
from ..logging import get_logger from ..logging import get_logger
from .types import TelegramIncomingMessage, TelegramVoice from .types import (
TelegramCallbackQuery,
TelegramIncomingMessage,
TelegramIncomingUpdate,
TelegramVoice,
)
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -49,10 +54,26 @@ def parse_incoming_update(
*, *,
chat_id: int | None = None, chat_id: int | None = None,
chat_ids: set[int] | None = None, chat_ids: set[int] | None = None,
) -> TelegramIncomingMessage | None: ) -> TelegramIncomingUpdate | None:
msg = update.get("message") msg = update.get("message")
if not isinstance(msg, dict): if isinstance(msg, dict):
return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids)
callback_query = update.get("callback_query")
if isinstance(callback_query, dict):
return _parse_callback_query(
callback_query,
chat_id=chat_id,
chat_ids=chat_ids,
)
return None return None
def _parse_incoming_message(
msg: dict[str, Any],
*,
chat_id: int | None = None,
chat_ids: set[int] | None = None,
) -> TelegramIncomingMessage | None:
text = msg.get("text") text = msg.get("text")
voice_payload: TelegramVoice | None = None voice_payload: TelegramVoice | None = None
if not isinstance(text, str): if not isinstance(text, str):
@@ -123,16 +144,62 @@ def parse_incoming_update(
) )
def _parse_callback_query(
query: dict[str, Any],
*,
chat_id: int | None = None,
chat_ids: set[int] | None = None,
) -> TelegramCallbackQuery | None:
callback_id = query.get("id")
if not isinstance(callback_id, str) or not callback_id:
return None
msg = query.get("message")
if not isinstance(msg, dict):
return None
chat = msg.get("chat")
if not isinstance(chat, dict):
return None
msg_chat_id = chat.get("id")
if not isinstance(msg_chat_id, int):
return None
allowed = chat_ids
if allowed is None and chat_id is not None:
allowed = {chat_id}
if allowed is not None and msg_chat_id not in allowed:
return None
message_id = msg.get("message_id")
if not isinstance(message_id, int):
return None
data = query.get("data") if isinstance(query.get("data"), str) else None
sender = query.get("from")
sender_id = (
sender.get("id")
if isinstance(sender, dict) and isinstance(sender.get("id"), int)
else None
)
return TelegramCallbackQuery(
transport="telegram",
chat_id=msg_chat_id,
message_id=message_id,
callback_query_id=callback_id,
data=data,
sender_id=sender_id,
raw=query,
)
async def poll_incoming( async def poll_incoming(
bot: BotClient, bot: BotClient,
*, *,
chat_id: int | None = None, chat_id: int | None = None,
chat_ids: Iterable[int] | Callable[[], Iterable[int]] | None = None, chat_ids: Iterable[int] | Callable[[], Iterable[int]] | None = None,
offset: int | None = None, offset: int | None = None,
) -> AsyncIterator[TelegramIncomingMessage]: ) -> AsyncIterator[TelegramIncomingUpdate]:
while True: while True:
updates = await bot.get_updates( updates = await bot.get_updates(
offset=offset, timeout_s=50, allowed_updates=["message"] offset=offset,
timeout_s=50,
allowed_updates=["message", "callback_query"],
) )
if updates is None: if updates is None:
logger.info("loop.get_updates.failed") logger.info("loop.get_updates.failed")
@@ -172,6 +239,7 @@ class BotClient(Protocol):
disable_notification: bool | None = False, disable_notification: bool | None = False,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
reply_markup: dict[str, Any] | None = None,
*, *,
replace_message_id: int | None = None, replace_message_id: int | None = None,
) -> dict | None: ... ) -> dict | None: ...
@@ -183,6 +251,7 @@ class BotClient(Protocol):
text: str, text: str,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
reply_markup: dict[str, Any] | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict | None: ... ) -> dict | None: ...
@@ -203,6 +272,13 @@ class BotClient(Protocol):
async def get_me(self) -> dict | None: ... async def get_me(self) -> dict | None: ...
async def answer_callback_query(
self,
callback_query_id: str,
text: str | None = None,
show_alert: bool | None = None,
) -> bool: ...
if TYPE_CHECKING: if TYPE_CHECKING:
from anyio.abc import TaskGroup from anyio.abc import TaskGroup
@@ -647,6 +723,7 @@ class TelegramClient:
disable_notification: bool | None = False, disable_notification: bool | None = False,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
reply_markup: dict[str, Any] | None = None,
*, *,
replace_message_id: int | None = None, replace_message_id: int | None = None,
) -> dict | None: ) -> dict | None:
@@ -659,6 +736,7 @@ class TelegramClient:
disable_notification=disable_notification, disable_notification=disable_notification,
entities=entities, entities=entities,
parse_mode=parse_mode, parse_mode=parse_mode,
reply_markup=reply_markup,
replace_message_id=replace_message_id, replace_message_id=replace_message_id,
) )
params: dict[str, Any] = {"chat_id": chat_id, "text": text} params: dict[str, Any] = {"chat_id": chat_id, "text": text}
@@ -670,6 +748,8 @@ class TelegramClient:
params["entities"] = entities params["entities"] = entities
if parse_mode is not None: if parse_mode is not None:
params["parse_mode"] = parse_mode params["parse_mode"] = parse_mode
if reply_markup is not None:
params["reply_markup"] = reply_markup
result = await self._post("sendMessage", params) result = await self._post("sendMessage", params)
return result if isinstance(result, dict) else None return result if isinstance(result, dict) else None
@@ -697,6 +777,7 @@ class TelegramClient:
text: str, text: str,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
reply_markup: dict[str, Any] | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict | None: ) -> dict | None:
@@ -708,6 +789,7 @@ class TelegramClient:
text=text, text=text,
entities=entities, entities=entities,
parse_mode=parse_mode, parse_mode=parse_mode,
reply_markup=reply_markup,
wait=wait, wait=wait,
) )
params: dict[str, Any] = { params: dict[str, Any] = {
@@ -719,6 +801,8 @@ class TelegramClient:
params["entities"] = entities params["entities"] = entities
if parse_mode is not None: if parse_mode is not None:
params["parse_mode"] = parse_mode params["parse_mode"] = parse_mode
if reply_markup is not None:
params["reply_markup"] = reply_markup
result = await self._post("editMessageText", params) result = await self._post("editMessageText", params)
return result if isinstance(result, dict) else None return result if isinstance(result, dict) else None
@@ -806,3 +890,34 @@ class TelegramClient:
priority=SEND_PRIORITY, priority=SEND_PRIORITY,
chat_id=None, chat_id=None,
) )
async def answer_callback_query(
self,
callback_query_id: str,
text: str | None = None,
show_alert: bool | None = None,
) -> bool:
async def execute() -> bool:
if self._client_override is not None:
return await self._client_override.answer_callback_query(
callback_query_id=callback_query_id,
text=text,
show_alert=show_alert,
)
params: dict[str, Any] = {"callback_query_id": callback_query_id}
if text is not None:
params["text"] = text
if show_alert is not None:
params["show_alert"] = show_alert
result = await self._post("answerCallbackQuery", params)
return bool(result)
return bool(
await self.enqueue_op(
key=self.unique_key("answer_callback_query"),
label="answer_callback_query",
execute=execute,
priority=SEND_PRIORITY,
chat_id=None,
)
)
+14
View File
@@ -24,3 +24,17 @@ class TelegramIncomingMessage:
sender_id: int | None sender_id: int | None
voice: TelegramVoice | None = None voice: TelegramVoice | None = None
raw: dict[str, Any] | None = None raw: dict[str, Any] | None = None
@dataclass(frozen=True, slots=True)
class TelegramCallbackQuery:
transport: str
chat_id: int
message_id: int
callback_query_id: str
data: str | None
sender_id: int | None
raw: dict[str, Any] | None = None
TelegramIncomingUpdate = TelegramIncomingMessage | TelegramCallbackQuery
+135 -1
View File
@@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import cast
import anyio import anyio
import pytest import pytest
@@ -8,8 +9,10 @@ import takopi.telegram.bridge as bridge
from takopi.directives import parse_directives from takopi.directives import parse_directives
from takopi.telegram.bridge import ( from takopi.telegram.bridge import (
TelegramBridgeConfig, TelegramBridgeConfig,
TelegramPresenter,
TelegramTransport, TelegramTransport,
_build_bot_commands, _build_bot_commands,
_handle_callback_cancel,
_handle_cancel, _handle_cancel,
_is_cancel_command, _is_cancel_command,
_send_with_resume, _send_with_resume,
@@ -20,10 +23,11 @@ from takopi.config import ProjectConfig, ProjectsConfig, empty_projects_config
from takopi.runner_bridge import ExecBridgeConfig, RunningTask from takopi.runner_bridge import ExecBridgeConfig, RunningTask
from takopi.markdown import MarkdownPresenter from takopi.markdown import MarkdownPresenter
from takopi.model import EngineId, ResumeToken from takopi.model import EngineId, ResumeToken
from takopi.progress import ProgressTracker
from takopi.router import AutoRouter, RunnerEntry from takopi.router import AutoRouter, RunnerEntry
from takopi.transport_runtime import TransportRuntime from takopi.transport_runtime import TransportRuntime
from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait
from takopi.telegram.types import TelegramIncomingMessage from takopi.telegram.types import TelegramCallbackQuery, TelegramIncomingMessage
from takopi.transport import MessageRef, RenderedMessage, SendOptions from takopi.transport import MessageRef, RenderedMessage, SendOptions
from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints
@@ -91,6 +95,7 @@ class _FakeTransport:
class _FakeBot: class _FakeBot:
def __init__(self) -> None: def __init__(self) -> None:
self.command_calls: list[dict] = [] self.command_calls: list[dict] = []
self.callback_calls: list[dict] = []
self.send_calls: list[dict] = [] self.send_calls: list[dict] = []
self.edit_calls: list[dict] = [] self.edit_calls: list[dict] = []
self.delete_calls: list[dict] = [] self.delete_calls: list[dict] = []
@@ -122,6 +127,7 @@ class _FakeBot:
disable_notification: bool | None = False, disable_notification: bool | None = False,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
reply_markup: dict | None = None,
*, *,
replace_message_id: int | None = None, replace_message_id: int | None = None,
) -> dict: ) -> dict:
@@ -133,6 +139,7 @@ class _FakeBot:
"disable_notification": disable_notification, "disable_notification": disable_notification,
"entities": entities, "entities": entities,
"parse_mode": parse_mode, "parse_mode": parse_mode,
"reply_markup": reply_markup,
"replace_message_id": replace_message_id, "replace_message_id": replace_message_id,
} }
) )
@@ -145,6 +152,7 @@ class _FakeBot:
text: str, text: str,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
reply_markup: dict | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict: ) -> dict:
@@ -155,6 +163,7 @@ class _FakeBot:
"text": text, "text": text,
"entities": entities, "entities": entities,
"parse_mode": parse_mode, "parse_mode": parse_mode,
"reply_markup": reply_markup,
"wait": wait, "wait": wait,
} }
) )
@@ -186,6 +195,21 @@ class _FakeBot:
async def close(self) -> None: async def close(self) -> None:
return None return None
async def answer_callback_query(
self,
callback_query_id: str,
text: str | None = None,
show_alert: bool | None = None,
) -> bool:
self.callback_calls.append(
{
"callback_query_id": callback_query_id,
"text": text,
"show_alert": show_alert,
}
)
return True
def _make_cfg( def _make_cfg(
transport: _FakeTransport, runner: ScriptRunner | None = None transport: _FakeTransport, runner: ScriptRunner | None = None
@@ -356,6 +380,35 @@ def test_build_bot_commands_caps_total() -> None:
assert any(cmd["command"] == "cancel" for cmd in commands) assert any(cmd["command"] == "cancel" for cmd in commands)
def test_telegram_presenter_progress_shows_cancel_button() -> None:
presenter = TelegramPresenter()
state = ProgressTracker(engine="codex").snapshot()
rendered = presenter.render_progress(state, elapsed_s=0.0)
reply_markup = rendered.extra["reply_markup"]
assert reply_markup["inline_keyboard"][0][0]["text"] == "cancel"
assert reply_markup["inline_keyboard"][0][0]["callback_data"] == "takopi:cancel"
def test_telegram_presenter_clears_button_on_cancelled() -> None:
presenter = TelegramPresenter()
state = ProgressTracker(engine="codex").snapshot()
rendered = presenter.render_progress(state, elapsed_s=0.0, label="`cancelled`")
assert rendered.extra["reply_markup"]["inline_keyboard"] == []
def test_telegram_presenter_final_clears_button() -> None:
presenter = TelegramPresenter()
state = ProgressTracker(engine="codex").snapshot()
rendered = presenter.render_final(state, elapsed_s=0.0, status="done", answer="ok")
assert rendered.extra["reply_markup"]["inline_keyboard"] == []
@pytest.mark.anyio @pytest.mark.anyio
async def test_telegram_transport_passes_replace_and_wait() -> None: async def test_telegram_transport_passes_replace_and_wait() -> None:
bot = _FakeBot() bot = _FakeBot()
@@ -380,6 +433,28 @@ async def test_telegram_transport_passes_replace_and_wait() -> None:
assert bot.edit_calls[0]["wait"] is False assert bot.edit_calls[0]["wait"] is False
@pytest.mark.anyio
async def test_telegram_transport_passes_reply_markup() -> None:
bot = _FakeBot()
transport = TelegramTransport(bot)
markup = {"inline_keyboard": []}
await transport.send(
channel_id=123,
message=RenderedMessage(text="hello", extra={"reply_markup": markup}),
)
assert bot.send_calls
assert bot.send_calls[0]["reply_markup"] == markup
ref = MessageRef(channel_id=123, message_id=1)
await transport.edit(
ref=ref,
message=RenderedMessage(text="edit", extra={"reply_markup": markup}),
)
assert bot.edit_calls
assert bot.edit_calls[0]["reply_markup"] == markup
@pytest.mark.anyio @pytest.mark.anyio
async def test_telegram_transport_edit_wait_false_returns_ref() -> None: async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
class _OutboxBot: class _OutboxBot:
@@ -410,9 +485,11 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
disable_notification: bool | None = False, disable_notification: bool | None = False,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
reply_markup: dict | None = None,
*, *,
replace_message_id: int | None = None, replace_message_id: int | None = None,
) -> dict | None: ) -> dict | None:
_ = reply_markup
return None return None
async def edit_message_text( async def edit_message_text(
@@ -422,6 +499,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
text: str, text: str,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
reply_markup: dict | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict | None: ) -> dict | None:
@@ -432,6 +510,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
"text": text, "text": text,
"entities": entities, "entities": entities,
"parse_mode": parse_mode, "parse_mode": parse_mode,
"reply_markup": reply_markup,
"wait": wait, "wait": wait,
} }
) )
@@ -461,6 +540,15 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
async def close(self) -> None: async def close(self) -> None:
return None return None
async def answer_callback_query(
self,
callback_query_id: str,
text: str | None = None,
show_alert: bool | None = None,
) -> bool:
_ = callback_query_id, text, show_alert
return True
bot = _OutboxBot() bot = _OutboxBot()
transport = TelegramTransport(bot) transport = TelegramTransport(bot)
ref = MessageRef(channel_id=123, message_id=1) ref = MessageRef(channel_id=123, message_id=1)
@@ -590,6 +678,52 @@ async def test_handle_cancel_only_cancels_matching_progress_message() -> None:
assert len(transport.send_calls) == 0 assert len(transport.send_calls) == 0
@pytest.mark.anyio
async def test_handle_callback_cancel_cancels_running_task() -> None:
transport = _FakeTransport()
cfg = _make_cfg(transport)
progress_id = 42
running_task = RunningTask()
running_tasks = {MessageRef(channel_id=123, message_id=progress_id): running_task}
query = TelegramCallbackQuery(
transport="telegram",
chat_id=123,
message_id=progress_id,
callback_query_id="cbq-1",
data="takopi:cancel",
sender_id=123,
)
await _handle_callback_cancel(cfg, query, running_tasks)
assert running_task.cancel_requested.is_set() is True
assert len(transport.send_calls) == 0
bot = cast(_FakeBot, cfg.bot)
assert bot.callback_calls
assert bot.callback_calls[-1]["text"] == "cancelling..."
@pytest.mark.anyio
async def test_handle_callback_cancel_without_task_acknowledges() -> None:
transport = _FakeTransport()
cfg = _make_cfg(transport)
query = TelegramCallbackQuery(
transport="telegram",
chat_id=123,
message_id=99,
callback_query_id="cbq-2",
data="takopi:cancel",
sender_id=123,
)
await _handle_callback_cancel(cfg, query, {})
assert len(transport.send_calls) == 0
bot = cast(_FakeBot, cfg.bot)
assert bot.callback_calls
assert "nothing is currently running" in bot.callback_calls[-1]["text"].lower()
def test_cancel_command_accepts_extra_text() -> None: def test_cancel_command_accepts_extra_text() -> None:
assert _is_cancel_command("/cancel now") is True assert _is_cancel_command("/cancel now") is True
assert _is_cancel_command("/cancel@takopi please") is True assert _is_cancel_command("/cancel@takopi please") is True
+31 -1
View File
@@ -1,4 +1,8 @@
from takopi.telegram import parse_incoming_update from takopi.telegram import (
TelegramCallbackQuery,
TelegramIncomingMessage,
parse_incoming_update,
)
def test_parse_incoming_update_maps_fields() -> None: def test_parse_incoming_update_maps_fields() -> None:
@@ -15,6 +19,7 @@ def test_parse_incoming_update_maps_fields() -> None:
msg = parse_incoming_update(update, chat_id=123) msg = parse_incoming_update(update, chat_id=123)
assert msg is not None assert msg is not None
assert isinstance(msg, TelegramIncomingMessage)
assert msg.transport == "telegram" assert msg.transport == "telegram"
assert msg.chat_id == 123 assert msg.chat_id == 123
assert msg.message_id == 10 assert msg.message_id == 10
@@ -66,9 +71,34 @@ def test_parse_incoming_update_voice_message() -> None:
msg = parse_incoming_update(update, chat_id=123) msg = parse_incoming_update(update, chat_id=123)
assert msg is not None assert msg is not None
assert isinstance(msg, TelegramIncomingMessage)
assert msg.text == "" assert msg.text == ""
assert msg.voice is not None assert msg.voice is not None
assert msg.voice.file_id == "voice-id" assert msg.voice.file_id == "voice-id"
assert msg.voice.mime_type == "audio/ogg" assert msg.voice.mime_type == "audio/ogg"
assert msg.voice.file_size == 1234 assert msg.voice.file_size == 1234
assert msg.voice.duration == 3 assert msg.voice.duration == 3
def test_parse_incoming_update_callback_query() -> None:
update = {
"update_id": 1,
"callback_query": {
"id": "cbq-1",
"data": "takopi:cancel",
"from": {"id": 321},
"message": {
"message_id": 55,
"chat": {"id": 123},
},
},
}
msg = parse_incoming_update(update, chat_id=123)
assert isinstance(msg, TelegramCallbackQuery)
assert msg.transport == "telegram"
assert msg.chat_id == 123
assert msg.message_id == 55
assert msg.callback_query_id == "cbq-1"
assert msg.data == "takopi:cancel"
assert msg.sender_id == 321
+15
View File
@@ -22,6 +22,7 @@ class _FakeBot:
disable_notification: bool | None = False, disable_notification: bool | None = False,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
reply_markup: dict | None = None,
*, *,
replace_message_id: int | None = None, replace_message_id: int | None = None,
) -> dict: ) -> dict:
@@ -29,6 +30,7 @@ class _FakeBot:
_ = disable_notification _ = disable_notification
_ = entities _ = entities
_ = parse_mode _ = parse_mode
_ = reply_markup
_ = replace_message_id _ = replace_message_id
self.calls.append("send_message") self.calls.append("send_message")
return {"message_id": 1} return {"message_id": 1}
@@ -40,6 +42,7 @@ class _FakeBot:
text: str, text: str,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
reply_markup: dict | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict: ) -> dict:
@@ -47,6 +50,7 @@ class _FakeBot:
_ = message_id _ = message_id
_ = entities _ = entities
_ = parse_mode _ = parse_mode
_ = reply_markup
_ = wait _ = wait
self.calls.append("edit_message_text") self.calls.append("edit_message_text")
self.edit_calls.append(text) self.edit_calls.append(text)
@@ -106,6 +110,15 @@ class _FakeBot:
async def get_me(self) -> dict | None: async def get_me(self) -> dict | None:
return {"id": 1} return {"id": 1}
async def answer_callback_query(
self,
callback_query_id: str,
text: str | None = None,
show_alert: bool | None = None,
) -> bool:
_ = callback_query_id, text, show_alert
return True
@pytest.mark.anyio @pytest.mark.anyio
async def test_edits_coalesce_latest() -> None: async def test_edits_coalesce_latest() -> None:
@@ -123,6 +136,7 @@ async def test_edits_coalesce_latest() -> None:
text: str, text: str,
entities: list[dict] | None = None, entities: list[dict] | None = None,
parse_mode: str | None = None, parse_mode: str | None = None,
reply_markup: dict | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict: ) -> dict:
@@ -136,6 +150,7 @@ async def test_edits_coalesce_latest() -> None:
text=text, text=text,
entities=entities, entities=entities,
parse_mode=parse_mode, parse_mode=parse_mode,
reply_markup=reply_markup,
wait=wait, wait=wait,
) )