diff --git a/src/takopi/cli.py b/src/takopi/cli.py index d1cbe63..02f180e 100644 --- a/src/takopi/cli.py +++ b/src/takopi/cli.py @@ -105,11 +105,7 @@ def _default_engine_for_setup( if settings is None or config_path is None: return "codex" value = settings.default_engine - if not isinstance(value, str) or not value.strip(): - raise ConfigError( - f"Invalid `default_engine` in {config_path}; expected a non-empty string." - ) - return value.strip() + return value def _config_path_display(path: Path) -> str: @@ -235,9 +231,12 @@ def _run_auto_router( default_engine_override=default_engine_override, reserved=("cancel",), ) - transport_config = settings.transport_config( - settings.transport, config_path=config_path - ) + if settings.transport == "telegram": + transport_config = settings.transports.telegram + else: + transport_config = settings.transport_config( + settings.transport, config_path=config_path + ) lock_token = transport_backend.lock_token( transport_config=transport_config, config_path=config_path, diff --git a/src/takopi/router.py b/src/takopi/router.py index f08804b..c94a4a0 100644 --- a/src/takopi/router.py +++ b/src/takopi/router.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Iterable +from typing import Iterable, Literal, TypeAlias from .model import EngineId, ResumeToken from .runner import Runner @@ -17,13 +17,22 @@ class RunnerUnavailableError(RuntimeError): self.issue = issue +EngineStatus: TypeAlias = Literal["ok", "missing_cli", "bad_config", "load_error"] + + @dataclass(frozen=True, slots=True) class RunnerEntry: engine: EngineId runner: Runner - available: bool = True + status: EngineStatus = "ok" issue: str | None = None + @property + def available(self) -> bool: + # "bad_config" means we ignored user config and built the runner with defaults. + # The engine is still runnable, but a warning should be surfaced to the user. + return self.status in {"ok", "bad_config"} + class AutoRouter: def __init__( diff --git a/src/takopi/runtime_loader.py b/src/takopi/runtime_loader.py index acbba3b..afecb65 100644 --- a/src/takopi/runtime_loader.py +++ b/src/takopi/runtime_loader.py @@ -9,7 +9,7 @@ from .backends import EngineBackend from .config import ConfigError, ProjectsConfig from .engines import get_backend, list_backend_ids from .logging import get_logger -from .router import AutoRouter, RunnerEntry +from .router import AutoRouter, EngineStatus, RunnerEntry from .settings import TakopiSettings from .transport_runtime import TransportRuntime @@ -83,6 +83,7 @@ def build_router( for backend in backends: engine_id = backend.id issue: str | None = None + status: EngineStatus = "ok" engine_cfg: dict try: engine_cfg = settings.engine_config(engine_id, config_path=config_path) @@ -90,6 +91,7 @@ def build_router( if engine_id == default_engine: raise issue = str(exc) + status = "bad_config" engine_cfg = {} try: @@ -104,26 +106,31 @@ def build_router( except Exception as fallback_exc: warnings.append(f"{engine_id}: {issue or str(fallback_exc)}") continue + status = "bad_config" else: + status = "load_error" warnings.append(f"{engine_id}: {issue}") continue cmd = backend.cli_cmd or backend.id if shutil.which(cmd) is None: - issue = issue or f"{cmd} not found on PATH" + status = "missing_cli" + if issue: + issue = f"{issue}; {cmd} not found on PATH" + else: + issue = f"{cmd} not found on PATH" - if issue and engine_id == default_engine: + if status != "ok" and engine_id == default_engine: raise ConfigError(f"Default engine {engine_id!r} unavailable: {issue}") - available = issue is None - if issue and engine_id != default_engine: + if status != "ok" and engine_id != default_engine: warnings.append(f"{engine_id}: {issue}") entries.append( RunnerEntry( engine=engine_id, runner=runner, - available=available, + status=status, issue=issue, ) ) diff --git a/src/takopi/settings.py b/src/takopi/settings.py index 5cc2467..ea8ac1a 100644 --- a/src/takopi/settings.py +++ b/src/takopi/settings.py @@ -83,6 +83,14 @@ class TelegramFilesSettings(BaseModel): raise ValueError("files.uploads_dir must be a relative path") return value + @property + def max_upload_bytes(self) -> int: + return 20 * 1024 * 1024 + + @property + def max_download_bytes(self) -> int: + return 50 * 1024 * 1024 + class TelegramTransportSettings(BaseModel): model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) @@ -90,14 +98,13 @@ class TelegramTransportSettings(BaseModel): bot_token: NonEmptyStr chat_id: StrictInt voice_transcription: bool = False + voice_max_bytes: StrictInt = 10 * 1024 * 1024 topics: TelegramTopicsSettings = Field(default_factory=TelegramTopicsSettings) files: TelegramFilesSettings = Field(default_factory=TelegramFilesSettings) class TransportsSettings(BaseModel): - telegram: TelegramTransportSettings = Field( - default_factory=TelegramTransportSettings - ) + telegram: TelegramTransportSettings model_config = ConfigDict(extra="allow") @@ -132,7 +139,7 @@ class TakopiSettings(BaseSettings): projects: dict[str, ProjectSettings] = Field(default_factory=dict) transport: NonEmptyStr = "telegram" - transports: TransportsSettings = Field(default_factory=TransportsSettings) + transports: TransportsSettings plugins: PluginsSettings = Field(default_factory=PluginsSettings) @@ -310,8 +317,6 @@ def require_telegram(settings: TakopiSettings, config_path: Path) -> tuple[str, "(telegram only for now)." ) tg = settings.transports.telegram - if not tg.bot_token: - raise ConfigError(f"Missing bot token in {config_path}.") return tg.bot_token, tg.chat_id diff --git a/src/takopi/telegram/api_models.py b/src/takopi/telegram/api_models.py new file mode 100644 index 0000000..1497afe --- /dev/null +++ b/src/takopi/telegram/api_models.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import Any + +import msgspec + +__all__ = [ + "Chat", + "ChatMember", + "File", + "ForumTopic", + "Message", + "Update", + "User", +] + + +class User(msgspec.Struct, forbid_unknown_fields=False): + id: int + username: str | None = None + first_name: str | None = None + last_name: str | None = None + + +class Chat(msgspec.Struct, forbid_unknown_fields=False): + id: int + type: str + is_forum: bool | None = None + + +class ChatMember(msgspec.Struct, forbid_unknown_fields=False): + status: str + can_manage_topics: bool | None = None + + +class Message(msgspec.Struct, forbid_unknown_fields=False): + message_id: int + message_thread_id: int | None = None + text: str | None = None + + +class File(msgspec.Struct, forbid_unknown_fields=False): + file_path: str + + +class ForumTopic(msgspec.Struct, forbid_unknown_fields=False): + message_thread_id: int + + +class Update(msgspec.Struct, forbid_unknown_fields=False): + update_id: int + message: dict[str, Any] | None = None + callback_query: dict[str, Any] | None = None diff --git a/src/takopi/telegram/backend.py b/src/takopi/telegram/backend.py index 4916ab1..1c748ef 100644 --- a/src/takopi/telegram/backend.py +++ b/src/takopi/telegram/backend.py @@ -2,22 +2,19 @@ from __future__ import annotations import os from pathlib import Path -from typing import cast import anyio from ..backends import EngineBackend -from ..runner_bridge import ExecBridgeConfig from ..logging import get_logger - -from ..transports import SetupResult, TransportBackend +from ..runner_bridge import ExecBridgeConfig +from ..settings import TelegramTransportSettings from ..transport_runtime import TransportRuntime +from ..transports import SetupResult, TransportBackend from .bridge import ( TelegramBridgeConfig, TelegramPresenter, TelegramTransport, - TelegramFilesConfig, - TelegramTopicsConfig, run_main_loop, ) from .client import TelegramClient @@ -26,6 +23,12 @@ from .onboarding import check_setup, interactive_setup logger = get_logger(__name__) +def _expect_transport_settings(transport_config: object) -> TelegramTransportSettings: + if isinstance(transport_config, TelegramTransportSettings): + return transport_config + raise TypeError("transport_config must be TelegramTransportSettings") + + def _build_startup_message( runtime: TransportRuntime, *, @@ -33,9 +36,20 @@ def _build_startup_message( ) -> str: available_engines = list(runtime.available_engine_ids()) missing_engines = list(runtime.missing_engine_ids()) + misconfigured_engines = list(runtime.engine_ids_with_status("bad_config")) + failed_engines = list(runtime.engine_ids_with_status("load_error")) + engine_list = ", ".join(available_engines) if available_engines else "none" + + notes: list[str] = [] if missing_engines: - engine_list = f"{engine_list} (not installed: {', '.join(missing_engines)})" + notes.append(f"not installed: {', '.join(missing_engines)}") + if misconfigured_engines: + notes.append(f"misconfigured: {', '.join(misconfigured_engines)}") + if failed_engines: + notes.append(f"failed to load: {', '.join(failed_engines)}") + if notes: + engine_list = f"{engine_list} ({'; '.join(notes)})" project_aliases = sorted( {alias for alias in runtime.project_aliases()}, key=str.lower ) @@ -49,34 +63,6 @@ def _build_startup_message( ) -def _build_topics_config(transport_config: dict[str, object]) -> TelegramTopicsConfig: - raw = cast(dict[str, object], transport_config.get("topics", {})) - return TelegramTopicsConfig( - enabled=cast(bool, raw.get("enabled", False)), - scope=cast(str, raw.get("scope", "auto")), - ) - - -def _build_files_config(transport_config: dict[str, object]) -> TelegramFilesConfig: - defaults = TelegramFilesConfig() - raw = cast(dict[str, object], transport_config.get("files", {})) - return TelegramFilesConfig( - enabled=cast(bool, raw.get("enabled", defaults.enabled)), - auto_put=cast(bool, raw.get("auto_put", defaults.auto_put)), - uploads_dir=cast(str, raw.get("uploads_dir", defaults.uploads_dir)), - max_upload_bytes=defaults.max_upload_bytes, - max_download_bytes=defaults.max_download_bytes, - allowed_user_ids=frozenset( - cast( - list[int], raw.get("allowed_user_ids", list(defaults.allowed_user_ids)) - ) - ), - deny_globs=tuple( - cast(list[str], raw.get("deny_globs", list(defaults.deny_globs))) - ), - ) - - class TelegramBackend(TransportBackend): id = "telegram" description = "Telegram bot" @@ -92,23 +78,23 @@ class TelegramBackend(TransportBackend): def interactive_setup(self, *, force: bool) -> bool: return interactive_setup(force=force) - def lock_token( - self, *, transport_config: dict[str, object], config_path: Path - ) -> str | None: + def lock_token(self, *, transport_config: object, config_path: Path) -> str | None: _ = config_path - return cast(str, transport_config.get("bot_token")) + settings = _expect_transport_settings(transport_config) + return settings.bot_token def build_and_run( self, *, - transport_config: dict[str, object], + transport_config: object, config_path: Path, runtime: TransportRuntime, final_notify: bool, default_engine_override: str | None, ) -> None: - token = cast(str, transport_config.get("bot_token")) - chat_id = cast(int, transport_config.get("chat_id")) + settings = _expect_transport_settings(transport_config) + token = settings.bot_token + chat_id = settings.chat_id startup_msg = _build_startup_message( runtime, startup_pwd=os.getcwd(), @@ -121,19 +107,16 @@ class TelegramBackend(TransportBackend): presenter=presenter, final_notify=final_notify, ) - topics = _build_topics_config(transport_config) - files = _build_files_config(transport_config) cfg = TelegramBridgeConfig( bot=bot, runtime=runtime, chat_id=chat_id, startup_msg=startup_msg, exec_cfg=exec_cfg, - voice_transcription=cast( - bool, transport_config.get("voice_transcription", False) - ), - topics=topics, - files=files, + voice_transcription=settings.voice_transcription, + voice_max_bytes=int(settings.voice_max_bytes), + topics=settings.topics, + files=settings.files, ) async def run_loop() -> None: @@ -142,7 +125,7 @@ class TelegramBackend(TransportBackend): watch_config=runtime.watch_config, default_engine_override=default_engine_override, transport_id=self.id, - transport_config=transport_config, + transport_config=settings, ) anyio.run(run_loop) diff --git a/src/takopi/telegram/bridge.py b/src/takopi/telegram/bridge.py index 1f1c526..e904be9 100644 --- a/src/takopi/telegram/bridge.py +++ b/src/takopi/telegram/bridge.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Awaitable, Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import cast from ..logging import get_logger @@ -12,6 +12,11 @@ from ..transport import MessageRef, RenderedMessage, SendOptions, Transport from ..transport_runtime import TransportRuntime from ..context import RunContext from ..model import ResumeToken +from ..settings import ( + TelegramFilesSettings, + TelegramTopicsSettings, + TelegramTransportSettings, +) from .client import BotClient from .render import prepare_telegram from .types import TelegramCallbackQuery, TelegramIncomingMessage @@ -20,8 +25,6 @@ logger = get_logger(__name__) __all__ = [ "TelegramBridgeConfig", - "TelegramFilesConfig", - "TelegramTopicsConfig", "TelegramPresenter", "TelegramTransport", "build_bot_commands", @@ -85,29 +88,6 @@ def _is_cancelled_label(label: str) -> bool: return stripped.lower() == "cancelled" -@dataclass(frozen=True) -class TelegramFilesConfig: - enabled: bool = False - auto_put: bool = True - uploads_dir: str = "incoming" - max_upload_bytes: int = 20 * 1024 * 1024 - max_download_bytes: int = 50 * 1024 * 1024 - allowed_user_ids: frozenset[int] = frozenset() - deny_globs: tuple[str, ...] = ( - ".git/**", - ".env", - ".envrc", - "**/*.pem", - "**/.ssh/**", - ) - - -@dataclass(frozen=True) -class TelegramTopicsConfig: - enabled: bool = False - scope: str = "auto" - - @dataclass(frozen=True) class TelegramBridgeConfig: bot: BotClient @@ -116,9 +96,10 @@ class TelegramBridgeConfig: startup_msg: str exec_cfg: ExecBridgeConfig voice_transcription: bool = False - files: TelegramFilesConfig = TelegramFilesConfig() + voice_max_bytes: int = 10 * 1024 * 1024 + files: TelegramFilesSettings = field(default_factory=TelegramFilesSettings) chat_ids: tuple[int, ...] | None = None - topics: TelegramTopicsConfig = TelegramTopicsConfig() + topics: TelegramTopicsSettings = field(default_factory=TelegramTopicsSettings) class TelegramTransport: @@ -166,7 +147,7 @@ class TelegramTransport: ) if sent is None: return None - message_id = cast(int, sent["message_id"]) + message_id = sent.message_id return MessageRef( channel_id=chat_id, message_id=message_id, @@ -192,7 +173,7 @@ class TelegramTransport: ) if edited is None: return ref if not wait else None - message_id = cast(int, edited.get("message_id", message_id)) + message_id = edited.message_id return MessageRef( channel_id=chat_id, message_id=message_id, @@ -287,7 +268,7 @@ async def run_main_loop( watch_config: bool | None = None, default_engine_override: str | None = None, transport_id: str | None = None, - transport_config: dict[str, object] | None = None, + transport_config: TelegramTransportSettings | None = None, ) -> None: from .loop import run_main_loop as _run_main_loop diff --git a/src/takopi/telegram/client.py b/src/takopi/telegram/client.py index f64d9d9..46dca64 100644 --- a/src/takopi/telegram/client.py +++ b/src/takopi/telegram/client.py @@ -12,13 +12,16 @@ from typing import ( Iterable, Protocol, TYPE_CHECKING, + TypeVar, ) +import msgspec import httpx import anyio from ..logging import get_logger +from .api_models import Chat, ChatMember, File, ForumTopic, Message, Update, User from .types import ( TelegramCallbackQuery, TelegramDocument, @@ -29,6 +32,8 @@ from .types import ( logger = get_logger(__name__) +T = TypeVar("T") + SEND_PRIORITY = 0 DELETE_PRIORITY = 1 @@ -51,15 +56,20 @@ def is_group_chat_id(chat_id: int) -> bool: def parse_incoming_update( - update: dict[str, Any], + update: Update | dict[str, Any], *, chat_id: int | None = None, chat_ids: set[int] | None = None, ) -> TelegramIncomingUpdate | None: - msg = update.get("message") + if isinstance(update, Update): + msg = update.message + callback_query = update.callback_query + else: + msg = update.get("message") + callback_query = update.get("callback_query") + if isinstance(msg, dict): return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids) - callback_query = update.get("callback_query") if isinstance(callback_query, dict): return _parse_callback_query( callback_query, @@ -303,7 +313,7 @@ async def poll_incoming( if allowed is None and chat_id is not None: allowed = {chat_id} for upd in updates: - offset = upd["update_id"] + 1 + offset = upd.update_id + 1 msg = parse_incoming_update(upd, chat_ids=allowed) if msg is not None: yield msg @@ -317,9 +327,9 @@ class BotClient(Protocol): offset: int | None, timeout_s: int = 50, allowed_updates: list[str] | None = None, - ) -> list[dict] | None: ... + ) -> list[Update] | None: ... - async def get_file(self, file_id: str) -> dict | None: ... + async def get_file(self, file_id: str) -> File | None: ... async def download_file(self, file_path: str) -> bytes | None: ... @@ -335,7 +345,7 @@ class BotClient(Protocol): reply_markup: dict[str, Any] | None = None, *, replace_message_id: int | None = None, - ) -> dict | None: ... + ) -> Message | None: ... async def send_document( self, @@ -346,7 +356,7 @@ class BotClient(Protocol): message_thread_id: int | None = None, disable_notification: bool | None = False, caption: str | None = None, - ) -> dict | None: ... + ) -> Message | None: ... async def edit_message_text( self, @@ -358,7 +368,7 @@ class BotClient(Protocol): reply_markup: dict[str, Any] | None = None, *, wait: bool = True, - ) -> dict | None: ... + ) -> Message | None: ... async def delete_message( self, @@ -374,7 +384,7 @@ class BotClient(Protocol): language_code: str | None = None, ) -> bool: ... - async def get_me(self) -> dict | None: ... + async def get_me(self) -> User | None: ... async def answer_callback_query( self, @@ -383,15 +393,17 @@ class BotClient(Protocol): show_alert: bool | None = None, ) -> bool: ... - async def get_chat(self, chat_id: int) -> dict | None: ... + async def get_chat(self, chat_id: int) -> Chat | None: ... - async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None: ... + async def get_chat_member( + self, chat_id: int, user_id: int + ) -> ChatMember | None: ... async def create_forum_topic( self, chat_id: int, name: str, - ) -> dict | None: ... + ) -> ForumTopic | None: ... async def edit_forum_topic( self, @@ -672,71 +684,13 @@ class TelegramClient: if self._owns_http_client and self._http_client is not None: await self._http_client.aclose() - async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None: - if self._http_client is None or self._base is None: - raise RuntimeError("TelegramClient is configured without an HTTP client.") - logger.debug("telegram.request", method=method, payload=json_data) - try: - resp = await self._http_client.post( - f"{self._base}/{method}", json=json_data - ) - except httpx.HTTPError as e: - url = getattr(e.request, "url", None) - logger.error( - "telegram.network_error", - method=method, - url=str(url) if url is not None else None, - error=str(e), - error_type=e.__class__.__name__, - ) - return None - - try: - resp.raise_for_status() - except httpx.HTTPStatusError as e: - if resp.status_code == 429: - retry_after: float | None = None - try: - payload = resp.json() - except Exception: - payload = None - if isinstance(payload, dict): - retry_after = retry_after_from_payload(payload) - retry_after = 5.0 if retry_after is None else retry_after - logger.warning( - "telegram.rate_limited", - method=method, - status=resp.status_code, - url=str(resp.request.url), - retry_after=retry_after, - ) - raise TelegramRetryAfter(retry_after) from e - body = resp.text - logger.error( - "telegram.http_error", - method=method, - status=resp.status_code, - url=str(resp.request.url), - error=str(e), - body=body, - ) - return None - - try: - payload = resp.json() - except Exception as e: - body = resp.text - logger.error( - "telegram.bad_response", - method=method, - status=resp.status_code, - url=str(resp.request.url), - error=str(e), - error_type=e.__class__.__name__, - body=body, - ) - return None - + def _parse_telegram_envelope( + self, + *, + method: str, + resp: httpx.Response, + payload: Any, + ) -> Any | None: if not isinstance(payload, dict): logger.error( "telegram.invalid_payload", @@ -768,169 +722,237 @@ class TelegramClient: logger.debug("telegram.response", method=method, payload=payload) return payload.get("result") + async def _request( + self, + method: str, + *, + json: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, + files: dict[str, Any] | None = None, + ) -> Any | None: + if self._http_client is None or self._base is None: + raise RuntimeError("TelegramClient is configured without an HTTP client.") + request_payload = json if json is not None else data + logger.debug("telegram.request", method=method, payload=request_payload) + try: + if json is not None: + resp = await self._http_client.post(f"{self._base}/{method}", json=json) + else: + resp = await self._http_client.post( + f"{self._base}/{method}", data=data, files=files + ) + except httpx.HTTPError as exc: + url = getattr(exc.request, "url", None) + logger.error( + "telegram.network_error", + method=method, + url=str(url) if url is not None else None, + error=str(exc), + error_type=exc.__class__.__name__, + ) + return None + + try: + resp.raise_for_status() + except httpx.HTTPStatusError as exc: + if resp.status_code == 429: + retry_after: float | None = None + try: + response_payload = resp.json() + except Exception: + response_payload = None + if isinstance(response_payload, dict): + retry_after = retry_after_from_payload(response_payload) + retry_after = 5.0 if retry_after is None else retry_after + logger.warning( + "telegram.rate_limited", + method=method, + status=resp.status_code, + url=str(resp.request.url), + retry_after=retry_after, + ) + raise TelegramRetryAfter(retry_after) from exc + body = resp.text + logger.error( + "telegram.http_error", + method=method, + status=resp.status_code, + url=str(resp.request.url), + error=str(exc), + body=body, + ) + return None + + try: + response_payload = resp.json() + except Exception as exc: + body = resp.text + logger.error( + "telegram.bad_response", + method=method, + status=resp.status_code, + url=str(resp.request.url), + error=str(exc), + error_type=exc.__class__.__name__, + body=body, + ) + return None + + return self._parse_telegram_envelope( + method=method, + resp=resp, + payload=response_payload, + ) + + def _decode_result( + self, + *, + method: str, + payload: Any, + model: type[T], + ) -> T | None: + if payload is None: + return None + try: + return msgspec.convert(payload, type=model) + except Exception as exc: + logger.error( + "telegram.decode_error", + method=method, + error=str(exc), + error_type=exc.__class__.__name__, + ) + return None + + async def _call_with_retry_after( + self, + fn: Callable[[], Awaitable[T]], + ) -> T: + while True: + try: + return await fn() + except TelegramRetryAfter as exc: + await self._sleep(exc.retry_after) + + async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None: + return await self._request(method, json=json_data) + async def _post_form( self, method: str, data: dict[str, Any], files: dict[str, Any], ) -> Any | None: - if self._http_client is None or self._base is None: - raise RuntimeError("TelegramClient is configured without an HTTP client.") - logger.debug("telegram.request", method=method, payload=data) - try: - resp = await self._http_client.post( - f"{self._base}/{method}", data=data, files=files - ) - except httpx.HTTPError as e: - url = getattr(e.request, "url", None) - logger.error( - "telegram.network_error", - method=method, - url=str(url) if url is not None else None, - error=str(e), - error_type=e.__class__.__name__, - ) - return None - - try: - resp.raise_for_status() - except httpx.HTTPStatusError as e: - if resp.status_code == 429: - retry_after: float | None = None - try: - payload = resp.json() - except Exception: - payload = None - if isinstance(payload, dict): - retry_after = retry_after_from_payload(payload) - retry_after = 5.0 if retry_after is None else retry_after - logger.warning( - "telegram.rate_limited", - method=method, - status=resp.status_code, - url=str(resp.request.url), - retry_after=retry_after, - ) - raise TelegramRetryAfter(retry_after) from e - body = resp.text - logger.error( - "telegram.http_error", - method=method, - status=resp.status_code, - url=str(resp.request.url), - error=str(e), - body=body, - ) - return None - - try: - payload = resp.json() - except Exception as e: - body = resp.text - logger.error( - "telegram.bad_response", - method=method, - status=resp.status_code, - error=str(e), - error_type=e.__class__.__name__, - body=body, - ) - return None - - if not isinstance(payload, dict): - logger.error( - "telegram.invalid_payload", - method=method, - url=str(resp.request.url), - payload=payload, - ) - return None - - if not payload.get("ok"): - if payload.get("error_code") == 429: - retry_after = retry_after_from_payload(payload) - retry_after = 5.0 if retry_after is None else retry_after - logger.warning( - "telegram.rate_limited", - method=method, - url=str(resp.request.url), - retry_after=retry_after, - ) - raise TelegramRetryAfter(retry_after) - logger.error( - "telegram.api_error", - method=method, - url=str(resp.request.url), - payload=payload, - ) - return None - - logger.debug("telegram.response", method=method, payload=payload) - return payload.get("result") + return await self._request(method, data=data, files=files) async def get_updates( self, offset: int | None, timeout_s: int = 50, allowed_updates: list[str] | None = None, - ) -> list[dict] | None: - while True: - try: - if self._client_override is not None: - return await self._client_override.get_updates( - offset=offset, - timeout_s=timeout_s, - allowed_updates=allowed_updates, + ) -> list[Update] | None: + async def execute() -> list[Update] | None: + if self._client_override is not None: + raw = await self._client_override.get_updates( + offset=offset, + timeout_s=timeout_s, + allowed_updates=allowed_updates, + ) + if raw is None: + return None + try: + return msgspec.convert(raw, type=list[Update]) + except Exception as exc: + logger.error( + "telegram.decode_error", + method="getUpdates", + error=str(exc), + error_type=exc.__class__.__name__, ) - params: dict[str, Any] = {"timeout": timeout_s} - if offset is not None: - params["offset"] = offset - if allowed_updates is not None: - params["allowed_updates"] = allowed_updates - result = await self._post("getUpdates", params) - return result if isinstance(result, list) else None - except TelegramRetryAfter as exc: - await self._sleep(exc.retry_after) + return None - async def get_file(self, file_id: str) -> dict | None: - while True: + params: dict[str, Any] = {"timeout": timeout_s} + if offset is not None: + params["offset"] = offset + if allowed_updates is not None: + params["allowed_updates"] = allowed_updates + result = await self._post("getUpdates", params) + if result is None or not isinstance(result, list): + return None try: - if self._client_override is not None: - return await self._client_override.get_file(file_id) - result = await self._post("getFile", {"file_id": file_id}) - return result if isinstance(result, dict) else None - except TelegramRetryAfter as exc: - await self._sleep(exc.retry_after) + return msgspec.convert(result, type=list[Update]) + except Exception as exc: + logger.error( + "telegram.decode_error", + method="getUpdates", + error=str(exc), + error_type=exc.__class__.__name__, + ) + return None + + return await self._call_with_retry_after(execute) + + async def get_file(self, file_id: str) -> File | None: + async def execute() -> File | None: + if self._client_override is not None: + return await self._client_override.get_file(file_id) + result = await self._post("getFile", {"file_id": file_id}) + return self._decode_result(method="getFile", payload=result, model=File) + + return await self._call_with_retry_after(execute) async def download_file(self, file_path: str) -> bytes | None: - if self._client_override is not None: - return await self._client_override.download_file(file_path) - if self._http_client is None or self._file_base is None: - raise RuntimeError("TelegramClient is configured without an HTTP client.") - url = f"{self._file_base}/{file_path}" - try: - resp = await self._http_client.get(url) - except httpx.HTTPError as exc: - request_url = getattr(exc.request, "url", None) - logger.error( - "telegram.file_network_error", - url=str(request_url) if request_url is not None else None, - error=str(exc), - error_type=exc.__class__.__name__, - ) - return None - try: - resp.raise_for_status() - except httpx.HTTPStatusError as exc: - logger.error( - "telegram.file_http_error", - status=resp.status_code, - url=str(resp.request.url), - error=str(exc), - body=resp.text, - ) - return None - return resp.content + async def execute() -> bytes | None: + if self._client_override is not None: + return await self._client_override.download_file(file_path) + if self._http_client is None or self._file_base is None: + raise RuntimeError( + "TelegramClient is configured without an HTTP client." + ) + url = f"{self._file_base}/{file_path}" + try: + resp = await self._http_client.get(url) + except httpx.HTTPError as exc: + request_url = getattr(exc.request, "url", None) + logger.error( + "telegram.file_network_error", + url=str(request_url) if request_url is not None else None, + error=str(exc), + error_type=exc.__class__.__name__, + ) + return None + try: + resp.raise_for_status() + except httpx.HTTPStatusError as exc: + if resp.status_code == 429: + retry_after: float | None = None + try: + response_payload = resp.json() + except Exception: + response_payload = None + if isinstance(response_payload, dict): + retry_after = retry_after_from_payload(response_payload) + retry_after = 5.0 if retry_after is None else retry_after + logger.warning( + "telegram.rate_limited", + method="download_file", + status=resp.status_code, + url=str(resp.request.url), + retry_after=retry_after, + ) + raise TelegramRetryAfter(retry_after) from exc + + logger.error( + "telegram.file_http_error", + status=resp.status_code, + url=str(resp.request.url), + error=str(exc), + body=resp.text, + ) + return None + return resp.content + + return await self._call_with_retry_after(execute) async def send_message( self, @@ -944,8 +966,8 @@ class TelegramClient: reply_markup: dict[str, Any] | None = None, *, replace_message_id: int | None = None, - ) -> dict | None: - async def execute() -> dict | None: + ) -> Message | None: + async def execute() -> Message | None: if self._client_override is not None: return await self._client_override.send_message( chat_id=chat_id, @@ -972,7 +994,11 @@ class TelegramClient: if reply_markup is not None: params["reply_markup"] = reply_markup result = await self._post("sendMessage", params) - return result if isinstance(result, dict) else None + return self._decode_result( + method="sendMessage", + payload=result, + model=Message, + ) if replace_message_id is not None: await self._outbox.drop_pending(key=("edit", chat_id, replace_message_id)) @@ -1000,8 +1026,8 @@ class TelegramClient: message_thread_id: int | None = None, disable_notification: bool | None = False, caption: str | None = None, - ) -> dict | None: - async def execute() -> dict | None: + ) -> Message | None: + async def execute() -> Message | None: if self._client_override is not None: return await self._client_override.send_document( chat_id=chat_id, @@ -1026,7 +1052,11 @@ class TelegramClient: params, files={"document": (filename, content)}, ) - return result if isinstance(result, dict) else None + return self._decode_result( + method="sendDocument", + payload=result, + model=Message, + ) return await self.enqueue_op( key=self.unique_key("send_document"), @@ -1046,8 +1076,8 @@ class TelegramClient: reply_markup: dict[str, Any] | None = None, *, wait: bool = True, - ) -> dict | None: - async def execute() -> dict | None: + ) -> Message | None: + async def execute() -> Message | None: if self._client_override is not None: return await self._client_override.edit_message_text( chat_id=chat_id, @@ -1070,7 +1100,11 @@ class TelegramClient: if reply_markup is not None: params["reply_markup"] = reply_markup result = await self._post("editMessageText", params) - return result if isinstance(result, dict) else None + return self._decode_result( + method="editMessageText", + payload=result, + model=Message, + ) return await self.enqueue_op( key=("edit", chat_id, message_id), @@ -1142,12 +1176,12 @@ class TelegramClient: ) ) - async def get_me(self) -> dict | None: - async def execute() -> dict | None: + async def get_me(self) -> User | None: + async def execute() -> User | None: if self._client_override is not None: return await self._client_override.get_me() result = await self._post("getMe", {}) - return result if isinstance(result, dict) else None + return self._decode_result(method="getMe", payload=result, model=User) return await self.enqueue_op( key=self.unique_key("get_me"), @@ -1188,12 +1222,12 @@ class TelegramClient: ) ) - async def get_chat(self, chat_id: int) -> dict | None: - async def execute() -> dict | None: + async def get_chat(self, chat_id: int) -> Chat | None: + async def execute() -> Chat | None: if self._client_override is not None: return await self._client_override.get_chat(chat_id) result = await self._post("getChat", {"chat_id": chat_id}) - return result if isinstance(result, dict) else None + return self._decode_result(method="getChat", payload=result, model=Chat) return await self.enqueue_op( key=self.unique_key("get_chat"), @@ -1203,14 +1237,18 @@ class TelegramClient: chat_id=chat_id, ) - async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None: - async def execute() -> dict | None: + async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None: + async def execute() -> ChatMember | None: if self._client_override is not None: return await self._client_override.get_chat_member(chat_id, user_id) result = await self._post( "getChatMember", {"chat_id": chat_id, "user_id": user_id} ) - return result if isinstance(result, dict) else None + return self._decode_result( + method="getChatMember", + payload=result, + model=ChatMember, + ) return await self.enqueue_op( key=self.unique_key("get_chat_member"), @@ -1220,14 +1258,18 @@ class TelegramClient: chat_id=chat_id, ) - async def create_forum_topic(self, chat_id: int, name: str) -> dict | None: - async def execute() -> dict | None: + async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None: + async def execute() -> ForumTopic | None: if self._client_override is not None: return await self._client_override.create_forum_topic(chat_id, name) result = await self._post( "createForumTopic", {"chat_id": chat_id, "name": name} ) - return result if isinstance(result, dict) else None + return self._decode_result( + method="createForumTopic", + payload=result, + model=ForumTopic, + ) return await self.enqueue_op( key=self.unique_key("create_forum_topic"), diff --git a/src/takopi/telegram/commands.py b/src/takopi/telegram/commands.py index 373a971..2943dd5 100644 --- a/src/takopi/telegram/commands.py +++ b/src/takopi/telegram/commands.py @@ -289,11 +289,10 @@ async def _check_file_permissions(cfg, msg: TelegramIncomingMessage) -> bool: if is_private: return True member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) - if not isinstance(member, dict): + if member is None: await reply(text="failed to verify file transfer permissions.") return False - status = member.get("status") - if status in {"creator", "administrator"}: + if member.status in {"creator", "administrator"}: return True await reply(text="file transfer is restricted to group admins.") return False @@ -377,21 +376,14 @@ async def _save_document_payload( error="file is too large to upload.", ) file_info = await cfg.bot.get_file(document.file_id) - if not isinstance(file_info, dict): - return _FilePutResult( - name=name, - rel_path=None, - size=None, - error="failed to fetch file metadata.", - ) - file_path = file_info.get("file_path") - if not isinstance(file_path, str) or not file_path: + if file_info is None: return _FilePutResult( name=name, rel_path=None, size=None, error="failed to fetch file metadata.", ) + file_path = file_info.file_path name = default_upload_name(document.file_name, file_path) resolved_path = rel_path if resolved_path is None: @@ -972,10 +964,10 @@ async def _handle_topic_command( return title = _topic_title(runtime=cfg.runtime, context=context) created = await cfg.bot.create_forum_topic(msg.chat_id, title) - thread_id = created.get("message_thread_id") if isinstance(created, dict) else None - if isinstance(thread_id, bool) or not isinstance(thread_id, int): + if created is None: await reply(text="failed to create topic.") return + thread_id = created.message_thread_id await store.set_context( msg.chat_id, thread_id, diff --git a/src/takopi/telegram/loop.py b/src/takopi/telegram/loop.py index ac0437f..1cd9e57 100644 --- a/src/takopi/telegram/loop.py +++ b/src/takopi/telegram/loop.py @@ -14,6 +14,7 @@ from ..directives import DirectiveError from ..logging import get_logger from ..model import EngineId, ResumeToken from ..scheduler import ThreadJob, ThreadScheduler +from ..settings import TelegramTransportSettings from ..transport import MessageRef from ..context import RunContext from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain @@ -169,7 +170,7 @@ async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int | if drained: logger.info("startup.backlog.drained", count=drained) return offset - offset = updates[-1]["update_id"] + 1 + offset = updates[-1].update_id + 1 drained += len(updates) @@ -266,7 +267,7 @@ async def run_main_loop( watch_config: bool | None = None, default_engine_override: str | None = None, transport_id: str | None = None, - transport_config: dict[str, object] | None = None, + transport_config: TelegramTransportSettings | None = None, ) -> None: from ..runner_bridge import RunningTasks @@ -277,7 +278,7 @@ async def run_main_loop( } reserved_commands = _reserved_commands(cfg.runtime) transport_snapshot = ( - dict(transport_config) if transport_config is not None else None + transport_config.model_dump() if transport_config is not None else None ) topic_store: TopicStateStore | None = None media_groups: dict[tuple[int, str], _MediaGroupState] = {} @@ -476,6 +477,7 @@ async def run_main_loop( bot=cfg.bot, msg=msg, enabled=cfg.voice_transcription, + max_bytes=cfg.voice_max_bytes, reply=reply, ) if text is None: diff --git a/src/takopi/telegram/onboarding.py b/src/takopi/telegram/onboarding.py index dea4cef..9323354 100644 --- a/src/takopi/telegram/onboarding.py +++ b/src/takopi/telegram/onboarding.py @@ -33,6 +33,7 @@ from ..engines import list_backends from ..logging import suppress_logs from ..settings import HOME_CONFIG_PATH, load_settings, require_telegram from ..transports import SetupResult +from .api_models import User from .client import TelegramClient, TelegramRetryAfter __all__ = [ @@ -131,7 +132,7 @@ def mask_token(token: str) -> str: return f"{token[:9]}...{token[-5:]}" -async def get_bot_info(token: str) -> dict[str, Any] | None: +async def get_bot_info(token: str) -> User | None: bot = TelegramClient(token) try: for _ in range(3): @@ -153,23 +154,19 @@ async def wait_for_chat(token: str) -> ChatInfo: offset=None, timeout_s=0, allowed_updates=allowed_updates ) if drained: - offset = drained[-1]["update_id"] + 1 + offset = drained[-1].update_id + 1 while True: - try: - updates = await bot.get_updates( - offset=offset, timeout_s=50, allowed_updates=allowed_updates - ) - except TelegramRetryAfter as exc: - await anyio.sleep(exc.retry_after) - continue + updates = await bot.get_updates( + offset=offset, timeout_s=50, allowed_updates=allowed_updates + ) if updates is None: await anyio.sleep(1) continue if not updates: continue - offset = updates[-1]["update_id"] + 1 + offset = updates[-1].update_id + 1 update = updates[-1] - msg = update.get("message") + msg = update.message if not isinstance(msg, dict): continue sender = msg.get("from") @@ -298,7 +295,7 @@ def _confirm(message: str, *, default: bool = True) -> bool | None: return question.ask() -def _prompt_token(console: Console) -> tuple[str, dict[str, Any]] | None: +def _prompt_token(console: Console) -> tuple[str, User] | None: while True: token = questionary.password("paste your bot token:").ask() if token is None: @@ -310,11 +307,10 @@ def _prompt_token(console: Console) -> tuple[str, dict[str, Any]] | None: console.print(" validating...") info = anyio.run(get_bot_info, token) if info: - username = info.get("username") - if isinstance(username, str) and username: - console.print(f" connected to @{username}") + if info.username: + console.print(f" connected to @{info.username}") else: - name = info.get("first_name") or "your bot" + name = info.first_name or "your bot" console.print(f" connected to {name}") return token, info console.print(" failed to connect, check the token and try again") @@ -342,7 +338,7 @@ def capture_chat_id(*, token: str | None = None) -> ChatInfo | None: return None token, info = token_info - bot_ref = f"@{info['username']}" + bot_ref = f"@{info.username}" console.print("") console.print(f" send /start to {bot_ref} (works in groups too)") console.print(" waiting...") @@ -401,7 +397,7 @@ def interactive_setup(*, force: bool) -> bool: if token_info is None: return False token, info = token_info - bot_ref = f"@{info['username']}" + bot_ref = f"@{info.username}" console.print("") console.print(f" send /start to {bot_ref} (works in groups too)") diff --git a/src/takopi/telegram/topic_state.py b/src/takopi/telegram/topic_state.py index f9737bc..074ac34 100644 --- a/src/takopi/telegram/topic_state.py +++ b/src/takopi/telegram/topic_state.py @@ -4,9 +4,9 @@ import json import os from dataclasses import dataclass from pathlib import Path -from typing import Any, cast import anyio +import msgspec from ..context import RunContext from ..logging import get_logger @@ -27,6 +27,26 @@ class TopicThreadSnapshot: topic_title: str | None +class _ContextState(msgspec.Struct, forbid_unknown_fields=False): + project: str | None = None + branch: str | None = None + + +class _SessionState(msgspec.Struct, forbid_unknown_fields=False): + resume: str + + +class _ThreadState(msgspec.Struct, forbid_unknown_fields=False): + context: _ContextState | None = None + sessions: dict[str, _SessionState] = msgspec.field(default_factory=dict) + topic_title: str | None = None + + +class _TopicState(msgspec.Struct, forbid_unknown_fields=False): + version: int + threads: dict[str, _ThreadState] = msgspec.field(default_factory=dict) + + def resolve_state_path(config_path: Path) -> Path: return config_path.with_name(STATE_FILENAME) @@ -35,34 +55,31 @@ def _thread_key(chat_id: int, thread_id: int) -> str: return f"{chat_id}:{thread_id}" -def _parse_context(raw: object) -> RunContext | None: - if not isinstance(raw, dict): +def _normalize_text(value: str | None) -> str | None: + if value is None: return None - payload = cast(dict[str, object], raw) - project = payload.get("project") - branch = payload.get("branch") - if project is not None and not isinstance(project, str): - project = None - if isinstance(project, str): - project = project.strip() or None - if branch is not None and not isinstance(branch, str): - branch = None - if isinstance(branch, str): - branch = branch.strip() or None + value = value.strip() + return value or None + + +def _context_from_state(state: _ContextState | None) -> RunContext | None: + if state is None: + return None + project = _normalize_text(state.project) + branch = _normalize_text(state.branch) if project is None and branch is None: return None return RunContext(project=project, branch=branch) -def _dump_context(context: RunContext | None) -> dict[str, str] | None: - if context is None or (context.project is None and context.branch is None): +def _context_to_state(context: RunContext | None) -> _ContextState | None: + if context is None: return None - payload: dict[str, str] = {} - if context.project is not None: - payload["project"] = context.project - if context.branch is not None: - payload["branch"] = context.branch - return payload or None + project = _normalize_text(context.project) + branch = _normalize_text(context.branch) + if project is None and branch is None: + return None + return _ContextState(project=project, branch=branch) class TopicStateStore: @@ -71,10 +88,7 @@ class TopicStateStore: self._lock = anyio.Lock() self._loaded = False self._mtime_ns: int | None = None - self._data: dict[str, Any] = { - "version": STATE_VERSION, - "threads": {}, - } + self._state = _TopicState(version=STATE_VERSION, threads={}) async def get_thread( self, chat_id: int, thread_id: int @@ -92,7 +106,7 @@ class TopicStateStore: thread = self._get_thread_locked(chat_id, thread_id) if thread is None: return None - return _parse_context(thread.get("context")) + return _context_from_state(thread.context) async def set_context( self, @@ -105,9 +119,9 @@ class TopicStateStore: async with self._lock: self._reload_locked_if_needed() thread = self._ensure_thread_locked(chat_id, thread_id) - thread["context"] = _dump_context(context) + thread.context = _context_to_state(context) if topic_title is not None: - thread["topic_title"] = topic_title + thread.topic_title = topic_title self._save_locked() async def clear_context(self, chat_id: int, thread_id: int) -> None: @@ -116,7 +130,7 @@ class TopicStateStore: thread = self._get_thread_locked(chat_id, thread_id) if thread is None: return - thread.pop("context", None) + thread.context = None self._save_locked() async def get_session_resume( @@ -127,16 +141,10 @@ class TopicStateStore: thread = self._get_thread_locked(chat_id, thread_id) if thread is None: return None - sessions = thread.get("sessions") - if not isinstance(sessions, dict): + entry = thread.sessions.get(engine) + if entry is None or not entry.resume: return None - entry = sessions.get(engine) - if not isinstance(entry, dict): - return None - value = entry.get("resume") - if not isinstance(value, str) or not value: - return None - return ResumeToken(engine=engine, value=value) + return ResumeToken(engine=engine, value=entry.resume) async def set_session_resume( self, chat_id: int, thread_id: int, token: ResumeToken @@ -144,13 +152,7 @@ class TopicStateStore: async with self._lock: self._reload_locked_if_needed() thread = self._ensure_thread_locked(chat_id, thread_id) - sessions = thread.get("sessions") - if not isinstance(sessions, dict): - sessions = {} - thread["sessions"] = sessions - sessions[token.engine] = { - "resume": token.value, - } + thread.sessions[token.engine] = _SessionState(resume=token.value) self._save_locked() async def clear_sessions(self, chat_id: int, thread_id: int) -> None: @@ -159,7 +161,7 @@ class TopicStateStore: thread = self._get_thread_locked(chat_id, thread_id) if thread is None: return - thread.pop("sessions", None) + thread.sessions = {} self._save_locked() async def find_thread_for_context( @@ -167,47 +169,37 @@ class TopicStateStore: ) -> int | None: async with self._lock: self._reload_locked_if_needed() - threads = self._data.get("threads") - if not isinstance(threads, dict): - return None - for raw_key, payload in threads.items(): - if not isinstance(raw_key, str) or not isinstance(payload, dict): + target_project = _normalize_text(context.project) + target_branch = _normalize_text(context.branch) + for raw_key, thread in self._state.threads.items(): + if not raw_key.startswith(f"{chat_id}:"): continue - parsed = _parse_context(payload.get("context")) + parsed = _context_from_state(thread.context) if parsed is None: continue - if parsed.project != context.project or parsed.branch != context.branch: - continue - if not raw_key.startswith(f"{chat_id}:"): + if parsed.project != target_project or parsed.branch != target_branch: continue try: _, thread_str = raw_key.split(":", 1) return int(thread_str) - except (ValueError, TypeError): + except ValueError: continue return None def _snapshot_locked( - self, thread: dict[str, Any], chat_id: int, thread_id: int + self, thread: _ThreadState, chat_id: int, thread_id: int ) -> TopicThreadSnapshot: - sessions: dict[str, str] = {} - raw_sessions = thread.get("sessions") - if isinstance(raw_sessions, dict): - for engine, entry in raw_sessions.items(): - if not isinstance(engine, str) or not isinstance(entry, dict): - continue - value = entry.get("resume") - if isinstance(value, str) and value: - sessions[engine] = value - topic_title = thread.get("topic_title") - if not isinstance(topic_title, str): - topic_title = None + sessions = { + engine: entry.resume + for engine, entry in thread.sessions.items() + if entry.resume + } return TopicThreadSnapshot( chat_id=chat_id, thread_id=thread_id, - context=_parse_context(thread.get("context")), + context=_context_from_state(thread.context), sessions=sessions, - topic_title=topic_title, + topic_title=thread.topic_title, ) def _stat_mtime_ns(self) -> int | None: @@ -226,10 +218,10 @@ class TopicStateStore: self._loaded = True self._mtime_ns = self._stat_mtime_ns() if self._mtime_ns is None: - self._data = {"version": STATE_VERSION, "threads": {}} + self._state = _TopicState(version=STATE_VERSION, threads={}) return try: - payload = json.loads(self._path.read_text(encoding="utf-8")) + payload = msgspec.json.decode(self._path.read_bytes(), type=_TopicState) except Exception as exc: logger.warning( "telegram.topic_state.load_failed", @@ -237,29 +229,22 @@ class TopicStateStore: error=str(exc), error_type=exc.__class__.__name__, ) - self._data = {"version": STATE_VERSION, "threads": {}} + self._state = _TopicState(version=STATE_VERSION, threads={}) return - if not isinstance(payload, dict): - self._data = {"version": STATE_VERSION, "threads": {}} - return - version = payload.get("version") - if version != STATE_VERSION: + if payload.version != STATE_VERSION: logger.warning( "telegram.topic_state.version_mismatch", path=str(self._path), - version=version, + version=payload.version, expected=STATE_VERSION, ) - self._data = {"version": STATE_VERSION, "threads": {}} + self._state = _TopicState(version=STATE_VERSION, threads={}) return - threads = payload.get("threads") - if not isinstance(threads, dict): - threads = {} - self._data = {"version": STATE_VERSION, "threads": threads} + self._state = payload def _save_locked(self) -> None: self._path.parent.mkdir(parents=True, exist_ok=True) - payload = {"version": STATE_VERSION, "threads": self._data.get("threads", {})} + payload = msgspec.to_builtins(self._state) tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp") with open(tmp_path, "w", encoding="utf-8") as handle: json.dump(payload, handle, indent=2, sort_keys=True) @@ -267,22 +252,14 @@ class TopicStateStore: os.replace(tmp_path, self._path) self._mtime_ns = self._stat_mtime_ns() - def _get_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any] | None: - threads = self._data.get("threads") - if not isinstance(threads, dict): - return None - entry = threads.get(_thread_key(chat_id, thread_id)) - return entry if isinstance(entry, dict) else None + def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None: + return self._state.threads.get(_thread_key(chat_id, thread_id)) - def _ensure_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any]: - threads = self._data.get("threads") - if not isinstance(threads, dict): - threads = {} - self._data["threads"] = threads + def _ensure_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState: key = _thread_key(chat_id, thread_id) - entry = threads.get(key) - if isinstance(entry, dict): + entry = self._state.threads.get(key) + if entry is not None: return entry - entry = {} - threads[key] = entry + entry = _ThreadState() + self._state.threads[key] = entry return entry diff --git a/src/takopi/telegram/topics.py b/src/takopi/telegram/topics.py index 5174255..6ab102a 100644 --- a/src/takopi/telegram/topics.py +++ b/src/takopi/telegram/topics.py @@ -185,9 +185,9 @@ async def _validate_topics_setup(cfg: TelegramBridgeConfig) -> None: if not cfg.topics.enabled: return me = await cfg.bot.get_me() - bot_id = me.get("id") if isinstance(me, dict) else None - if not isinstance(bot_id, int): + if me is None: raise ConfigError("failed to fetch bot id for topics validation.") + bot_id = me.id scope, chat_ids = _resolve_topics_scope(cfg) if scope == "projects" and not chat_ids: raise ConfigError( @@ -197,37 +197,34 @@ async def _validate_topics_setup(cfg: TelegramBridgeConfig) -> None: for chat_id in chat_ids: chat = await cfg.bot.get_chat(chat_id) - if not isinstance(chat, dict): + if chat is None: raise ConfigError( f"failed to fetch chat info for topics validation ({chat_id})." ) - chat_type = chat.get("type") - is_forum = chat.get("is_forum") - if chat_type != "supergroup": + if chat.type != "supergroup": raise ConfigError( "topics enabled but chat is not a supergroup " f"(chat_id={chat_id}); convert the group and enable topics." ) - if is_forum is not True: + if chat.is_forum is not True: raise ConfigError( "topics enabled but chat does not have topics enabled " f"(chat_id={chat_id}); turn on topics in group settings." ) member = await cfg.bot.get_chat_member(chat_id, bot_id) - if not isinstance(member, dict): + if member is None: raise ConfigError( "failed to fetch bot permissions " f"(chat_id={chat_id}); promote the bot to admin with manage topics." ) - status = member.get("status") - if status == "creator": + if member.status == "creator": continue - if status != "administrator": + if member.status != "administrator": raise ConfigError( "topics enabled but bot is not an admin " f"(chat_id={chat_id}); promote it and grant manage topics." ) - if member.get("can_manage_topics") is not True: + if member.can_manage_topics is not True: raise ConfigError( "topics enabled but bot lacks manage topics permission " f"(chat_id={chat_id}); grant can_manage_topics." diff --git a/src/takopi/telegram/voice.py b/src/takopi/telegram/voice.py index 9d36981..70ffea7 100644 --- a/src/takopi/telegram/voice.py +++ b/src/takopi/telegram/voice.py @@ -2,7 +2,6 @@ from __future__ import annotations import io from collections.abc import Awaitable, Callable -from typing import cast from ..logging import get_logger from openai import AsyncOpenAI, OpenAIError @@ -29,6 +28,7 @@ async def transcribe_voice( bot: BotClient, msg: TelegramIncomingMessage, enabled: bool, + max_bytes: int | None = None, reply: Callable[..., Awaitable[None]], ) -> str | None: voice = msg.voice @@ -37,9 +37,24 @@ async def transcribe_voice( if not enabled: await reply(text=VOICE_TRANSCRIPTION_DISABLED_HINT) return None - file_info = cast(dict[str, object], await bot.get_file(voice.file_id)) - file_path = cast(str, file_info["file_path"]) - audio_bytes = cast(bytes, await bot.download_file(file_path)) + if ( + max_bytes is not None + and voice.file_size is not None + and voice.file_size > max_bytes + ): + await reply(text="voice message is too large to transcribe.") + return None + file_info = await bot.get_file(voice.file_id) + if file_info is None: + await reply(text="failed to fetch voice file.") + return None + audio_bytes = await bot.download_file(file_info.file_path) + if audio_bytes is None: + await reply(text="failed to download voice file.") + return None + if max_bytes is not None and len(audio_bytes) > max_bytes: + await reply(text="voice message is too large to transcribe.") + return None audio_file = io.BytesIO(audio_bytes) audio_file.name = "voice.ogg" async with AsyncOpenAI(timeout=120) as client: diff --git a/src/takopi/transport_runtime.py b/src/takopi/transport_runtime.py index 7d6282a..2d5768d 100644 --- a/src/takopi/transport_runtime.py +++ b/src/takopi/transport_runtime.py @@ -10,7 +10,7 @@ from .context import RunContext from .directives import format_context_line, parse_context_line, parse_directives from .model import EngineId, ResumeToken from .plugins import normalize_allowlist -from .router import AutoRouter +from .router import AutoRouter, EngineStatus from .runner import Runner from .worktrees import WorktreeError, resolve_run_cwd @@ -108,11 +108,14 @@ class TransportRuntime: def available_engine_ids(self) -> tuple[EngineId, ...]: return tuple(entry.engine for entry in self._router.available_entries) - def missing_engine_ids(self) -> tuple[EngineId, ...]: + def engine_ids_with_status(self, status: EngineStatus) -> tuple[EngineId, ...]: return tuple( - entry.engine for entry in self._router.entries if not entry.available + entry.engine for entry in self._router.entries if entry.status == status ) + def missing_engine_ids(self) -> tuple[EngineId, ...]: + return self.engine_ids_with_status("missing_cli") + def project_aliases(self) -> tuple[str, ...]: return tuple(project.alias for project in self._projects.projects.values()) diff --git a/src/takopi/transports.py b/src/takopi/transports.py index 111ddc5..4e948cd 100644 --- a/src/takopi/transports.py +++ b/src/takopi/transports.py @@ -41,13 +41,13 @@ class TransportBackend(Protocol): def interactive_setup(self, *, force: bool) -> bool: ... def lock_token( - self, *, transport_config: dict[str, object], config_path: Path + self, *, transport_config: object, config_path: Path ) -> str | None: ... def build_and_run( self, *, - transport_config: dict[str, object], + transport_config: object, config_path: Path, runtime: TransportRuntime, final_notify: bool, diff --git a/tests/test_onboarding_interactive.py b/tests/test_onboarding_interactive.py index 772b5a6..b904e1d 100644 --- a/tests/test_onboarding_interactive.py +++ b/tests/test_onboarding_interactive.py @@ -1,8 +1,9 @@ from __future__ import annotations +from takopi.backends import EngineBackend from takopi.config import dump_toml from takopi.telegram import onboarding -from takopi.backends import EngineBackend +from takopi.telegram.api_models import User def test_mask_token_short() -> None: @@ -91,7 +92,7 @@ def test_interactive_setup_writes_config(monkeypatch, tmp_path) -> None: def _fake_run(func, *args, **kwargs): if func is onboarding.get_bot_info: - return {"username": "my_bot"} + return User(id=1, username="my_bot") if func is onboarding.wait_for_chat: return onboarding.ChatInfo( chat_id=123, @@ -136,7 +137,7 @@ def test_interactive_setup_preserves_projects(monkeypatch, tmp_path) -> None: def _fake_run(func, *args, **kwargs): if func is onboarding.get_bot_info: - return {"username": "my_bot"} + return User(id=1, username="my_bot") if func is onboarding.wait_for_chat: return onboarding.ChatInfo( chat_id=123, @@ -173,7 +174,7 @@ def test_interactive_setup_no_agents_aborts(monkeypatch, tmp_path) -> None: def _fake_run(func, *args, **kwargs): if func is onboarding.get_bot_info: - return {"username": "my_bot"} + return User(id=1, username="my_bot") if func is onboarding.wait_for_chat: return onboarding.ChatInfo( chat_id=123, @@ -211,7 +212,7 @@ def test_interactive_setup_recovers_from_malformed_toml(monkeypatch, tmp_path) - def _fake_run(func, *args, **kwargs): if func is onboarding.get_bot_info: - return {"username": "my_bot"} + return User(id=1, username="my_bot") if func is onboarding.wait_for_chat: return onboarding.ChatInfo( chat_id=123, @@ -239,7 +240,7 @@ def test_interactive_setup_recovers_from_malformed_toml(monkeypatch, tmp_path) - def test_capture_chat_id_with_token(monkeypatch) -> None: def _fake_run(func, *args, **kwargs): if func is onboarding.get_bot_info: - return {"username": "my_bot"} + return User(id=1, username="my_bot") if func is onboarding.wait_for_chat: return onboarding.ChatInfo( chat_id=456, @@ -261,7 +262,9 @@ def test_capture_chat_id_with_token(monkeypatch) -> None: def test_capture_chat_id_prompts_for_token(monkeypatch) -> None: monkeypatch.setattr( - onboarding, "_prompt_token", lambda _console: ("token", {"username": "bot"}) + onboarding, + "_prompt_token", + lambda _console: ("token", User(id=1, username="bot")), ) def _fake_run(func, *args, **kwargs): diff --git a/tests/test_telegram_backend.py b/tests/test_telegram_backend.py index 36d3746..87b3ce2 100644 --- a/tests/test_telegram_backend.py +++ b/tests/test_telegram_backend.py @@ -9,6 +9,11 @@ from takopi.config import ProjectsConfig from takopi.model import EngineId from takopi.router import AutoRouter, RunnerEntry from takopi.runners.mock import Return, ScriptRunner +from takopi.settings import ( + TelegramFilesSettings, + TelegramTopicsSettings, + TelegramTransportSettings, +) from takopi.telegram import backend as telegram_backend from takopi.transport_runtime import TransportRuntime @@ -20,8 +25,13 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None: missing = ScriptRunner([Return(answer="ok")], engine=pi) router = AutoRouter( entries=[ - RunnerEntry(engine=codex, runner=runner, available=True), - RunnerEntry(engine=pi, runner=missing, available=False, issue="missing"), + RunnerEntry(engine=codex, runner=runner), + RunnerEntry( + engine=pi, + runner=missing, + status="missing_cli", + issue="missing", + ), ], default_engine=codex, ) @@ -40,6 +50,44 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None: assert "projects: `none`" in message +def test_build_startup_message_surfaces_unavailable_engine_reasons( + tmp_path: Path, +) -> None: + codex = EngineId("codex") + pi = EngineId("pi") + claude = EngineId("claude") + runner = ScriptRunner([Return(answer="ok")], engine=codex) + bad_cfg = ScriptRunner([Return(answer="ok")], engine=pi) + load_err = ScriptRunner([Return(answer="ok")], engine=claude) + + router = AutoRouter( + entries=[ + RunnerEntry(engine=codex, runner=runner), + RunnerEntry(engine=pi, runner=bad_cfg, status="bad_config", issue="bad"), + RunnerEntry( + engine=claude, + runner=load_err, + status="load_error", + issue="failed", + ), + ], + default_engine=codex, + ) + runtime = TransportRuntime( + router=router, + projects=ProjectsConfig(projects={}, default_project=None), + watch_config=True, + ) + + message = telegram_backend._build_startup_message( + runtime, startup_pwd=str(tmp_path) + ) + + assert "agents: `codex" in message + assert "misconfigured: pi" in message + assert "failed to load: claude" in message + + def test_telegram_backend_build_and_run_wires_config( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: @@ -55,7 +103,7 @@ def test_telegram_backend_build_and_run_wires_config( codex = EngineId("codex") runner = ScriptRunner([Return(answer="ok")], engine=codex) router = AutoRouter( - entries=[RunnerEntry(engine=codex, runner=runner, available=True)], + entries=[RunnerEntry(engine=codex, runner=runner)], default_engine=codex, ) runtime = TransportRuntime( @@ -80,13 +128,14 @@ def test_telegram_backend_build_and_run_wires_config( monkeypatch.setattr(telegram_backend, "run_main_loop", fake_run_main_loop) monkeypatch.setattr(telegram_backend, "TelegramClient", _FakeClient) - transport_config = { - "bot_token": "token", - "chat_id": 321, - "voice_transcription": True, - "files": {"enabled": True, "allowed_user_ids": [1, 2]}, - "topics": {"enabled": True, "scope": "main"}, - } + transport_config = TelegramTransportSettings( + bot_token="token", + chat_id=321, + voice_transcription=True, + voice_max_bytes=1234, + files=TelegramFilesSettings(enabled=True, allowed_user_ids=[1, 2]), + topics=TelegramTopicsSettings(enabled=True, scope="main"), + ) telegram_backend.TelegramBackend().build_and_run( transport_config=transport_config, @@ -100,18 +149,19 @@ def test_telegram_backend_build_and_run_wires_config( kwargs = captured["kwargs"] assert cfg.chat_id == 321 assert cfg.voice_transcription is True + assert cfg.voice_max_bytes == 1234 assert cfg.files.enabled is True - assert cfg.files.allowed_user_ids == frozenset({1, 2}) + assert cfg.files.allowed_user_ids == [1, 2] assert cfg.topics.enabled is True assert cfg.bot.token == "token" assert kwargs["watch_config"] is True assert kwargs["transport_id"] == "telegram" -def test_build_files_config_defaults() -> None: - cfg = telegram_backend._build_files_config({}) +def test_telegram_files_settings_defaults() -> None: + cfg = TelegramFilesSettings() assert cfg.enabled is False assert cfg.auto_put is True assert cfg.uploads_dir == "incoming" - assert cfg.allowed_user_ids == frozenset() + assert cfg.allowed_user_ids == [] diff --git a/tests/test_telegram_bridge.py b/tests/test_telegram_bridge.py index bff838f..965b537 100644 --- a/tests/test_telegram_bridge.py +++ b/tests/test_telegram_bridge.py @@ -6,22 +6,30 @@ import anyio import pytest from takopi import commands, plugins -import takopi.telegram.bridge as bridge import takopi.telegram.loop as telegram_loop import takopi.telegram.commands as telegram_commands import takopi.telegram.topics as telegram_topics from takopi.directives import parse_directives +from takopi.telegram.api_models import ( + Chat, + ChatMember, + File, + ForumTopic, + Message, + Update, + User, +) +from takopi.settings import TelegramFilesSettings, TelegramTopicsSettings from takopi.telegram.bridge import ( TelegramBridgeConfig, - TelegramFilesConfig, TelegramPresenter, TelegramTransport, build_bot_commands, handle_callback_cancel, handle_cancel, is_cancel_command, - send_with_resume, run_main_loop, + send_with_resume, ) from takopi.telegram.client import BotClient from takopi.telegram.topic_state import TopicStateStore, resolve_state_path @@ -122,13 +130,13 @@ class _FakeBot(BotClient): offset: int | None, timeout_s: int = 50, allowed_updates: list[str] | None = None, - ) -> list[dict[str, Any]] | None: + ) -> list[Update] | None: _ = offset _ = timeout_s _ = allowed_updates return [] - async def get_file(self, file_id: str) -> dict[str, Any] | None: + async def get_file(self, file_id: str) -> File | None: _ = file_id return None @@ -148,7 +156,7 @@ class _FakeBot(BotClient): reply_markup: dict | None = None, *, replace_message_id: int | None = None, - ) -> dict[str, Any]: + ) -> Message: self.send_calls.append( { "chat_id": chat_id, @@ -162,7 +170,7 @@ class _FakeBot(BotClient): "replace_message_id": replace_message_id, } ) - return {"message_id": 1} + return Message(message_id=1) async def send_document( self, @@ -173,7 +181,7 @@ class _FakeBot(BotClient): message_thread_id: int | None = None, disable_notification: bool | None = False, caption: str | None = None, - ) -> dict[str, Any]: + ) -> Message: self.document_calls.append( { "chat_id": chat_id, @@ -185,7 +193,7 @@ class _FakeBot(BotClient): "caption": caption, } ) - return {"message_id": 2} + return Message(message_id=2) async def edit_message_text( self, @@ -197,7 +205,7 @@ class _FakeBot(BotClient): reply_markup: dict | None = None, *, wait: bool = True, - ) -> dict[str, Any]: + ) -> Message: self.edit_calls.append( { "chat_id": chat_id, @@ -209,7 +217,7 @@ class _FakeBot(BotClient): "wait": wait, } ) - return {"message_id": message_id} + return Message(message_id=message_id) async def delete_message(self, chat_id: int, message_id: int) -> bool: self.delete_calls.append({"chat_id": chat_id, "message_id": message_id}) @@ -231,26 +239,22 @@ class _FakeBot(BotClient): ) return True - async def get_me(self) -> dict[str, Any] | None: - return {"id": 1} + async def get_me(self) -> User | None: + return User(id=1, username="bot") - async def get_chat(self, chat_id: int) -> dict[str, Any] | None: + async def get_chat(self, chat_id: int) -> Chat | None: _ = chat_id - return {"id": chat_id, "type": "supergroup", "is_forum": True} + return Chat(id=chat_id, type="supergroup", is_forum=True) - async def get_chat_member( - self, chat_id: int, user_id: int - ) -> dict[str, Any] | None: + async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None: _ = chat_id _ = user_id - return {"status": "administrator", "can_manage_topics": True} + return ChatMember(status="administrator", can_manage_topics=True) - async def create_forum_topic( - self, chat_id: int, name: str - ) -> dict[str, Any] | None: + async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None: _ = chat_id _ = name - return {"message_thread_id": 1} + return ForumTopic(message_thread_id=1) async def edit_forum_topic( self, chat_id: int, message_thread_id: int, name: str @@ -539,10 +543,10 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None: offset: int | None, timeout_s: int = 50, allowed_updates: list[str] | None = None, - ) -> list[dict[str, Any]] | None: + ) -> list[Update] | None: return None - async def get_file(self, file_id: str) -> dict[str, Any] | None: + async def get_file(self, file_id: str) -> File | None: _ = file_id return None @@ -562,7 +566,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None: reply_markup: dict | None = None, *, replace_message_id: int | None = None, - ) -> dict | None: + ) -> Message | None: _ = reply_markup return None @@ -575,7 +579,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None: message_thread_id: int | None = None, disable_notification: bool | None = False, caption: str | None = None, - ) -> dict | None: + ) -> Message | None: _ = ( chat_id, filename, @@ -597,7 +601,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None: reply_markup: dict | None = None, *, wait: bool = True, - ) -> dict | None: + ) -> Message | None: self.edit_calls.append( { "chat_id": chat_id, @@ -611,7 +615,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None: ) if not wait: return None - return {"message_id": message_id} + return Message(message_id=message_id) async def delete_message( self, @@ -629,7 +633,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None: ) -> bool: return False - async def get_me(self) -> dict[str, Any] | None: + async def get_me(self) -> User | None: return None async def close(self) -> None: @@ -778,9 +782,9 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None: payload = b"hello" class _FileBot(_FakeBot): - async def get_file(self, file_id: str) -> dict[str, Any] | None: + async def get_file(self, file_id: str) -> File | None: _ = file_id - return {"file_path": "files/hello.txt"} + return File(file_path="files/hello.txt") async def download_file(self, file_path: str) -> bytes | None: _ = file_path @@ -811,7 +815,7 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None: chat_id=123, startup_msg="", exec_cfg=exec_cfg, - files=TelegramFilesConfig(enabled=True), + files=TelegramFilesSettings(enabled=True), ) msg = TelegramIncomingMessage( transport="telegram", @@ -876,9 +880,9 @@ async def test_handle_file_get_sends_document_for_allowed_user( chat_id=123, startup_msg="", exec_cfg=exec_cfg, - files=TelegramFilesConfig( + files=TelegramFilesSettings( enabled=True, - allowed_user_ids=frozenset({42}), + allowed_user_ids=[42], ), ) msg = TelegramIncomingMessage( @@ -1006,7 +1010,7 @@ def test_topic_title_projects_scope_includes_project() -> None: transport = _FakeTransport() cfg = replace( _make_cfg(transport), - topics=bridge.TelegramTopicsConfig( + topics=TelegramTopicsSettings( enabled=True, scope="projects", ), @@ -1277,7 +1281,7 @@ async def test_run_main_loop_persists_topic_sessions_in_project_scope( chat_id=123, startup_msg="", exec_cfg=exec_cfg, - topics=bridge.TelegramTopicsConfig( + topics=TelegramTopicsSettings( enabled=True, scope="projects", ), @@ -1363,11 +1367,11 @@ async def test_run_main_loop_batches_media_group_upload( } class _MediaBot(_FakeBot): - async def get_file(self, file_id: str) -> dict[str, Any] | None: + async def get_file(self, file_id: str) -> File | None: file_path = file_map.get(file_id) if file_path is None: return None - return {"file_path": file_path} + return File(file_path=file_path) async def download_file(self, file_path: str) -> bytes | None: return payloads.get(file_path) @@ -1397,7 +1401,7 @@ async def test_run_main_loop_batches_media_group_upload( chat_id=123, startup_msg="", exec_cfg=exec_cfg, - files=TelegramFilesConfig(enabled=True, auto_put=True), + files=TelegramFilesSettings(enabled=True, auto_put=True), ) msg1 = TelegramIncomingMessage( transport="telegram", diff --git a/tests/test_telegram_client.py b/tests/test_telegram_client.py index 9b2da4e..f34d962 100644 --- a/tests/test_telegram_client.py +++ b/tests/test_telegram_client.py @@ -57,3 +57,175 @@ async def test_no_token_in_logs_on_http_error( out = capsys.readouterr().out assert token not in out assert "bot[REDACTED]" in out + + +@pytest.mark.anyio +async def test_telegram_429_no_retry_post_form() -> None: + calls: list[int] = [] + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(1) + return httpx.Response( + 429, + json={ + "ok": False, + "description": "retry", + "parameters": {"retry_after": 2}, + }, + request=request, + ) + + transport = httpx.MockTransport(handler) + + client = httpx.AsyncClient(transport=transport) + try: + tg = TelegramClient("123:abcDEF_ghij", http_client=client) + with pytest.raises(TelegramRetryAfter) as exc: + await tg._post_form( + "sendDocument", + {"chat_id": 1}, + files={"document": ("note.txt", b"hi")}, + ) + finally: + await client.aclose() + + assert exc.value.retry_after == 2 + assert len(calls) == 1 + + +@pytest.mark.anyio +async def test_telegram_429_defaults_retry_after_on_bad_body() -> None: + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(429, text="nope", request=request) + + transport = httpx.MockTransport(handler) + + client = httpx.AsyncClient(transport=transport) + try: + tg = TelegramClient("123:abcDEF_ghij", http_client=client) + with pytest.raises(TelegramRetryAfter) as exc: + await tg._post("sendMessage", {"chat_id": 1, "text": "hi"}) + finally: + await client.aclose() + + assert exc.value.retry_after == 5.0 + + +@pytest.mark.anyio +async def test_telegram_ok_false_returns_none() -> None: + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={"ok": False, "error_code": 400, "description": "bad"}, + request=request, + ) + + transport = httpx.MockTransport(handler) + + client = httpx.AsyncClient(transport=transport) + try: + tg = TelegramClient("123:abcDEF_ghij", http_client=client) + result = await tg._post("getUpdates", {"timeout": 1}) + finally: + await client.aclose() + + assert result is None + + +@pytest.mark.anyio +async def test_telegram_invalid_payload_returns_none() -> None: + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json=["not", "a", "dict"], request=request) + + transport = httpx.MockTransport(handler) + + client = httpx.AsyncClient(transport=transport) + try: + tg = TelegramClient("123:abcDEF_ghij", http_client=client) + result = await tg._post("getUpdates", {"timeout": 1}) + finally: + await client.aclose() + + assert result is None + + +@pytest.mark.anyio +async def test_telegram_decode_failure_returns_none() -> None: + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={"ok": True, "result": {"username": "bot-only"}}, + request=request, + ) + + transport = httpx.MockTransport(handler) + + client = httpx.AsyncClient(transport=transport) + try: + tg = TelegramClient("123:abcDEF_ghij", http_client=client) + result = await tg.get_me() + finally: + await client.aclose() + + assert result is None + + +@pytest.mark.anyio +async def test_telegram_download_file_retries_on_429() -> None: + calls: list[int] = [] + sleeps: list[float] = [] + + async def sleep(delay: float) -> None: + sleeps.append(delay) + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(1) + if len(calls) == 1: + return httpx.Response( + 429, + json={"ok": False, "parameters": {"retry_after": 3}}, + request=request, + ) + return httpx.Response(200, content=b"ok", request=request) + + transport = httpx.MockTransport(handler) + + client = httpx.AsyncClient(transport=transport) + try: + tg = TelegramClient("123:abcDEF_ghij", http_client=client, sleep=sleep) + payload = await tg.download_file("path/to/file") + finally: + await client.aclose() + + assert payload == b"ok" + assert sleeps == [3.0] + assert len(calls) == 2 + + +@pytest.mark.anyio +async def test_telegram_download_file_429_defaults_retry_after_on_bad_body() -> None: + sleeps: list[float] = [] + + async def sleep(delay: float) -> None: + sleeps.append(delay) + + calls: list[int] = [] + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(1) + if len(calls) == 1: + return httpx.Response(429, text="nope", request=request) + return httpx.Response(200, content=b"ok", request=request) + + transport = httpx.MockTransport(handler) + + client = httpx.AsyncClient(transport=transport) + try: + tg = TelegramClient("123:abcDEF_ghij", http_client=client, sleep=sleep) + payload = await tg.download_file("path") + finally: + await client.aclose() + + assert payload == b"ok" + assert sleeps == [5.0] + assert len(calls) == 2 diff --git a/tests/test_telegram_queue.py b/tests/test_telegram_queue.py index 55279c1..30fe15e 100644 --- a/tests/test_telegram_queue.py +++ b/tests/test_telegram_queue.py @@ -3,6 +3,7 @@ from typing import Any import anyio import pytest +from takopi.telegram.api_models import File, Message, Update, User from takopi.telegram.client import BotClient, TelegramClient, TelegramRetryAfter @@ -29,7 +30,7 @@ class _FakeBot(BotClient): reply_markup: dict | None = None, *, replace_message_id: int | None = None, - ) -> dict[str, Any]: + ) -> Message | None: _ = reply_to_message_id _ = disable_notification _ = message_thread_id @@ -38,7 +39,7 @@ class _FakeBot(BotClient): _ = reply_markup _ = replace_message_id self.calls.append("send_message") - return {"message_id": 1} + return Message(message_id=1) async def send_document( self, @@ -49,7 +50,7 @@ class _FakeBot(BotClient): message_thread_id: int | None = None, disable_notification: bool | None = False, caption: str | None = None, - ) -> dict[str, Any]: + ) -> Message | None: _ = ( chat_id, filename, @@ -60,7 +61,7 @@ class _FakeBot(BotClient): caption, ) self.calls.append("send_document") - return {"message_id": 1} + return Message(message_id=1) async def edit_message_text( self, @@ -72,7 +73,7 @@ class _FakeBot(BotClient): reply_markup: dict | None = None, *, wait: bool = True, - ) -> dict[str, Any]: + ) -> Message | None: _ = chat_id _ = message_id _ = entities @@ -85,7 +86,7 @@ class _FakeBot(BotClient): self._edit_attempts += 1 raise TelegramRetryAfter(self.retry_after) self._edit_attempts += 1 - return {"message_id": message_id} + return Message(message_id=message_id) async def delete_message( self, @@ -113,7 +114,7 @@ class _FakeBot(BotClient): offset: int | None, timeout_s: int = 50, allowed_updates: list[str] | None = None, - ) -> list[dict[str, Any]] | None: + ) -> list[Update] | None: _ = offset _ = timeout_s _ = allowed_updates @@ -123,7 +124,7 @@ class _FakeBot(BotClient): self._updates_attempts += 1 return [] - async def get_file(self, file_id: str) -> dict[str, Any] | None: + async def get_file(self, file_id: str) -> File | None: _ = file_id return None @@ -134,8 +135,8 @@ class _FakeBot(BotClient): async def close(self) -> None: return None - async def get_me(self) -> dict[str, Any] | None: - return {"id": 1} + async def get_me(self) -> User | None: + return User(id=1) async def answer_callback_query( self, @@ -187,7 +188,7 @@ async def test_edits_coalesce_latest() -> None: reply_markup: dict | None = None, *, wait: bool = True, - ) -> dict: + ) -> Message | None: if self._block_first: self._block_first = False self.edit_started.set() @@ -305,7 +306,8 @@ async def test_retry_after_retries_once() -> None: text="retry", ) - assert result == {"message_id": 1} + assert result is not None + assert result.message_id == 1 assert bot._edit_attempts == 2 diff --git a/tests/test_telegram_voice.py b/tests/test_telegram_voice.py new file mode 100644 index 0000000..2dc64bb --- /dev/null +++ b/tests/test_telegram_voice.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import pytest + +from takopi.telegram.api_models import ( + Chat, + ChatMember, + File, + ForumTopic, + Message, + Update, + User, +) +from takopi.telegram.client import BotClient +from takopi.telegram.types import TelegramIncomingMessage, TelegramVoice +from takopi.telegram.voice import transcribe_voice + + +class _Bot(BotClient): + def __init__(self, *, file_info: File | None, audio: bytes | None) -> None: + self._file_info = file_info + self._audio = audio + + async def close(self) -> None: + return None + + async def get_updates( + self, + offset: int | None, + timeout_s: int = 50, + allowed_updates: list[str] | None = None, + ) -> list[Update] | None: + _ = offset, timeout_s, allowed_updates + return [] + + async def get_file(self, file_id: str) -> File | None: + _ = file_id + return self._file_info + + async def download_file(self, file_path: str) -> bytes | None: + _ = file_path + return self._audio + + async def send_message( + self, + chat_id: int, + text: str, + reply_to_message_id: int | None = None, + disable_notification: bool | None = False, + message_thread_id: int | None = None, + entities: list[dict] | None = None, + parse_mode: str | None = None, + reply_markup: dict | None = None, + *, + replace_message_id: int | None = None, + ) -> Message | None: + _ = ( + chat_id, + text, + reply_to_message_id, + disable_notification, + message_thread_id, + entities, + parse_mode, + reply_markup, + replace_message_id, + ) + raise AssertionError("send_message should not be called") + + async def send_document( + self, + chat_id: int, + filename: str, + content: bytes, + reply_to_message_id: int | None = None, + message_thread_id: int | None = None, + disable_notification: bool | None = False, + caption: str | None = None, + ) -> Message | None: + _ = ( + chat_id, + filename, + content, + reply_to_message_id, + message_thread_id, + disable_notification, + caption, + ) + raise AssertionError("send_document should not be called") + + async def edit_message_text( + self, + chat_id: int, + message_id: int, + text: str, + entities: list[dict] | None = None, + parse_mode: str | None = None, + reply_markup: dict | None = None, + *, + wait: bool = True, + ) -> Message | None: + _ = ( + chat_id, + message_id, + text, + entities, + parse_mode, + reply_markup, + wait, + ) + raise AssertionError("edit_message_text should not be called") + + async def delete_message(self, chat_id: int, message_id: int) -> bool: + _ = chat_id, message_id + raise AssertionError("delete_message should not be called") + + async def set_my_commands( + self, + commands: list[dict], + *, + scope: dict | None = None, + language_code: str | None = None, + ) -> bool: + _ = commands, scope, language_code + raise AssertionError("set_my_commands should not be called") + + async def get_me(self) -> User | None: + raise AssertionError("get_me should not be called") + + async def answer_callback_query( + self, + callback_query_id: str, + text: str | None = None, + show_alert: bool | None = None, + ) -> bool: + _ = callback_query_id, text, show_alert + raise AssertionError("answer_callback_query should not be called") + + async def get_chat(self, chat_id: int) -> Chat | None: + _ = chat_id + raise AssertionError("get_chat should not be called") + + async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None: + _ = chat_id, user_id + raise AssertionError("get_chat_member should not be called") + + async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None: + _ = chat_id, name + raise AssertionError("create_forum_topic should not be called") + + async def edit_forum_topic( + self, chat_id: int, message_thread_id: int, name: str + ) -> bool: + _ = chat_id, message_thread_id, name + raise AssertionError("edit_forum_topic should not be called") + + +def _voice_message(*, file_size: int = 123) -> TelegramIncomingMessage: + voice = TelegramVoice( + file_id="voice-id", + mime_type="audio/ogg", + file_size=file_size, + duration=1, + raw={}, + ) + return TelegramIncomingMessage( + transport="telegram", + chat_id=1, + message_id=1, + text="", + reply_to_message_id=None, + reply_to_text=None, + sender_id=1, + voice=voice, + raw={}, + ) + + +@pytest.mark.anyio +async def test_transcribe_voice_handles_missing_file() -> None: + replies: list[str] = [] + + async def reply(**kwargs) -> None: + replies.append(kwargs["text"]) + + bot = _Bot(file_info=None, audio=None) + result = await transcribe_voice( + bot=bot, + msg=_voice_message(), + enabled=True, + reply=reply, + ) + + assert result is None + assert replies[-1] == "failed to fetch voice file." + + +@pytest.mark.anyio +async def test_transcribe_voice_handles_missing_download() -> None: + replies: list[str] = [] + + async def reply(**kwargs) -> None: + replies.append(kwargs["text"]) + + bot = _Bot(file_info=File(file_path="voice.ogg"), audio=None) + result = await transcribe_voice( + bot=bot, + msg=_voice_message(), + enabled=True, + reply=reply, + ) + + assert result is None + assert replies[-1] == "failed to download voice file." + + +@pytest.mark.anyio +async def test_transcribe_voice_rejects_large_voice_without_downloading() -> None: + replies: list[str] = [] + + async def reply(**kwargs) -> None: + replies.append(kwargs["text"]) + + class _NoFetchBot(_Bot): + async def get_file(self, file_id: str) -> File | None: # type: ignore[override] + _ = file_id + raise AssertionError("get_file should not be called") + + async def download_file(self, file_path: str) -> bytes | None: # type: ignore[override] + _ = file_path + raise AssertionError("download_file should not be called") + + bot = _NoFetchBot(file_info=None, audio=None) + result = await transcribe_voice( + bot=bot, + msg=_voice_message(file_size=10_000), + enabled=True, + max_bytes=100, + reply=reply, + ) + + assert result is None + assert replies[-1] == "voice message is too large to transcribe." diff --git a/tests/test_transport_registry.py b/tests/test_transport_registry.py index 5bbc64a..bad214a 100644 --- a/tests/test_transport_registry.py +++ b/tests/test_transport_registry.py @@ -15,14 +15,14 @@ class DummyTransport: def interactive_setup(self, *, force: bool) -> bool: raise NotImplementedError - def lock_token(self, *, transport_config: dict[str, object], config_path): + def lock_token(self, *, transport_config: object, config_path): _ = transport_config, config_path raise NotImplementedError def build_and_run( self, *, - transport_config: dict[str, object], + transport_config: object, config_path, runtime, final_notify: bool,