refactor(telegram): boundary types (#90)

This commit is contained in:
banteg
2026-01-11 21:36:07 +04:00
committed by GitHub
parent c6c34ac17f
commit e8c478d786
23 changed files with 1116 additions and 581 deletions
+4 -5
View File
@@ -105,11 +105,7 @@ def _default_engine_for_setup(
if settings is None or config_path is None: if settings is None or config_path is None:
return "codex" return "codex"
value = settings.default_engine value = settings.default_engine
if not isinstance(value, str) or not value.strip(): return value
raise ConfigError(
f"Invalid `default_engine` in {config_path}; expected a non-empty string."
)
return value.strip()
def _config_path_display(path: Path) -> str: def _config_path_display(path: Path) -> str:
@@ -235,6 +231,9 @@ def _run_auto_router(
default_engine_override=default_engine_override, default_engine_override=default_engine_override,
reserved=("cancel",), reserved=("cancel",),
) )
if settings.transport == "telegram":
transport_config = settings.transports.telegram
else:
transport_config = settings.transport_config( transport_config = settings.transport_config(
settings.transport, config_path=config_path settings.transport, config_path=config_path
) )
+11 -2
View File
@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable from typing import Iterable, Literal, TypeAlias
from .model import EngineId, ResumeToken from .model import EngineId, ResumeToken
from .runner import Runner from .runner import Runner
@@ -17,13 +17,22 @@ class RunnerUnavailableError(RuntimeError):
self.issue = issue self.issue = issue
EngineStatus: TypeAlias = Literal["ok", "missing_cli", "bad_config", "load_error"]
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class RunnerEntry: class RunnerEntry:
engine: EngineId engine: EngineId
runner: Runner runner: Runner
available: bool = True status: EngineStatus = "ok"
issue: str | None = None issue: str | None = None
@property
def available(self) -> bool:
# "bad_config" means we ignored user config and built the runner with defaults.
# The engine is still runnable, but a warning should be surfaced to the user.
return self.status in {"ok", "bad_config"}
class AutoRouter: class AutoRouter:
def __init__( def __init__(
+13 -6
View File
@@ -9,7 +9,7 @@ from .backends import EngineBackend
from .config import ConfigError, ProjectsConfig from .config import ConfigError, ProjectsConfig
from .engines import get_backend, list_backend_ids from .engines import get_backend, list_backend_ids
from .logging import get_logger from .logging import get_logger
from .router import AutoRouter, RunnerEntry from .router import AutoRouter, EngineStatus, RunnerEntry
from .settings import TakopiSettings from .settings import TakopiSettings
from .transport_runtime import TransportRuntime from .transport_runtime import TransportRuntime
@@ -83,6 +83,7 @@ def build_router(
for backend in backends: for backend in backends:
engine_id = backend.id engine_id = backend.id
issue: str | None = None issue: str | None = None
status: EngineStatus = "ok"
engine_cfg: dict engine_cfg: dict
try: try:
engine_cfg = settings.engine_config(engine_id, config_path=config_path) engine_cfg = settings.engine_config(engine_id, config_path=config_path)
@@ -90,6 +91,7 @@ def build_router(
if engine_id == default_engine: if engine_id == default_engine:
raise raise
issue = str(exc) issue = str(exc)
status = "bad_config"
engine_cfg = {} engine_cfg = {}
try: try:
@@ -104,26 +106,31 @@ def build_router(
except Exception as fallback_exc: except Exception as fallback_exc:
warnings.append(f"{engine_id}: {issue or str(fallback_exc)}") warnings.append(f"{engine_id}: {issue or str(fallback_exc)}")
continue continue
status = "bad_config"
else: else:
status = "load_error"
warnings.append(f"{engine_id}: {issue}") warnings.append(f"{engine_id}: {issue}")
continue continue
cmd = backend.cli_cmd or backend.id cmd = backend.cli_cmd or backend.id
if shutil.which(cmd) is None: if shutil.which(cmd) is None:
issue = issue or f"{cmd} not found on PATH" status = "missing_cli"
if issue:
issue = f"{issue}; {cmd} not found on PATH"
else:
issue = f"{cmd} not found on PATH"
if issue and engine_id == default_engine: if status != "ok" and engine_id == default_engine:
raise ConfigError(f"Default engine {engine_id!r} unavailable: {issue}") raise ConfigError(f"Default engine {engine_id!r} unavailable: {issue}")
available = issue is None if status != "ok" and engine_id != default_engine:
if issue and engine_id != default_engine:
warnings.append(f"{engine_id}: {issue}") warnings.append(f"{engine_id}: {issue}")
entries.append( entries.append(
RunnerEntry( RunnerEntry(
engine=engine_id, engine=engine_id,
runner=runner, runner=runner,
available=available, status=status,
issue=issue, issue=issue,
) )
) )
+11 -6
View File
@@ -83,6 +83,14 @@ class TelegramFilesSettings(BaseModel):
raise ValueError("files.uploads_dir must be a relative path") raise ValueError("files.uploads_dir must be a relative path")
return value return value
@property
def max_upload_bytes(self) -> int:
return 20 * 1024 * 1024
@property
def max_download_bytes(self) -> int:
return 50 * 1024 * 1024
class TelegramTransportSettings(BaseModel): class TelegramTransportSettings(BaseModel):
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
@@ -90,14 +98,13 @@ class TelegramTransportSettings(BaseModel):
bot_token: NonEmptyStr bot_token: NonEmptyStr
chat_id: StrictInt chat_id: StrictInt
voice_transcription: bool = False voice_transcription: bool = False
voice_max_bytes: StrictInt = 10 * 1024 * 1024
topics: TelegramTopicsSettings = Field(default_factory=TelegramTopicsSettings) topics: TelegramTopicsSettings = Field(default_factory=TelegramTopicsSettings)
files: TelegramFilesSettings = Field(default_factory=TelegramFilesSettings) files: TelegramFilesSettings = Field(default_factory=TelegramFilesSettings)
class TransportsSettings(BaseModel): class TransportsSettings(BaseModel):
telegram: TelegramTransportSettings = Field( telegram: TelegramTransportSettings
default_factory=TelegramTransportSettings
)
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@@ -132,7 +139,7 @@ class TakopiSettings(BaseSettings):
projects: dict[str, ProjectSettings] = Field(default_factory=dict) projects: dict[str, ProjectSettings] = Field(default_factory=dict)
transport: NonEmptyStr = "telegram" transport: NonEmptyStr = "telegram"
transports: TransportsSettings = Field(default_factory=TransportsSettings) transports: TransportsSettings
plugins: PluginsSettings = Field(default_factory=PluginsSettings) plugins: PluginsSettings = Field(default_factory=PluginsSettings)
@@ -310,8 +317,6 @@ def require_telegram(settings: TakopiSettings, config_path: Path) -> tuple[str,
"(telegram only for now)." "(telegram only for now)."
) )
tg = settings.transports.telegram tg = settings.transports.telegram
if not tg.bot_token:
raise ConfigError(f"Missing bot token in {config_path}.")
return tg.bot_token, tg.chat_id return tg.bot_token, tg.chat_id
+53
View File
@@ -0,0 +1,53 @@
from __future__ import annotations
from typing import Any
import msgspec
__all__ = [
"Chat",
"ChatMember",
"File",
"ForumTopic",
"Message",
"Update",
"User",
]
class User(msgspec.Struct, forbid_unknown_fields=False):
id: int
username: str | None = None
first_name: str | None = None
last_name: str | None = None
class Chat(msgspec.Struct, forbid_unknown_fields=False):
id: int
type: str
is_forum: bool | None = None
class ChatMember(msgspec.Struct, forbid_unknown_fields=False):
status: str
can_manage_topics: bool | None = None
class Message(msgspec.Struct, forbid_unknown_fields=False):
message_id: int
message_thread_id: int | None = None
text: str | None = None
class File(msgspec.Struct, forbid_unknown_fields=False):
file_path: str
class ForumTopic(msgspec.Struct, forbid_unknown_fields=False):
message_thread_id: int
class Update(msgspec.Struct, forbid_unknown_fields=False):
update_id: int
message: dict[str, Any] | None = None
callback_query: dict[str, Any] | None = None
+33 -50
View File
@@ -2,22 +2,19 @@ from __future__ import annotations
import os import os
from pathlib import Path from pathlib import Path
from typing import cast
import anyio import anyio
from ..backends import EngineBackend from ..backends import EngineBackend
from ..runner_bridge import ExecBridgeConfig
from ..logging import get_logger from ..logging import get_logger
from ..runner_bridge import ExecBridgeConfig
from ..transports import SetupResult, TransportBackend from ..settings import TelegramTransportSettings
from ..transport_runtime import TransportRuntime from ..transport_runtime import TransportRuntime
from ..transports import SetupResult, TransportBackend
from .bridge import ( from .bridge import (
TelegramBridgeConfig, TelegramBridgeConfig,
TelegramPresenter, TelegramPresenter,
TelegramTransport, TelegramTransport,
TelegramFilesConfig,
TelegramTopicsConfig,
run_main_loop, run_main_loop,
) )
from .client import TelegramClient from .client import TelegramClient
@@ -26,6 +23,12 @@ from .onboarding import check_setup, interactive_setup
logger = get_logger(__name__) logger = get_logger(__name__)
def _expect_transport_settings(transport_config: object) -> TelegramTransportSettings:
if isinstance(transport_config, TelegramTransportSettings):
return transport_config
raise TypeError("transport_config must be TelegramTransportSettings")
def _build_startup_message( def _build_startup_message(
runtime: TransportRuntime, runtime: TransportRuntime,
*, *,
@@ -33,9 +36,20 @@ def _build_startup_message(
) -> str: ) -> str:
available_engines = list(runtime.available_engine_ids()) available_engines = list(runtime.available_engine_ids())
missing_engines = list(runtime.missing_engine_ids()) missing_engines = list(runtime.missing_engine_ids())
misconfigured_engines = list(runtime.engine_ids_with_status("bad_config"))
failed_engines = list(runtime.engine_ids_with_status("load_error"))
engine_list = ", ".join(available_engines) if available_engines else "none" engine_list = ", ".join(available_engines) if available_engines else "none"
notes: list[str] = []
if missing_engines: if missing_engines:
engine_list = f"{engine_list} (not installed: {', '.join(missing_engines)})" notes.append(f"not installed: {', '.join(missing_engines)}")
if misconfigured_engines:
notes.append(f"misconfigured: {', '.join(misconfigured_engines)}")
if failed_engines:
notes.append(f"failed to load: {', '.join(failed_engines)}")
if notes:
engine_list = f"{engine_list} ({'; '.join(notes)})"
project_aliases = sorted( project_aliases = sorted(
{alias for alias in runtime.project_aliases()}, key=str.lower {alias for alias in runtime.project_aliases()}, key=str.lower
) )
@@ -49,34 +63,6 @@ def _build_startup_message(
) )
def _build_topics_config(transport_config: dict[str, object]) -> TelegramTopicsConfig:
raw = cast(dict[str, object], transport_config.get("topics", {}))
return TelegramTopicsConfig(
enabled=cast(bool, raw.get("enabled", False)),
scope=cast(str, raw.get("scope", "auto")),
)
def _build_files_config(transport_config: dict[str, object]) -> TelegramFilesConfig:
defaults = TelegramFilesConfig()
raw = cast(dict[str, object], transport_config.get("files", {}))
return TelegramFilesConfig(
enabled=cast(bool, raw.get("enabled", defaults.enabled)),
auto_put=cast(bool, raw.get("auto_put", defaults.auto_put)),
uploads_dir=cast(str, raw.get("uploads_dir", defaults.uploads_dir)),
max_upload_bytes=defaults.max_upload_bytes,
max_download_bytes=defaults.max_download_bytes,
allowed_user_ids=frozenset(
cast(
list[int], raw.get("allowed_user_ids", list(defaults.allowed_user_ids))
)
),
deny_globs=tuple(
cast(list[str], raw.get("deny_globs", list(defaults.deny_globs)))
),
)
class TelegramBackend(TransportBackend): class TelegramBackend(TransportBackend):
id = "telegram" id = "telegram"
description = "Telegram bot" description = "Telegram bot"
@@ -92,23 +78,23 @@ class TelegramBackend(TransportBackend):
def interactive_setup(self, *, force: bool) -> bool: def interactive_setup(self, *, force: bool) -> bool:
return interactive_setup(force=force) return interactive_setup(force=force)
def lock_token( def lock_token(self, *, transport_config: object, config_path: Path) -> str | None:
self, *, transport_config: dict[str, object], config_path: Path
) -> str | None:
_ = config_path _ = config_path
return cast(str, transport_config.get("bot_token")) settings = _expect_transport_settings(transport_config)
return settings.bot_token
def build_and_run( def build_and_run(
self, self,
*, *,
transport_config: dict[str, object], transport_config: object,
config_path: Path, config_path: Path,
runtime: TransportRuntime, runtime: TransportRuntime,
final_notify: bool, final_notify: bool,
default_engine_override: str | None, default_engine_override: str | None,
) -> None: ) -> None:
token = cast(str, transport_config.get("bot_token")) settings = _expect_transport_settings(transport_config)
chat_id = cast(int, transport_config.get("chat_id")) token = settings.bot_token
chat_id = settings.chat_id
startup_msg = _build_startup_message( startup_msg = _build_startup_message(
runtime, runtime,
startup_pwd=os.getcwd(), startup_pwd=os.getcwd(),
@@ -121,19 +107,16 @@ class TelegramBackend(TransportBackend):
presenter=presenter, presenter=presenter,
final_notify=final_notify, final_notify=final_notify,
) )
topics = _build_topics_config(transport_config)
files = _build_files_config(transport_config)
cfg = TelegramBridgeConfig( cfg = TelegramBridgeConfig(
bot=bot, bot=bot,
runtime=runtime, runtime=runtime,
chat_id=chat_id, chat_id=chat_id,
startup_msg=startup_msg, startup_msg=startup_msg,
exec_cfg=exec_cfg, exec_cfg=exec_cfg,
voice_transcription=cast( voice_transcription=settings.voice_transcription,
bool, transport_config.get("voice_transcription", False) voice_max_bytes=int(settings.voice_max_bytes),
), topics=settings.topics,
topics=topics, files=settings.files,
files=files,
) )
async def run_loop() -> None: async def run_loop() -> None:
@@ -142,7 +125,7 @@ class TelegramBackend(TransportBackend):
watch_config=runtime.watch_config, watch_config=runtime.watch_config,
default_engine_override=default_engine_override, default_engine_override=default_engine_override,
transport_id=self.id, transport_id=self.id,
transport_config=transport_config, transport_config=settings,
) )
anyio.run(run_loop) anyio.run(run_loop)
+12 -31
View File
@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import cast from typing import cast
from ..logging import get_logger from ..logging import get_logger
@@ -12,6 +12,11 @@ from ..transport import MessageRef, RenderedMessage, SendOptions, Transport
from ..transport_runtime import TransportRuntime from ..transport_runtime import TransportRuntime
from ..context import RunContext from ..context import RunContext
from ..model import ResumeToken from ..model import ResumeToken
from ..settings import (
TelegramFilesSettings,
TelegramTopicsSettings,
TelegramTransportSettings,
)
from .client import BotClient from .client import BotClient
from .render import prepare_telegram from .render import prepare_telegram
from .types import TelegramCallbackQuery, TelegramIncomingMessage from .types import TelegramCallbackQuery, TelegramIncomingMessage
@@ -20,8 +25,6 @@ logger = get_logger(__name__)
__all__ = [ __all__ = [
"TelegramBridgeConfig", "TelegramBridgeConfig",
"TelegramFilesConfig",
"TelegramTopicsConfig",
"TelegramPresenter", "TelegramPresenter",
"TelegramTransport", "TelegramTransport",
"build_bot_commands", "build_bot_commands",
@@ -85,29 +88,6 @@ def _is_cancelled_label(label: str) -> bool:
return stripped.lower() == "cancelled" return stripped.lower() == "cancelled"
@dataclass(frozen=True)
class TelegramFilesConfig:
enabled: bool = False
auto_put: bool = True
uploads_dir: str = "incoming"
max_upload_bytes: int = 20 * 1024 * 1024
max_download_bytes: int = 50 * 1024 * 1024
allowed_user_ids: frozenset[int] = frozenset()
deny_globs: tuple[str, ...] = (
".git/**",
".env",
".envrc",
"**/*.pem",
"**/.ssh/**",
)
@dataclass(frozen=True)
class TelegramTopicsConfig:
enabled: bool = False
scope: str = "auto"
@dataclass(frozen=True) @dataclass(frozen=True)
class TelegramBridgeConfig: class TelegramBridgeConfig:
bot: BotClient bot: BotClient
@@ -116,9 +96,10 @@ class TelegramBridgeConfig:
startup_msg: str startup_msg: str
exec_cfg: ExecBridgeConfig exec_cfg: ExecBridgeConfig
voice_transcription: bool = False voice_transcription: bool = False
files: TelegramFilesConfig = TelegramFilesConfig() voice_max_bytes: int = 10 * 1024 * 1024
files: TelegramFilesSettings = field(default_factory=TelegramFilesSettings)
chat_ids: tuple[int, ...] | None = None chat_ids: tuple[int, ...] | None = None
topics: TelegramTopicsConfig = TelegramTopicsConfig() topics: TelegramTopicsSettings = field(default_factory=TelegramTopicsSettings)
class TelegramTransport: class TelegramTransport:
@@ -166,7 +147,7 @@ class TelegramTransport:
) )
if sent is None: if sent is None:
return None return None
message_id = cast(int, sent["message_id"]) message_id = sent.message_id
return MessageRef( return MessageRef(
channel_id=chat_id, channel_id=chat_id,
message_id=message_id, message_id=message_id,
@@ -192,7 +173,7 @@ class TelegramTransport:
) )
if edited is None: if edited is None:
return ref if not wait else None return ref if not wait else None
message_id = cast(int, edited.get("message_id", message_id)) message_id = edited.message_id
return MessageRef( return MessageRef(
channel_id=chat_id, channel_id=chat_id,
message_id=message_id, message_id=message_id,
@@ -287,7 +268,7 @@ async def run_main_loop(
watch_config: bool | None = None, watch_config: bool | None = None,
default_engine_override: str | None = None, default_engine_override: str | None = None,
transport_id: str | None = None, transport_id: str | None = None,
transport_config: dict[str, object] | None = None, transport_config: TelegramTransportSettings | None = None,
) -> None: ) -> None:
from .loop import run_main_loop as _run_main_loop from .loop import run_main_loop as _run_main_loop
+247 -205
View File
@@ -12,13 +12,16 @@ from typing import (
Iterable, Iterable,
Protocol, Protocol,
TYPE_CHECKING, TYPE_CHECKING,
TypeVar,
) )
import msgspec
import httpx import httpx
import anyio import anyio
from ..logging import get_logger from ..logging import get_logger
from .api_models import Chat, ChatMember, File, ForumTopic, Message, Update, User
from .types import ( from .types import (
TelegramCallbackQuery, TelegramCallbackQuery,
TelegramDocument, TelegramDocument,
@@ -29,6 +32,8 @@ from .types import (
logger = get_logger(__name__) logger = get_logger(__name__)
T = TypeVar("T")
SEND_PRIORITY = 0 SEND_PRIORITY = 0
DELETE_PRIORITY = 1 DELETE_PRIORITY = 1
@@ -51,15 +56,20 @@ def is_group_chat_id(chat_id: int) -> bool:
def parse_incoming_update( def parse_incoming_update(
update: dict[str, Any], update: Update | dict[str, Any],
*, *,
chat_id: int | None = None, chat_id: int | None = None,
chat_ids: set[int] | None = None, chat_ids: set[int] | None = None,
) -> TelegramIncomingUpdate | None: ) -> TelegramIncomingUpdate | None:
if isinstance(update, Update):
msg = update.message
callback_query = update.callback_query
else:
msg = update.get("message") msg = update.get("message")
callback_query = update.get("callback_query")
if isinstance(msg, dict): if isinstance(msg, dict):
return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids) return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids)
callback_query = update.get("callback_query")
if isinstance(callback_query, dict): if isinstance(callback_query, dict):
return _parse_callback_query( return _parse_callback_query(
callback_query, callback_query,
@@ -303,7 +313,7 @@ async def poll_incoming(
if allowed is None and chat_id is not None: if allowed is None and chat_id is not None:
allowed = {chat_id} allowed = {chat_id}
for upd in updates: for upd in updates:
offset = upd["update_id"] + 1 offset = upd.update_id + 1
msg = parse_incoming_update(upd, chat_ids=allowed) msg = parse_incoming_update(upd, chat_ids=allowed)
if msg is not None: if msg is not None:
yield msg yield msg
@@ -317,9 +327,9 @@ class BotClient(Protocol):
offset: int | None, offset: int | None,
timeout_s: int = 50, timeout_s: int = 50,
allowed_updates: list[str] | None = None, allowed_updates: list[str] | None = None,
) -> list[dict] | None: ... ) -> list[Update] | None: ...
async def get_file(self, file_id: str) -> dict | None: ... async def get_file(self, file_id: str) -> File | None: ...
async def download_file(self, file_path: str) -> bytes | None: ... async def download_file(self, file_path: str) -> bytes | None: ...
@@ -335,7 +345,7 @@ class BotClient(Protocol):
reply_markup: dict[str, Any] | None = None, reply_markup: dict[str, Any] | None = None,
*, *,
replace_message_id: int | None = None, replace_message_id: int | None = None,
) -> dict | None: ... ) -> Message | None: ...
async def send_document( async def send_document(
self, self,
@@ -346,7 +356,7 @@ class BotClient(Protocol):
message_thread_id: int | None = None, message_thread_id: int | None = None,
disable_notification: bool | None = False, disable_notification: bool | None = False,
caption: str | None = None, caption: str | None = None,
) -> dict | None: ... ) -> Message | None: ...
async def edit_message_text( async def edit_message_text(
self, self,
@@ -358,7 +368,7 @@ class BotClient(Protocol):
reply_markup: dict[str, Any] | None = None, reply_markup: dict[str, Any] | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict | None: ... ) -> Message | None: ...
async def delete_message( async def delete_message(
self, self,
@@ -374,7 +384,7 @@ class BotClient(Protocol):
language_code: str | None = None, language_code: str | None = None,
) -> bool: ... ) -> bool: ...
async def get_me(self) -> dict | None: ... async def get_me(self) -> User | None: ...
async def answer_callback_query( async def answer_callback_query(
self, self,
@@ -383,15 +393,17 @@ class BotClient(Protocol):
show_alert: bool | None = None, show_alert: bool | None = None,
) -> bool: ... ) -> bool: ...
async def get_chat(self, chat_id: int) -> dict | None: ... async def get_chat(self, chat_id: int) -> Chat | None: ...
async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None: ... async def get_chat_member(
self, chat_id: int, user_id: int
) -> ChatMember | None: ...
async def create_forum_topic( async def create_forum_topic(
self, self,
chat_id: int, chat_id: int,
name: str, name: str,
) -> dict | None: ... ) -> ForumTopic | None: ...
async def edit_forum_topic( async def edit_forum_topic(
self, self,
@@ -672,71 +684,13 @@ class TelegramClient:
if self._owns_http_client and self._http_client is not None: if self._owns_http_client and self._http_client is not None:
await self._http_client.aclose() await self._http_client.aclose()
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None: def _parse_telegram_envelope(
if self._http_client is None or self._base is None: self,
raise RuntimeError("TelegramClient is configured without an HTTP client.") *,
logger.debug("telegram.request", method=method, payload=json_data) method: str,
try: resp: httpx.Response,
resp = await self._http_client.post( payload: Any,
f"{self._base}/{method}", json=json_data ) -> Any | None:
)
except httpx.HTTPError as e:
url = getattr(e.request, "url", None)
logger.error(
"telegram.network_error",
method=method,
url=str(url) if url is not None else None,
error=str(e),
error_type=e.__class__.__name__,
)
return None
try:
resp.raise_for_status()
except httpx.HTTPStatusError as e:
if resp.status_code == 429:
retry_after: float | None = None
try:
payload = resp.json()
except Exception:
payload = None
if isinstance(payload, dict):
retry_after = retry_after_from_payload(payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method=method,
status=resp.status_code,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after) from e
body = resp.text
logger.error(
"telegram.http_error",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(e),
body=body,
)
return None
try:
payload = resp.json()
except Exception as e:
body = resp.text
logger.error(
"telegram.bad_response",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(e),
error_type=e.__class__.__name__,
body=body,
)
return None
if not isinstance(payload, dict): if not isinstance(payload, dict):
logger.error( logger.error(
"telegram.invalid_payload", "telegram.invalid_payload",
@@ -768,145 +722,193 @@ class TelegramClient:
logger.debug("telegram.response", method=method, payload=payload) logger.debug("telegram.response", method=method, payload=payload)
return payload.get("result") return payload.get("result")
async def _request(
self,
method: str,
*,
json: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
files: dict[str, Any] | None = None,
) -> Any | None:
if self._http_client is None or self._base is None:
raise RuntimeError("TelegramClient is configured without an HTTP client.")
request_payload = json if json is not None else data
logger.debug("telegram.request", method=method, payload=request_payload)
try:
if json is not None:
resp = await self._http_client.post(f"{self._base}/{method}", json=json)
else:
resp = await self._http_client.post(
f"{self._base}/{method}", data=data, files=files
)
except httpx.HTTPError as exc:
url = getattr(exc.request, "url", None)
logger.error(
"telegram.network_error",
method=method,
url=str(url) if url is not None else None,
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
try:
resp.raise_for_status()
except httpx.HTTPStatusError as exc:
if resp.status_code == 429:
retry_after: float | None = None
try:
response_payload = resp.json()
except Exception:
response_payload = None
if isinstance(response_payload, dict):
retry_after = retry_after_from_payload(response_payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method=method,
status=resp.status_code,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after) from exc
body = resp.text
logger.error(
"telegram.http_error",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(exc),
body=body,
)
return None
try:
response_payload = resp.json()
except Exception as exc:
body = resp.text
logger.error(
"telegram.bad_response",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(exc),
error_type=exc.__class__.__name__,
body=body,
)
return None
return self._parse_telegram_envelope(
method=method,
resp=resp,
payload=response_payload,
)
def _decode_result(
self,
*,
method: str,
payload: Any,
model: type[T],
) -> T | None:
if payload is None:
return None
try:
return msgspec.convert(payload, type=model)
except Exception as exc:
logger.error(
"telegram.decode_error",
method=method,
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
async def _call_with_retry_after(
self,
fn: Callable[[], Awaitable[T]],
) -> T:
while True:
try:
return await fn()
except TelegramRetryAfter as exc:
await self._sleep(exc.retry_after)
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
return await self._request(method, json=json_data)
async def _post_form( async def _post_form(
self, self,
method: str, method: str,
data: dict[str, Any], data: dict[str, Any],
files: dict[str, Any], files: dict[str, Any],
) -> Any | None: ) -> Any | None:
if self._http_client is None or self._base is None: return await self._request(method, data=data, files=files)
raise RuntimeError("TelegramClient is configured without an HTTP client.")
logger.debug("telegram.request", method=method, payload=data)
try:
resp = await self._http_client.post(
f"{self._base}/{method}", data=data, files=files
)
except httpx.HTTPError as e:
url = getattr(e.request, "url", None)
logger.error(
"telegram.network_error",
method=method,
url=str(url) if url is not None else None,
error=str(e),
error_type=e.__class__.__name__,
)
return None
try:
resp.raise_for_status()
except httpx.HTTPStatusError as e:
if resp.status_code == 429:
retry_after: float | None = None
try:
payload = resp.json()
except Exception:
payload = None
if isinstance(payload, dict):
retry_after = retry_after_from_payload(payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method=method,
status=resp.status_code,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after) from e
body = resp.text
logger.error(
"telegram.http_error",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(e),
body=body,
)
return None
try:
payload = resp.json()
except Exception as e:
body = resp.text
logger.error(
"telegram.bad_response",
method=method,
status=resp.status_code,
error=str(e),
error_type=e.__class__.__name__,
body=body,
)
return None
if not isinstance(payload, dict):
logger.error(
"telegram.invalid_payload",
method=method,
url=str(resp.request.url),
payload=payload,
)
return None
if not payload.get("ok"):
if payload.get("error_code") == 429:
retry_after = retry_after_from_payload(payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method=method,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after)
logger.error(
"telegram.api_error",
method=method,
url=str(resp.request.url),
payload=payload,
)
return None
logger.debug("telegram.response", method=method, payload=payload)
return payload.get("result")
async def get_updates( async def get_updates(
self, self,
offset: int | None, offset: int | None,
timeout_s: int = 50, timeout_s: int = 50,
allowed_updates: list[str] | None = None, allowed_updates: list[str] | None = None,
) -> list[dict] | None: ) -> list[Update] | None:
while True: async def execute() -> list[Update] | None:
try:
if self._client_override is not None: if self._client_override is not None:
return await self._client_override.get_updates( raw = await self._client_override.get_updates(
offset=offset, offset=offset,
timeout_s=timeout_s, timeout_s=timeout_s,
allowed_updates=allowed_updates, allowed_updates=allowed_updates,
) )
if raw is None:
return None
try:
return msgspec.convert(raw, type=list[Update])
except Exception as exc:
logger.error(
"telegram.decode_error",
method="getUpdates",
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
params: dict[str, Any] = {"timeout": timeout_s} params: dict[str, Any] = {"timeout": timeout_s}
if offset is not None: if offset is not None:
params["offset"] = offset params["offset"] = offset
if allowed_updates is not None: if allowed_updates is not None:
params["allowed_updates"] = allowed_updates params["allowed_updates"] = allowed_updates
result = await self._post("getUpdates", params) result = await self._post("getUpdates", params)
return result if isinstance(result, list) else None if result is None or not isinstance(result, list):
except TelegramRetryAfter as exc: return None
await self._sleep(exc.retry_after)
async def get_file(self, file_id: str) -> dict | None:
while True:
try: try:
return msgspec.convert(result, type=list[Update])
except Exception as exc:
logger.error(
"telegram.decode_error",
method="getUpdates",
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
return await self._call_with_retry_after(execute)
async def get_file(self, file_id: str) -> File | None:
async def execute() -> File | None:
if self._client_override is not None: if self._client_override is not None:
return await self._client_override.get_file(file_id) return await self._client_override.get_file(file_id)
result = await self._post("getFile", {"file_id": file_id}) result = await self._post("getFile", {"file_id": file_id})
return result if isinstance(result, dict) else None return self._decode_result(method="getFile", payload=result, model=File)
except TelegramRetryAfter as exc:
await self._sleep(exc.retry_after) return await self._call_with_retry_after(execute)
async def download_file(self, file_path: str) -> bytes | None: async def download_file(self, file_path: str) -> bytes | None:
async def execute() -> bytes | None:
if self._client_override is not None: if self._client_override is not None:
return await self._client_override.download_file(file_path) return await self._client_override.download_file(file_path)
if self._http_client is None or self._file_base is None: if self._http_client is None or self._file_base is None:
raise RuntimeError("TelegramClient is configured without an HTTP client.") raise RuntimeError(
"TelegramClient is configured without an HTTP client."
)
url = f"{self._file_base}/{file_path}" url = f"{self._file_base}/{file_path}"
try: try:
resp = await self._http_client.get(url) resp = await self._http_client.get(url)
@@ -922,6 +924,24 @@ class TelegramClient:
try: try:
resp.raise_for_status() resp.raise_for_status()
except httpx.HTTPStatusError as exc: except httpx.HTTPStatusError as exc:
if resp.status_code == 429:
retry_after: float | None = None
try:
response_payload = resp.json()
except Exception:
response_payload = None
if isinstance(response_payload, dict):
retry_after = retry_after_from_payload(response_payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method="download_file",
status=resp.status_code,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after) from exc
logger.error( logger.error(
"telegram.file_http_error", "telegram.file_http_error",
status=resp.status_code, status=resp.status_code,
@@ -932,6 +952,8 @@ class TelegramClient:
return None return None
return resp.content return resp.content
return await self._call_with_retry_after(execute)
async def send_message( async def send_message(
self, self,
chat_id: int, chat_id: int,
@@ -944,8 +966,8 @@ class TelegramClient:
reply_markup: dict[str, Any] | None = None, reply_markup: dict[str, Any] | None = None,
*, *,
replace_message_id: int | None = None, replace_message_id: int | None = None,
) -> dict | None: ) -> Message | None:
async def execute() -> dict | None: async def execute() -> Message | None:
if self._client_override is not None: if self._client_override is not None:
return await self._client_override.send_message( return await self._client_override.send_message(
chat_id=chat_id, chat_id=chat_id,
@@ -972,7 +994,11 @@ class TelegramClient:
if reply_markup is not None: if reply_markup is not None:
params["reply_markup"] = reply_markup params["reply_markup"] = reply_markup
result = await self._post("sendMessage", params) result = await self._post("sendMessage", params)
return result if isinstance(result, dict) else None return self._decode_result(
method="sendMessage",
payload=result,
model=Message,
)
if replace_message_id is not None: if replace_message_id is not None:
await self._outbox.drop_pending(key=("edit", chat_id, replace_message_id)) await self._outbox.drop_pending(key=("edit", chat_id, replace_message_id))
@@ -1000,8 +1026,8 @@ class TelegramClient:
message_thread_id: int | None = None, message_thread_id: int | None = None,
disable_notification: bool | None = False, disable_notification: bool | None = False,
caption: str | None = None, caption: str | None = None,
) -> dict | None: ) -> Message | None:
async def execute() -> dict | None: async def execute() -> Message | None:
if self._client_override is not None: if self._client_override is not None:
return await self._client_override.send_document( return await self._client_override.send_document(
chat_id=chat_id, chat_id=chat_id,
@@ -1026,7 +1052,11 @@ class TelegramClient:
params, params,
files={"document": (filename, content)}, files={"document": (filename, content)},
) )
return result if isinstance(result, dict) else None return self._decode_result(
method="sendDocument",
payload=result,
model=Message,
)
return await self.enqueue_op( return await self.enqueue_op(
key=self.unique_key("send_document"), key=self.unique_key("send_document"),
@@ -1046,8 +1076,8 @@ class TelegramClient:
reply_markup: dict[str, Any] | None = None, reply_markup: dict[str, Any] | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict | None: ) -> Message | None:
async def execute() -> dict | None: async def execute() -> Message | None:
if self._client_override is not None: if self._client_override is not None:
return await self._client_override.edit_message_text( return await self._client_override.edit_message_text(
chat_id=chat_id, chat_id=chat_id,
@@ -1070,7 +1100,11 @@ class TelegramClient:
if reply_markup is not None: if reply_markup is not None:
params["reply_markup"] = reply_markup params["reply_markup"] = reply_markup
result = await self._post("editMessageText", params) result = await self._post("editMessageText", params)
return result if isinstance(result, dict) else None return self._decode_result(
method="editMessageText",
payload=result,
model=Message,
)
return await self.enqueue_op( return await self.enqueue_op(
key=("edit", chat_id, message_id), key=("edit", chat_id, message_id),
@@ -1142,12 +1176,12 @@ class TelegramClient:
) )
) )
async def get_me(self) -> dict | None: async def get_me(self) -> User | None:
async def execute() -> dict | None: async def execute() -> User | None:
if self._client_override is not None: if self._client_override is not None:
return await self._client_override.get_me() return await self._client_override.get_me()
result = await self._post("getMe", {}) result = await self._post("getMe", {})
return result if isinstance(result, dict) else None return self._decode_result(method="getMe", payload=result, model=User)
return await self.enqueue_op( return await self.enqueue_op(
key=self.unique_key("get_me"), key=self.unique_key("get_me"),
@@ -1188,12 +1222,12 @@ class TelegramClient:
) )
) )
async def get_chat(self, chat_id: int) -> dict | None: async def get_chat(self, chat_id: int) -> Chat | None:
async def execute() -> dict | None: async def execute() -> Chat | None:
if self._client_override is not None: if self._client_override is not None:
return await self._client_override.get_chat(chat_id) return await self._client_override.get_chat(chat_id)
result = await self._post("getChat", {"chat_id": chat_id}) result = await self._post("getChat", {"chat_id": chat_id})
return result if isinstance(result, dict) else None return self._decode_result(method="getChat", payload=result, model=Chat)
return await self.enqueue_op( return await self.enqueue_op(
key=self.unique_key("get_chat"), key=self.unique_key("get_chat"),
@@ -1203,14 +1237,18 @@ class TelegramClient:
chat_id=chat_id, chat_id=chat_id,
) )
async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None: async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
async def execute() -> dict | None: async def execute() -> ChatMember | None:
if self._client_override is not None: if self._client_override is not None:
return await self._client_override.get_chat_member(chat_id, user_id) return await self._client_override.get_chat_member(chat_id, user_id)
result = await self._post( result = await self._post(
"getChatMember", {"chat_id": chat_id, "user_id": user_id} "getChatMember", {"chat_id": chat_id, "user_id": user_id}
) )
return result if isinstance(result, dict) else None return self._decode_result(
method="getChatMember",
payload=result,
model=ChatMember,
)
return await self.enqueue_op( return await self.enqueue_op(
key=self.unique_key("get_chat_member"), key=self.unique_key("get_chat_member"),
@@ -1220,14 +1258,18 @@ class TelegramClient:
chat_id=chat_id, chat_id=chat_id,
) )
async def create_forum_topic(self, chat_id: int, name: str) -> dict | None: async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
async def execute() -> dict | None: async def execute() -> ForumTopic | None:
if self._client_override is not None: if self._client_override is not None:
return await self._client_override.create_forum_topic(chat_id, name) return await self._client_override.create_forum_topic(chat_id, name)
result = await self._post( result = await self._post(
"createForumTopic", {"chat_id": chat_id, "name": name} "createForumTopic", {"chat_id": chat_id, "name": name}
) )
return result if isinstance(result, dict) else None return self._decode_result(
method="createForumTopic",
payload=result,
model=ForumTopic,
)
return await self.enqueue_op( return await self.enqueue_op(
key=self.unique_key("create_forum_topic"), key=self.unique_key("create_forum_topic"),
+6 -14
View File
@@ -289,11 +289,10 @@ async def _check_file_permissions(cfg, msg: TelegramIncomingMessage) -> bool:
if is_private: if is_private:
return True return True
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
if not isinstance(member, dict): if member is None:
await reply(text="failed to verify file transfer permissions.") await reply(text="failed to verify file transfer permissions.")
return False return False
status = member.get("status") if member.status in {"creator", "administrator"}:
if status in {"creator", "administrator"}:
return True return True
await reply(text="file transfer is restricted to group admins.") await reply(text="file transfer is restricted to group admins.")
return False return False
@@ -377,21 +376,14 @@ async def _save_document_payload(
error="file is too large to upload.", error="file is too large to upload.",
) )
file_info = await cfg.bot.get_file(document.file_id) file_info = await cfg.bot.get_file(document.file_id)
if not isinstance(file_info, dict): if file_info is None:
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error="failed to fetch file metadata.",
)
file_path = file_info.get("file_path")
if not isinstance(file_path, str) or not file_path:
return _FilePutResult( return _FilePutResult(
name=name, name=name,
rel_path=None, rel_path=None,
size=None, size=None,
error="failed to fetch file metadata.", error="failed to fetch file metadata.",
) )
file_path = file_info.file_path
name = default_upload_name(document.file_name, file_path) name = default_upload_name(document.file_name, file_path)
resolved_path = rel_path resolved_path = rel_path
if resolved_path is None: if resolved_path is None:
@@ -972,10 +964,10 @@ async def _handle_topic_command(
return return
title = _topic_title(runtime=cfg.runtime, context=context) title = _topic_title(runtime=cfg.runtime, context=context)
created = await cfg.bot.create_forum_topic(msg.chat_id, title) created = await cfg.bot.create_forum_topic(msg.chat_id, title)
thread_id = created.get("message_thread_id") if isinstance(created, dict) else None if created is None:
if isinstance(thread_id, bool) or not isinstance(thread_id, int):
await reply(text="failed to create topic.") await reply(text="failed to create topic.")
return return
thread_id = created.message_thread_id
await store.set_context( await store.set_context(
msg.chat_id, msg.chat_id,
thread_id, thread_id,
+5 -3
View File
@@ -14,6 +14,7 @@ from ..directives import DirectiveError
from ..logging import get_logger from ..logging import get_logger
from ..model import EngineId, ResumeToken from ..model import EngineId, ResumeToken
from ..scheduler import ThreadJob, ThreadScheduler from ..scheduler import ThreadJob, ThreadScheduler
from ..settings import TelegramTransportSettings
from ..transport import MessageRef from ..transport import MessageRef
from ..context import RunContext from ..context import RunContext
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
@@ -169,7 +170,7 @@ async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int |
if drained: if drained:
logger.info("startup.backlog.drained", count=drained) logger.info("startup.backlog.drained", count=drained)
return offset return offset
offset = updates[-1]["update_id"] + 1 offset = updates[-1].update_id + 1
drained += len(updates) drained += len(updates)
@@ -266,7 +267,7 @@ async def run_main_loop(
watch_config: bool | None = None, watch_config: bool | None = None,
default_engine_override: str | None = None, default_engine_override: str | None = None,
transport_id: str | None = None, transport_id: str | None = None,
transport_config: dict[str, object] | None = None, transport_config: TelegramTransportSettings | None = None,
) -> None: ) -> None:
from ..runner_bridge import RunningTasks from ..runner_bridge import RunningTasks
@@ -277,7 +278,7 @@ async def run_main_loop(
} }
reserved_commands = _reserved_commands(cfg.runtime) reserved_commands = _reserved_commands(cfg.runtime)
transport_snapshot = ( transport_snapshot = (
dict(transport_config) if transport_config is not None else None transport_config.model_dump() if transport_config is not None else None
) )
topic_store: TopicStateStore | None = None topic_store: TopicStateStore | None = None
media_groups: dict[tuple[int, str], _MediaGroupState] = {} media_groups: dict[tuple[int, str], _MediaGroupState] = {}
@@ -476,6 +477,7 @@ async def run_main_loop(
bot=cfg.bot, bot=cfg.bot,
msg=msg, msg=msg,
enabled=cfg.voice_transcription, enabled=cfg.voice_transcription,
max_bytes=cfg.voice_max_bytes,
reply=reply, reply=reply,
) )
if text is None: if text is None:
+11 -15
View File
@@ -33,6 +33,7 @@ from ..engines import list_backends
from ..logging import suppress_logs from ..logging import suppress_logs
from ..settings import HOME_CONFIG_PATH, load_settings, require_telegram from ..settings import HOME_CONFIG_PATH, load_settings, require_telegram
from ..transports import SetupResult from ..transports import SetupResult
from .api_models import User
from .client import TelegramClient, TelegramRetryAfter from .client import TelegramClient, TelegramRetryAfter
__all__ = [ __all__ = [
@@ -131,7 +132,7 @@ def mask_token(token: str) -> str:
return f"{token[:9]}...{token[-5:]}" return f"{token[:9]}...{token[-5:]}"
async def get_bot_info(token: str) -> dict[str, Any] | None: async def get_bot_info(token: str) -> User | None:
bot = TelegramClient(token) bot = TelegramClient(token)
try: try:
for _ in range(3): for _ in range(3):
@@ -153,23 +154,19 @@ async def wait_for_chat(token: str) -> ChatInfo:
offset=None, timeout_s=0, allowed_updates=allowed_updates offset=None, timeout_s=0, allowed_updates=allowed_updates
) )
if drained: if drained:
offset = drained[-1]["update_id"] + 1 offset = drained[-1].update_id + 1
while True: while True:
try:
updates = await bot.get_updates( updates = await bot.get_updates(
offset=offset, timeout_s=50, allowed_updates=allowed_updates offset=offset, timeout_s=50, allowed_updates=allowed_updates
) )
except TelegramRetryAfter as exc:
await anyio.sleep(exc.retry_after)
continue
if updates is None: if updates is None:
await anyio.sleep(1) await anyio.sleep(1)
continue continue
if not updates: if not updates:
continue continue
offset = updates[-1]["update_id"] + 1 offset = updates[-1].update_id + 1
update = updates[-1] update = updates[-1]
msg = update.get("message") msg = update.message
if not isinstance(msg, dict): if not isinstance(msg, dict):
continue continue
sender = msg.get("from") sender = msg.get("from")
@@ -298,7 +295,7 @@ def _confirm(message: str, *, default: bool = True) -> bool | None:
return question.ask() return question.ask()
def _prompt_token(console: Console) -> tuple[str, dict[str, Any]] | None: def _prompt_token(console: Console) -> tuple[str, User] | None:
while True: while True:
token = questionary.password("paste your bot token:").ask() token = questionary.password("paste your bot token:").ask()
if token is None: if token is None:
@@ -310,11 +307,10 @@ def _prompt_token(console: Console) -> tuple[str, dict[str, Any]] | None:
console.print(" validating...") console.print(" validating...")
info = anyio.run(get_bot_info, token) info = anyio.run(get_bot_info, token)
if info: if info:
username = info.get("username") if info.username:
if isinstance(username, str) and username: console.print(f" connected to @{info.username}")
console.print(f" connected to @{username}")
else: else:
name = info.get("first_name") or "your bot" name = info.first_name or "your bot"
console.print(f" connected to {name}") console.print(f" connected to {name}")
return token, info return token, info
console.print(" failed to connect, check the token and try again") console.print(" failed to connect, check the token and try again")
@@ -342,7 +338,7 @@ def capture_chat_id(*, token: str | None = None) -> ChatInfo | None:
return None return None
token, info = token_info token, info = token_info
bot_ref = f"@{info['username']}" bot_ref = f"@{info.username}"
console.print("") console.print("")
console.print(f" send /start to {bot_ref} (works in groups too)") console.print(f" send /start to {bot_ref} (works in groups too)")
console.print(" waiting...") console.print(" waiting...")
@@ -401,7 +397,7 @@ def interactive_setup(*, force: bool) -> bool:
if token_info is None: if token_info is None:
return False return False
token, info = token_info token, info = token_info
bot_ref = f"@{info['username']}" bot_ref = f"@{info.username}"
console.print("") console.print("")
console.print(f" send /start to {bot_ref} (works in groups too)") console.print(f" send /start to {bot_ref} (works in groups too)")
+79 -102
View File
@@ -4,9 +4,9 @@ import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, cast
import anyio import anyio
import msgspec
from ..context import RunContext from ..context import RunContext
from ..logging import get_logger from ..logging import get_logger
@@ -27,6 +27,26 @@ class TopicThreadSnapshot:
topic_title: str | None topic_title: str | None
class _ContextState(msgspec.Struct, forbid_unknown_fields=False):
project: str | None = None
branch: str | None = None
class _SessionState(msgspec.Struct, forbid_unknown_fields=False):
resume: str
class _ThreadState(msgspec.Struct, forbid_unknown_fields=False):
context: _ContextState | None = None
sessions: dict[str, _SessionState] = msgspec.field(default_factory=dict)
topic_title: str | None = None
class _TopicState(msgspec.Struct, forbid_unknown_fields=False):
version: int
threads: dict[str, _ThreadState] = msgspec.field(default_factory=dict)
def resolve_state_path(config_path: Path) -> Path: def resolve_state_path(config_path: Path) -> Path:
return config_path.with_name(STATE_FILENAME) return config_path.with_name(STATE_FILENAME)
@@ -35,34 +55,31 @@ def _thread_key(chat_id: int, thread_id: int) -> str:
return f"{chat_id}:{thread_id}" return f"{chat_id}:{thread_id}"
def _parse_context(raw: object) -> RunContext | None: def _normalize_text(value: str | None) -> str | None:
if not isinstance(raw, dict): if value is None:
return None return None
payload = cast(dict[str, object], raw) value = value.strip()
project = payload.get("project") return value or None
branch = payload.get("branch")
if project is not None and not isinstance(project, str):
project = None def _context_from_state(state: _ContextState | None) -> RunContext | None:
if isinstance(project, str): if state is None:
project = project.strip() or None return None
if branch is not None and not isinstance(branch, str): project = _normalize_text(state.project)
branch = None branch = _normalize_text(state.branch)
if isinstance(branch, str):
branch = branch.strip() or None
if project is None and branch is None: if project is None and branch is None:
return None return None
return RunContext(project=project, branch=branch) return RunContext(project=project, branch=branch)
def _dump_context(context: RunContext | None) -> dict[str, str] | None: def _context_to_state(context: RunContext | None) -> _ContextState | None:
if context is None or (context.project is None and context.branch is None): if context is None:
return None return None
payload: dict[str, str] = {} project = _normalize_text(context.project)
if context.project is not None: branch = _normalize_text(context.branch)
payload["project"] = context.project if project is None and branch is None:
if context.branch is not None: return None
payload["branch"] = context.branch return _ContextState(project=project, branch=branch)
return payload or None
class TopicStateStore: class TopicStateStore:
@@ -71,10 +88,7 @@ class TopicStateStore:
self._lock = anyio.Lock() self._lock = anyio.Lock()
self._loaded = False self._loaded = False
self._mtime_ns: int | None = None self._mtime_ns: int | None = None
self._data: dict[str, Any] = { self._state = _TopicState(version=STATE_VERSION, threads={})
"version": STATE_VERSION,
"threads": {},
}
async def get_thread( async def get_thread(
self, chat_id: int, thread_id: int self, chat_id: int, thread_id: int
@@ -92,7 +106,7 @@ class TopicStateStore:
thread = self._get_thread_locked(chat_id, thread_id) thread = self._get_thread_locked(chat_id, thread_id)
if thread is None: if thread is None:
return None return None
return _parse_context(thread.get("context")) return _context_from_state(thread.context)
async def set_context( async def set_context(
self, self,
@@ -105,9 +119,9 @@ class TopicStateStore:
async with self._lock: async with self._lock:
self._reload_locked_if_needed() self._reload_locked_if_needed()
thread = self._ensure_thread_locked(chat_id, thread_id) thread = self._ensure_thread_locked(chat_id, thread_id)
thread["context"] = _dump_context(context) thread.context = _context_to_state(context)
if topic_title is not None: if topic_title is not None:
thread["topic_title"] = topic_title thread.topic_title = topic_title
self._save_locked() self._save_locked()
async def clear_context(self, chat_id: int, thread_id: int) -> None: async def clear_context(self, chat_id: int, thread_id: int) -> None:
@@ -116,7 +130,7 @@ class TopicStateStore:
thread = self._get_thread_locked(chat_id, thread_id) thread = self._get_thread_locked(chat_id, thread_id)
if thread is None: if thread is None:
return return
thread.pop("context", None) thread.context = None
self._save_locked() self._save_locked()
async def get_session_resume( async def get_session_resume(
@@ -127,16 +141,10 @@ class TopicStateStore:
thread = self._get_thread_locked(chat_id, thread_id) thread = self._get_thread_locked(chat_id, thread_id)
if thread is None: if thread is None:
return None return None
sessions = thread.get("sessions") entry = thread.sessions.get(engine)
if not isinstance(sessions, dict): if entry is None or not entry.resume:
return None return None
entry = sessions.get(engine) return ResumeToken(engine=engine, value=entry.resume)
if not isinstance(entry, dict):
return None
value = entry.get("resume")
if not isinstance(value, str) or not value:
return None
return ResumeToken(engine=engine, value=value)
async def set_session_resume( async def set_session_resume(
self, chat_id: int, thread_id: int, token: ResumeToken self, chat_id: int, thread_id: int, token: ResumeToken
@@ -144,13 +152,7 @@ class TopicStateStore:
async with self._lock: async with self._lock:
self._reload_locked_if_needed() self._reload_locked_if_needed()
thread = self._ensure_thread_locked(chat_id, thread_id) thread = self._ensure_thread_locked(chat_id, thread_id)
sessions = thread.get("sessions") thread.sessions[token.engine] = _SessionState(resume=token.value)
if not isinstance(sessions, dict):
sessions = {}
thread["sessions"] = sessions
sessions[token.engine] = {
"resume": token.value,
}
self._save_locked() self._save_locked()
async def clear_sessions(self, chat_id: int, thread_id: int) -> None: async def clear_sessions(self, chat_id: int, thread_id: int) -> None:
@@ -159,7 +161,7 @@ class TopicStateStore:
thread = self._get_thread_locked(chat_id, thread_id) thread = self._get_thread_locked(chat_id, thread_id)
if thread is None: if thread is None:
return return
thread.pop("sessions", None) thread.sessions = {}
self._save_locked() self._save_locked()
async def find_thread_for_context( async def find_thread_for_context(
@@ -167,47 +169,37 @@ class TopicStateStore:
) -> int | None: ) -> int | None:
async with self._lock: async with self._lock:
self._reload_locked_if_needed() self._reload_locked_if_needed()
threads = self._data.get("threads") target_project = _normalize_text(context.project)
if not isinstance(threads, dict): target_branch = _normalize_text(context.branch)
return None for raw_key, thread in self._state.threads.items():
for raw_key, payload in threads.items(): if not raw_key.startswith(f"{chat_id}:"):
if not isinstance(raw_key, str) or not isinstance(payload, dict):
continue continue
parsed = _parse_context(payload.get("context")) parsed = _context_from_state(thread.context)
if parsed is None: if parsed is None:
continue continue
if parsed.project != context.project or parsed.branch != context.branch: if parsed.project != target_project or parsed.branch != target_branch:
continue
if not raw_key.startswith(f"{chat_id}:"):
continue continue
try: try:
_, thread_str = raw_key.split(":", 1) _, thread_str = raw_key.split(":", 1)
return int(thread_str) return int(thread_str)
except (ValueError, TypeError): except ValueError:
continue continue
return None return None
def _snapshot_locked( def _snapshot_locked(
self, thread: dict[str, Any], chat_id: int, thread_id: int self, thread: _ThreadState, chat_id: int, thread_id: int
) -> TopicThreadSnapshot: ) -> TopicThreadSnapshot:
sessions: dict[str, str] = {} sessions = {
raw_sessions = thread.get("sessions") engine: entry.resume
if isinstance(raw_sessions, dict): for engine, entry in thread.sessions.items()
for engine, entry in raw_sessions.items(): if entry.resume
if not isinstance(engine, str) or not isinstance(entry, dict): }
continue
value = entry.get("resume")
if isinstance(value, str) and value:
sessions[engine] = value
topic_title = thread.get("topic_title")
if not isinstance(topic_title, str):
topic_title = None
return TopicThreadSnapshot( return TopicThreadSnapshot(
chat_id=chat_id, chat_id=chat_id,
thread_id=thread_id, thread_id=thread_id,
context=_parse_context(thread.get("context")), context=_context_from_state(thread.context),
sessions=sessions, sessions=sessions,
topic_title=topic_title, topic_title=thread.topic_title,
) )
def _stat_mtime_ns(self) -> int | None: def _stat_mtime_ns(self) -> int | None:
@@ -226,10 +218,10 @@ class TopicStateStore:
self._loaded = True self._loaded = True
self._mtime_ns = self._stat_mtime_ns() self._mtime_ns = self._stat_mtime_ns()
if self._mtime_ns is None: if self._mtime_ns is None:
self._data = {"version": STATE_VERSION, "threads": {}} self._state = _TopicState(version=STATE_VERSION, threads={})
return return
try: try:
payload = json.loads(self._path.read_text(encoding="utf-8")) payload = msgspec.json.decode(self._path.read_bytes(), type=_TopicState)
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(
"telegram.topic_state.load_failed", "telegram.topic_state.load_failed",
@@ -237,29 +229,22 @@ class TopicStateStore:
error=str(exc), error=str(exc),
error_type=exc.__class__.__name__, error_type=exc.__class__.__name__,
) )
self._data = {"version": STATE_VERSION, "threads": {}} self._state = _TopicState(version=STATE_VERSION, threads={})
return return
if not isinstance(payload, dict): if payload.version != STATE_VERSION:
self._data = {"version": STATE_VERSION, "threads": {}}
return
version = payload.get("version")
if version != STATE_VERSION:
logger.warning( logger.warning(
"telegram.topic_state.version_mismatch", "telegram.topic_state.version_mismatch",
path=str(self._path), path=str(self._path),
version=version, version=payload.version,
expected=STATE_VERSION, expected=STATE_VERSION,
) )
self._data = {"version": STATE_VERSION, "threads": {}} self._state = _TopicState(version=STATE_VERSION, threads={})
return return
threads = payload.get("threads") self._state = payload
if not isinstance(threads, dict):
threads = {}
self._data = {"version": STATE_VERSION, "threads": threads}
def _save_locked(self) -> None: def _save_locked(self) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True) self._path.parent.mkdir(parents=True, exist_ok=True)
payload = {"version": STATE_VERSION, "threads": self._data.get("threads", {})} payload = msgspec.to_builtins(self._state)
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp") tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
with open(tmp_path, "w", encoding="utf-8") as handle: with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2, sort_keys=True) json.dump(payload, handle, indent=2, sort_keys=True)
@@ -267,22 +252,14 @@ class TopicStateStore:
os.replace(tmp_path, self._path) os.replace(tmp_path, self._path)
self._mtime_ns = self._stat_mtime_ns() self._mtime_ns = self._stat_mtime_ns()
def _get_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any] | None: def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None:
threads = self._data.get("threads") return self._state.threads.get(_thread_key(chat_id, thread_id))
if not isinstance(threads, dict):
return None
entry = threads.get(_thread_key(chat_id, thread_id))
return entry if isinstance(entry, dict) else None
def _ensure_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any]: def _ensure_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState:
threads = self._data.get("threads")
if not isinstance(threads, dict):
threads = {}
self._data["threads"] = threads
key = _thread_key(chat_id, thread_id) key = _thread_key(chat_id, thread_id)
entry = threads.get(key) entry = self._state.threads.get(key)
if isinstance(entry, dict): if entry is not None:
return entry return entry
entry = {} entry = _ThreadState()
threads[key] = entry self._state.threads[key] = entry
return entry return entry
+9 -12
View File
@@ -185,9 +185,9 @@ async def _validate_topics_setup(cfg: TelegramBridgeConfig) -> None:
if not cfg.topics.enabled: if not cfg.topics.enabled:
return return
me = await cfg.bot.get_me() me = await cfg.bot.get_me()
bot_id = me.get("id") if isinstance(me, dict) else None if me is None:
if not isinstance(bot_id, int):
raise ConfigError("failed to fetch bot id for topics validation.") raise ConfigError("failed to fetch bot id for topics validation.")
bot_id = me.id
scope, chat_ids = _resolve_topics_scope(cfg) scope, chat_ids = _resolve_topics_scope(cfg)
if scope == "projects" and not chat_ids: if scope == "projects" and not chat_ids:
raise ConfigError( raise ConfigError(
@@ -197,37 +197,34 @@ async def _validate_topics_setup(cfg: TelegramBridgeConfig) -> None:
for chat_id in chat_ids: for chat_id in chat_ids:
chat = await cfg.bot.get_chat(chat_id) chat = await cfg.bot.get_chat(chat_id)
if not isinstance(chat, dict): if chat is None:
raise ConfigError( raise ConfigError(
f"failed to fetch chat info for topics validation ({chat_id})." f"failed to fetch chat info for topics validation ({chat_id})."
) )
chat_type = chat.get("type") if chat.type != "supergroup":
is_forum = chat.get("is_forum")
if chat_type != "supergroup":
raise ConfigError( raise ConfigError(
"topics enabled but chat is not a supergroup " "topics enabled but chat is not a supergroup "
f"(chat_id={chat_id}); convert the group and enable topics." f"(chat_id={chat_id}); convert the group and enable topics."
) )
if is_forum is not True: if chat.is_forum is not True:
raise ConfigError( raise ConfigError(
"topics enabled but chat does not have topics enabled " "topics enabled but chat does not have topics enabled "
f"(chat_id={chat_id}); turn on topics in group settings." f"(chat_id={chat_id}); turn on topics in group settings."
) )
member = await cfg.bot.get_chat_member(chat_id, bot_id) member = await cfg.bot.get_chat_member(chat_id, bot_id)
if not isinstance(member, dict): if member is None:
raise ConfigError( raise ConfigError(
"failed to fetch bot permissions " "failed to fetch bot permissions "
f"(chat_id={chat_id}); promote the bot to admin with manage topics." f"(chat_id={chat_id}); promote the bot to admin with manage topics."
) )
status = member.get("status") if member.status == "creator":
if status == "creator":
continue continue
if status != "administrator": if member.status != "administrator":
raise ConfigError( raise ConfigError(
"topics enabled but bot is not an admin " "topics enabled but bot is not an admin "
f"(chat_id={chat_id}); promote it and grant manage topics." f"(chat_id={chat_id}); promote it and grant manage topics."
) )
if member.get("can_manage_topics") is not True: if member.can_manage_topics is not True:
raise ConfigError( raise ConfigError(
"topics enabled but bot lacks manage topics permission " "topics enabled but bot lacks manage topics permission "
f"(chat_id={chat_id}); grant can_manage_topics." f"(chat_id={chat_id}); grant can_manage_topics."
+19 -4
View File
@@ -2,7 +2,6 @@ from __future__ import annotations
import io import io
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import cast
from ..logging import get_logger from ..logging import get_logger
from openai import AsyncOpenAI, OpenAIError from openai import AsyncOpenAI, OpenAIError
@@ -29,6 +28,7 @@ async def transcribe_voice(
bot: BotClient, bot: BotClient,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
enabled: bool, enabled: bool,
max_bytes: int | None = None,
reply: Callable[..., Awaitable[None]], reply: Callable[..., Awaitable[None]],
) -> str | None: ) -> str | None:
voice = msg.voice voice = msg.voice
@@ -37,9 +37,24 @@ async def transcribe_voice(
if not enabled: if not enabled:
await reply(text=VOICE_TRANSCRIPTION_DISABLED_HINT) await reply(text=VOICE_TRANSCRIPTION_DISABLED_HINT)
return None return None
file_info = cast(dict[str, object], await bot.get_file(voice.file_id)) if (
file_path = cast(str, file_info["file_path"]) max_bytes is not None
audio_bytes = cast(bytes, await bot.download_file(file_path)) and voice.file_size is not None
and voice.file_size > max_bytes
):
await reply(text="voice message is too large to transcribe.")
return None
file_info = await bot.get_file(voice.file_id)
if file_info is None:
await reply(text="failed to fetch voice file.")
return None
audio_bytes = await bot.download_file(file_info.file_path)
if audio_bytes is None:
await reply(text="failed to download voice file.")
return None
if max_bytes is not None and len(audio_bytes) > max_bytes:
await reply(text="voice message is too large to transcribe.")
return None
audio_file = io.BytesIO(audio_bytes) audio_file = io.BytesIO(audio_bytes)
audio_file.name = "voice.ogg" audio_file.name = "voice.ogg"
async with AsyncOpenAI(timeout=120) as client: async with AsyncOpenAI(timeout=120) as client:
+6 -3
View File
@@ -10,7 +10,7 @@ from .context import RunContext
from .directives import format_context_line, parse_context_line, parse_directives from .directives import format_context_line, parse_context_line, parse_directives
from .model import EngineId, ResumeToken from .model import EngineId, ResumeToken
from .plugins import normalize_allowlist from .plugins import normalize_allowlist
from .router import AutoRouter from .router import AutoRouter, EngineStatus
from .runner import Runner from .runner import Runner
from .worktrees import WorktreeError, resolve_run_cwd from .worktrees import WorktreeError, resolve_run_cwd
@@ -108,11 +108,14 @@ class TransportRuntime:
def available_engine_ids(self) -> tuple[EngineId, ...]: def available_engine_ids(self) -> tuple[EngineId, ...]:
return tuple(entry.engine for entry in self._router.available_entries) return tuple(entry.engine for entry in self._router.available_entries)
def missing_engine_ids(self) -> tuple[EngineId, ...]: def engine_ids_with_status(self, status: EngineStatus) -> tuple[EngineId, ...]:
return tuple( return tuple(
entry.engine for entry in self._router.entries if not entry.available entry.engine for entry in self._router.entries if entry.status == status
) )
def missing_engine_ids(self) -> tuple[EngineId, ...]:
return self.engine_ids_with_status("missing_cli")
def project_aliases(self) -> tuple[str, ...]: def project_aliases(self) -> tuple[str, ...]:
return tuple(project.alias for project in self._projects.projects.values()) return tuple(project.alias for project in self._projects.projects.values())
+2 -2
View File
@@ -41,13 +41,13 @@ class TransportBackend(Protocol):
def interactive_setup(self, *, force: bool) -> bool: ... def interactive_setup(self, *, force: bool) -> bool: ...
def lock_token( def lock_token(
self, *, transport_config: dict[str, object], config_path: Path self, *, transport_config: object, config_path: Path
) -> str | None: ... ) -> str | None: ...
def build_and_run( def build_and_run(
self, self,
*, *,
transport_config: dict[str, object], transport_config: object,
config_path: Path, config_path: Path,
runtime: TransportRuntime, runtime: TransportRuntime,
final_notify: bool, final_notify: bool,
+10 -7
View File
@@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
from takopi.backends import EngineBackend
from takopi.config import dump_toml from takopi.config import dump_toml
from takopi.telegram import onboarding from takopi.telegram import onboarding
from takopi.backends import EngineBackend from takopi.telegram.api_models import User
def test_mask_token_short() -> None: def test_mask_token_short() -> None:
@@ -91,7 +92,7 @@ def test_interactive_setup_writes_config(monkeypatch, tmp_path) -> None:
def _fake_run(func, *args, **kwargs): def _fake_run(func, *args, **kwargs):
if func is onboarding.get_bot_info: if func is onboarding.get_bot_info:
return {"username": "my_bot"} return User(id=1, username="my_bot")
if func is onboarding.wait_for_chat: if func is onboarding.wait_for_chat:
return onboarding.ChatInfo( return onboarding.ChatInfo(
chat_id=123, chat_id=123,
@@ -136,7 +137,7 @@ def test_interactive_setup_preserves_projects(monkeypatch, tmp_path) -> None:
def _fake_run(func, *args, **kwargs): def _fake_run(func, *args, **kwargs):
if func is onboarding.get_bot_info: if func is onboarding.get_bot_info:
return {"username": "my_bot"} return User(id=1, username="my_bot")
if func is onboarding.wait_for_chat: if func is onboarding.wait_for_chat:
return onboarding.ChatInfo( return onboarding.ChatInfo(
chat_id=123, chat_id=123,
@@ -173,7 +174,7 @@ def test_interactive_setup_no_agents_aborts(monkeypatch, tmp_path) -> None:
def _fake_run(func, *args, **kwargs): def _fake_run(func, *args, **kwargs):
if func is onboarding.get_bot_info: if func is onboarding.get_bot_info:
return {"username": "my_bot"} return User(id=1, username="my_bot")
if func is onboarding.wait_for_chat: if func is onboarding.wait_for_chat:
return onboarding.ChatInfo( return onboarding.ChatInfo(
chat_id=123, chat_id=123,
@@ -211,7 +212,7 @@ def test_interactive_setup_recovers_from_malformed_toml(monkeypatch, tmp_path) -
def _fake_run(func, *args, **kwargs): def _fake_run(func, *args, **kwargs):
if func is onboarding.get_bot_info: if func is onboarding.get_bot_info:
return {"username": "my_bot"} return User(id=1, username="my_bot")
if func is onboarding.wait_for_chat: if func is onboarding.wait_for_chat:
return onboarding.ChatInfo( return onboarding.ChatInfo(
chat_id=123, chat_id=123,
@@ -239,7 +240,7 @@ def test_interactive_setup_recovers_from_malformed_toml(monkeypatch, tmp_path) -
def test_capture_chat_id_with_token(monkeypatch) -> None: def test_capture_chat_id_with_token(monkeypatch) -> None:
def _fake_run(func, *args, **kwargs): def _fake_run(func, *args, **kwargs):
if func is onboarding.get_bot_info: if func is onboarding.get_bot_info:
return {"username": "my_bot"} return User(id=1, username="my_bot")
if func is onboarding.wait_for_chat: if func is onboarding.wait_for_chat:
return onboarding.ChatInfo( return onboarding.ChatInfo(
chat_id=456, chat_id=456,
@@ -261,7 +262,9 @@ def test_capture_chat_id_with_token(monkeypatch) -> None:
def test_capture_chat_id_prompts_for_token(monkeypatch) -> None: def test_capture_chat_id_prompts_for_token(monkeypatch) -> None:
monkeypatch.setattr( monkeypatch.setattr(
onboarding, "_prompt_token", lambda _console: ("token", {"username": "bot"}) onboarding,
"_prompt_token",
lambda _console: ("token", User(id=1, username="bot")),
) )
def _fake_run(func, *args, **kwargs): def _fake_run(func, *args, **kwargs):
+64 -14
View File
@@ -9,6 +9,11 @@ from takopi.config import ProjectsConfig
from takopi.model import EngineId from takopi.model import EngineId
from takopi.router import AutoRouter, RunnerEntry from takopi.router import AutoRouter, RunnerEntry
from takopi.runners.mock import Return, ScriptRunner from takopi.runners.mock import Return, ScriptRunner
from takopi.settings import (
TelegramFilesSettings,
TelegramTopicsSettings,
TelegramTransportSettings,
)
from takopi.telegram import backend as telegram_backend from takopi.telegram import backend as telegram_backend
from takopi.transport_runtime import TransportRuntime from takopi.transport_runtime import TransportRuntime
@@ -20,8 +25,13 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
missing = ScriptRunner([Return(answer="ok")], engine=pi) missing = ScriptRunner([Return(answer="ok")], engine=pi)
router = AutoRouter( router = AutoRouter(
entries=[ entries=[
RunnerEntry(engine=codex, runner=runner, available=True), RunnerEntry(engine=codex, runner=runner),
RunnerEntry(engine=pi, runner=missing, available=False, issue="missing"), RunnerEntry(
engine=pi,
runner=missing,
status="missing_cli",
issue="missing",
),
], ],
default_engine=codex, default_engine=codex,
) )
@@ -40,6 +50,44 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
assert "projects: `none`" in message assert "projects: `none`" in message
def test_build_startup_message_surfaces_unavailable_engine_reasons(
tmp_path: Path,
) -> None:
codex = EngineId("codex")
pi = EngineId("pi")
claude = EngineId("claude")
runner = ScriptRunner([Return(answer="ok")], engine=codex)
bad_cfg = ScriptRunner([Return(answer="ok")], engine=pi)
load_err = ScriptRunner([Return(answer="ok")], engine=claude)
router = AutoRouter(
entries=[
RunnerEntry(engine=codex, runner=runner),
RunnerEntry(engine=pi, runner=bad_cfg, status="bad_config", issue="bad"),
RunnerEntry(
engine=claude,
runner=load_err,
status="load_error",
issue="failed",
),
],
default_engine=codex,
)
runtime = TransportRuntime(
router=router,
projects=ProjectsConfig(projects={}, default_project=None),
watch_config=True,
)
message = telegram_backend._build_startup_message(
runtime, startup_pwd=str(tmp_path)
)
assert "agents: `codex" in message
assert "misconfigured: pi" in message
assert "failed to load: claude" in message
def test_telegram_backend_build_and_run_wires_config( def test_telegram_backend_build_and_run_wires_config(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None: ) -> None:
@@ -55,7 +103,7 @@ def test_telegram_backend_build_and_run_wires_config(
codex = EngineId("codex") codex = EngineId("codex")
runner = ScriptRunner([Return(answer="ok")], engine=codex) runner = ScriptRunner([Return(answer="ok")], engine=codex)
router = AutoRouter( router = AutoRouter(
entries=[RunnerEntry(engine=codex, runner=runner, available=True)], entries=[RunnerEntry(engine=codex, runner=runner)],
default_engine=codex, default_engine=codex,
) )
runtime = TransportRuntime( runtime = TransportRuntime(
@@ -80,13 +128,14 @@ def test_telegram_backend_build_and_run_wires_config(
monkeypatch.setattr(telegram_backend, "run_main_loop", fake_run_main_loop) monkeypatch.setattr(telegram_backend, "run_main_loop", fake_run_main_loop)
monkeypatch.setattr(telegram_backend, "TelegramClient", _FakeClient) monkeypatch.setattr(telegram_backend, "TelegramClient", _FakeClient)
transport_config = { transport_config = TelegramTransportSettings(
"bot_token": "token", bot_token="token",
"chat_id": 321, chat_id=321,
"voice_transcription": True, voice_transcription=True,
"files": {"enabled": True, "allowed_user_ids": [1, 2]}, voice_max_bytes=1234,
"topics": {"enabled": True, "scope": "main"}, files=TelegramFilesSettings(enabled=True, allowed_user_ids=[1, 2]),
} topics=TelegramTopicsSettings(enabled=True, scope="main"),
)
telegram_backend.TelegramBackend().build_and_run( telegram_backend.TelegramBackend().build_and_run(
transport_config=transport_config, transport_config=transport_config,
@@ -100,18 +149,19 @@ def test_telegram_backend_build_and_run_wires_config(
kwargs = captured["kwargs"] kwargs = captured["kwargs"]
assert cfg.chat_id == 321 assert cfg.chat_id == 321
assert cfg.voice_transcription is True assert cfg.voice_transcription is True
assert cfg.voice_max_bytes == 1234
assert cfg.files.enabled is True assert cfg.files.enabled is True
assert cfg.files.allowed_user_ids == frozenset({1, 2}) assert cfg.files.allowed_user_ids == [1, 2]
assert cfg.topics.enabled is True assert cfg.topics.enabled is True
assert cfg.bot.token == "token" assert cfg.bot.token == "token"
assert kwargs["watch_config"] is True assert kwargs["watch_config"] is True
assert kwargs["transport_id"] == "telegram" assert kwargs["transport_id"] == "telegram"
def test_build_files_config_defaults() -> None: def test_telegram_files_settings_defaults() -> None:
cfg = telegram_backend._build_files_config({}) cfg = TelegramFilesSettings()
assert cfg.enabled is False assert cfg.enabled is False
assert cfg.auto_put is True assert cfg.auto_put is True
assert cfg.uploads_dir == "incoming" assert cfg.uploads_dir == "incoming"
assert cfg.allowed_user_ids == frozenset() assert cfg.allowed_user_ids == []
+44 -40
View File
@@ -6,22 +6,30 @@ import anyio
import pytest import pytest
from takopi import commands, plugins from takopi import commands, plugins
import takopi.telegram.bridge as bridge
import takopi.telegram.loop as telegram_loop import takopi.telegram.loop as telegram_loop
import takopi.telegram.commands as telegram_commands import takopi.telegram.commands as telegram_commands
import takopi.telegram.topics as telegram_topics import takopi.telegram.topics as telegram_topics
from takopi.directives import parse_directives from takopi.directives import parse_directives
from takopi.telegram.api_models import (
Chat,
ChatMember,
File,
ForumTopic,
Message,
Update,
User,
)
from takopi.settings import TelegramFilesSettings, TelegramTopicsSettings
from takopi.telegram.bridge import ( from takopi.telegram.bridge import (
TelegramBridgeConfig, TelegramBridgeConfig,
TelegramFilesConfig,
TelegramPresenter, TelegramPresenter,
TelegramTransport, TelegramTransport,
build_bot_commands, build_bot_commands,
handle_callback_cancel, handle_callback_cancel,
handle_cancel, handle_cancel,
is_cancel_command, is_cancel_command,
send_with_resume,
run_main_loop, run_main_loop,
send_with_resume,
) )
from takopi.telegram.client import BotClient from takopi.telegram.client import BotClient
from takopi.telegram.topic_state import TopicStateStore, resolve_state_path from takopi.telegram.topic_state import TopicStateStore, resolve_state_path
@@ -122,13 +130,13 @@ class _FakeBot(BotClient):
offset: int | None, offset: int | None,
timeout_s: int = 50, timeout_s: int = 50,
allowed_updates: list[str] | None = None, allowed_updates: list[str] | None = None,
) -> list[dict[str, Any]] | None: ) -> list[Update] | None:
_ = offset _ = offset
_ = timeout_s _ = timeout_s
_ = allowed_updates _ = allowed_updates
return [] return []
async def get_file(self, file_id: str) -> dict[str, Any] | None: async def get_file(self, file_id: str) -> File | None:
_ = file_id _ = file_id
return None return None
@@ -148,7 +156,7 @@ class _FakeBot(BotClient):
reply_markup: dict | None = None, reply_markup: dict | None = None,
*, *,
replace_message_id: int | None = None, replace_message_id: int | None = None,
) -> dict[str, Any]: ) -> Message:
self.send_calls.append( self.send_calls.append(
{ {
"chat_id": chat_id, "chat_id": chat_id,
@@ -162,7 +170,7 @@ class _FakeBot(BotClient):
"replace_message_id": replace_message_id, "replace_message_id": replace_message_id,
} }
) )
return {"message_id": 1} return Message(message_id=1)
async def send_document( async def send_document(
self, self,
@@ -173,7 +181,7 @@ class _FakeBot(BotClient):
message_thread_id: int | None = None, message_thread_id: int | None = None,
disable_notification: bool | None = False, disable_notification: bool | None = False,
caption: str | None = None, caption: str | None = None,
) -> dict[str, Any]: ) -> Message:
self.document_calls.append( self.document_calls.append(
{ {
"chat_id": chat_id, "chat_id": chat_id,
@@ -185,7 +193,7 @@ class _FakeBot(BotClient):
"caption": caption, "caption": caption,
} }
) )
return {"message_id": 2} return Message(message_id=2)
async def edit_message_text( async def edit_message_text(
self, self,
@@ -197,7 +205,7 @@ class _FakeBot(BotClient):
reply_markup: dict | None = None, reply_markup: dict | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict[str, Any]: ) -> Message:
self.edit_calls.append( self.edit_calls.append(
{ {
"chat_id": chat_id, "chat_id": chat_id,
@@ -209,7 +217,7 @@ class _FakeBot(BotClient):
"wait": wait, "wait": wait,
} }
) )
return {"message_id": message_id} return Message(message_id=message_id)
async def delete_message(self, chat_id: int, message_id: int) -> bool: async def delete_message(self, chat_id: int, message_id: int) -> bool:
self.delete_calls.append({"chat_id": chat_id, "message_id": message_id}) self.delete_calls.append({"chat_id": chat_id, "message_id": message_id})
@@ -231,26 +239,22 @@ class _FakeBot(BotClient):
) )
return True return True
async def get_me(self) -> dict[str, Any] | None: async def get_me(self) -> User | None:
return {"id": 1} return User(id=1, username="bot")
async def get_chat(self, chat_id: int) -> dict[str, Any] | None: async def get_chat(self, chat_id: int) -> Chat | None:
_ = chat_id _ = chat_id
return {"id": chat_id, "type": "supergroup", "is_forum": True} return Chat(id=chat_id, type="supergroup", is_forum=True)
async def get_chat_member( async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
self, chat_id: int, user_id: int
) -> dict[str, Any] | None:
_ = chat_id _ = chat_id
_ = user_id _ = user_id
return {"status": "administrator", "can_manage_topics": True} return ChatMember(status="administrator", can_manage_topics=True)
async def create_forum_topic( async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
self, chat_id: int, name: str
) -> dict[str, Any] | None:
_ = chat_id _ = chat_id
_ = name _ = name
return {"message_thread_id": 1} return ForumTopic(message_thread_id=1)
async def edit_forum_topic( async def edit_forum_topic(
self, chat_id: int, message_thread_id: int, name: str self, chat_id: int, message_thread_id: int, name: str
@@ -539,10 +543,10 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
offset: int | None, offset: int | None,
timeout_s: int = 50, timeout_s: int = 50,
allowed_updates: list[str] | None = None, allowed_updates: list[str] | None = None,
) -> list[dict[str, Any]] | None: ) -> list[Update] | None:
return None return None
async def get_file(self, file_id: str) -> dict[str, Any] | None: async def get_file(self, file_id: str) -> File | None:
_ = file_id _ = file_id
return None return None
@@ -562,7 +566,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
reply_markup: dict | None = None, reply_markup: dict | None = None,
*, *,
replace_message_id: int | None = None, replace_message_id: int | None = None,
) -> dict | None: ) -> Message | None:
_ = reply_markup _ = reply_markup
return None return None
@@ -575,7 +579,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
message_thread_id: int | None = None, message_thread_id: int | None = None,
disable_notification: bool | None = False, disable_notification: bool | None = False,
caption: str | None = None, caption: str | None = None,
) -> dict | None: ) -> Message | None:
_ = ( _ = (
chat_id, chat_id,
filename, filename,
@@ -597,7 +601,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
reply_markup: dict | None = None, reply_markup: dict | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict | None: ) -> Message | None:
self.edit_calls.append( self.edit_calls.append(
{ {
"chat_id": chat_id, "chat_id": chat_id,
@@ -611,7 +615,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
) )
if not wait: if not wait:
return None return None
return {"message_id": message_id} return Message(message_id=message_id)
async def delete_message( async def delete_message(
self, self,
@@ -629,7 +633,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
) -> bool: ) -> bool:
return False return False
async def get_me(self) -> dict[str, Any] | None: async def get_me(self) -> User | None:
return None return None
async def close(self) -> None: async def close(self) -> None:
@@ -778,9 +782,9 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
payload = b"hello" payload = b"hello"
class _FileBot(_FakeBot): class _FileBot(_FakeBot):
async def get_file(self, file_id: str) -> dict[str, Any] | None: async def get_file(self, file_id: str) -> File | None:
_ = file_id _ = file_id
return {"file_path": "files/hello.txt"} return File(file_path="files/hello.txt")
async def download_file(self, file_path: str) -> bytes | None: async def download_file(self, file_path: str) -> bytes | None:
_ = file_path _ = file_path
@@ -811,7 +815,7 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
chat_id=123, chat_id=123,
startup_msg="", startup_msg="",
exec_cfg=exec_cfg, exec_cfg=exec_cfg,
files=TelegramFilesConfig(enabled=True), files=TelegramFilesSettings(enabled=True),
) )
msg = TelegramIncomingMessage( msg = TelegramIncomingMessage(
transport="telegram", transport="telegram",
@@ -876,9 +880,9 @@ async def test_handle_file_get_sends_document_for_allowed_user(
chat_id=123, chat_id=123,
startup_msg="", startup_msg="",
exec_cfg=exec_cfg, exec_cfg=exec_cfg,
files=TelegramFilesConfig( files=TelegramFilesSettings(
enabled=True, enabled=True,
allowed_user_ids=frozenset({42}), allowed_user_ids=[42],
), ),
) )
msg = TelegramIncomingMessage( msg = TelegramIncomingMessage(
@@ -1006,7 +1010,7 @@ def test_topic_title_projects_scope_includes_project() -> None:
transport = _FakeTransport() transport = _FakeTransport()
cfg = replace( cfg = replace(
_make_cfg(transport), _make_cfg(transport),
topics=bridge.TelegramTopicsConfig( topics=TelegramTopicsSettings(
enabled=True, enabled=True,
scope="projects", scope="projects",
), ),
@@ -1277,7 +1281,7 @@ async def test_run_main_loop_persists_topic_sessions_in_project_scope(
chat_id=123, chat_id=123,
startup_msg="", startup_msg="",
exec_cfg=exec_cfg, exec_cfg=exec_cfg,
topics=bridge.TelegramTopicsConfig( topics=TelegramTopicsSettings(
enabled=True, enabled=True,
scope="projects", scope="projects",
), ),
@@ -1363,11 +1367,11 @@ async def test_run_main_loop_batches_media_group_upload(
} }
class _MediaBot(_FakeBot): class _MediaBot(_FakeBot):
async def get_file(self, file_id: str) -> dict[str, Any] | None: async def get_file(self, file_id: str) -> File | None:
file_path = file_map.get(file_id) file_path = file_map.get(file_id)
if file_path is None: if file_path is None:
return None return None
return {"file_path": file_path} return File(file_path=file_path)
async def download_file(self, file_path: str) -> bytes | None: async def download_file(self, file_path: str) -> bytes | None:
return payloads.get(file_path) return payloads.get(file_path)
@@ -1397,7 +1401,7 @@ async def test_run_main_loop_batches_media_group_upload(
chat_id=123, chat_id=123,
startup_msg="", startup_msg="",
exec_cfg=exec_cfg, exec_cfg=exec_cfg,
files=TelegramFilesConfig(enabled=True, auto_put=True), files=TelegramFilesSettings(enabled=True, auto_put=True),
) )
msg1 = TelegramIncomingMessage( msg1 = TelegramIncomingMessage(
transport="telegram", transport="telegram",
+172
View File
@@ -57,3 +57,175 @@ async def test_no_token_in_logs_on_http_error(
out = capsys.readouterr().out out = capsys.readouterr().out
assert token not in out assert token not in out
assert "bot[REDACTED]" in out assert "bot[REDACTED]" in out
@pytest.mark.anyio
async def test_telegram_429_no_retry_post_form() -> None:
calls: list[int] = []
def handler(request: httpx.Request) -> httpx.Response:
calls.append(1)
return httpx.Response(
429,
json={
"ok": False,
"description": "retry",
"parameters": {"retry_after": 2},
},
request=request,
)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
with pytest.raises(TelegramRetryAfter) as exc:
await tg._post_form(
"sendDocument",
{"chat_id": 1},
files={"document": ("note.txt", b"hi")},
)
finally:
await client.aclose()
assert exc.value.retry_after == 2
assert len(calls) == 1
@pytest.mark.anyio
async def test_telegram_429_defaults_retry_after_on_bad_body() -> None:
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(429, text="nope", request=request)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
with pytest.raises(TelegramRetryAfter) as exc:
await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
finally:
await client.aclose()
assert exc.value.retry_after == 5.0
@pytest.mark.anyio
async def test_telegram_ok_false_returns_none() -> None:
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(
200,
json={"ok": False, "error_code": 400, "description": "bad"},
request=request,
)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
result = await tg._post("getUpdates", {"timeout": 1})
finally:
await client.aclose()
assert result is None
@pytest.mark.anyio
async def test_telegram_invalid_payload_returns_none() -> None:
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json=["not", "a", "dict"], request=request)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
result = await tg._post("getUpdates", {"timeout": 1})
finally:
await client.aclose()
assert result is None
@pytest.mark.anyio
async def test_telegram_decode_failure_returns_none() -> None:
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(
200,
json={"ok": True, "result": {"username": "bot-only"}},
request=request,
)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
result = await tg.get_me()
finally:
await client.aclose()
assert result is None
@pytest.mark.anyio
async def test_telegram_download_file_retries_on_429() -> None:
calls: list[int] = []
sleeps: list[float] = []
async def sleep(delay: float) -> None:
sleeps.append(delay)
def handler(request: httpx.Request) -> httpx.Response:
calls.append(1)
if len(calls) == 1:
return httpx.Response(
429,
json={"ok": False, "parameters": {"retry_after": 3}},
request=request,
)
return httpx.Response(200, content=b"ok", request=request)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client, sleep=sleep)
payload = await tg.download_file("path/to/file")
finally:
await client.aclose()
assert payload == b"ok"
assert sleeps == [3.0]
assert len(calls) == 2
@pytest.mark.anyio
async def test_telegram_download_file_429_defaults_retry_after_on_bad_body() -> None:
sleeps: list[float] = []
async def sleep(delay: float) -> None:
sleeps.append(delay)
calls: list[int] = []
def handler(request: httpx.Request) -> httpx.Response:
calls.append(1)
if len(calls) == 1:
return httpx.Response(429, text="nope", request=request)
return httpx.Response(200, content=b"ok", request=request)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client, sleep=sleep)
payload = await tg.download_file("path")
finally:
await client.aclose()
assert payload == b"ok"
assert sleeps == [5.0]
assert len(calls) == 2
+14 -12
View File
@@ -3,6 +3,7 @@ from typing import Any
import anyio import anyio
import pytest import pytest
from takopi.telegram.api_models import File, Message, Update, User
from takopi.telegram.client import BotClient, TelegramClient, TelegramRetryAfter from takopi.telegram.client import BotClient, TelegramClient, TelegramRetryAfter
@@ -29,7 +30,7 @@ class _FakeBot(BotClient):
reply_markup: dict | None = None, reply_markup: dict | None = None,
*, *,
replace_message_id: int | None = None, replace_message_id: int | None = None,
) -> dict[str, Any]: ) -> Message | None:
_ = reply_to_message_id _ = reply_to_message_id
_ = disable_notification _ = disable_notification
_ = message_thread_id _ = message_thread_id
@@ -38,7 +39,7 @@ class _FakeBot(BotClient):
_ = reply_markup _ = reply_markup
_ = replace_message_id _ = replace_message_id
self.calls.append("send_message") self.calls.append("send_message")
return {"message_id": 1} return Message(message_id=1)
async def send_document( async def send_document(
self, self,
@@ -49,7 +50,7 @@ class _FakeBot(BotClient):
message_thread_id: int | None = None, message_thread_id: int | None = None,
disable_notification: bool | None = False, disable_notification: bool | None = False,
caption: str | None = None, caption: str | None = None,
) -> dict[str, Any]: ) -> Message | None:
_ = ( _ = (
chat_id, chat_id,
filename, filename,
@@ -60,7 +61,7 @@ class _FakeBot(BotClient):
caption, caption,
) )
self.calls.append("send_document") self.calls.append("send_document")
return {"message_id": 1} return Message(message_id=1)
async def edit_message_text( async def edit_message_text(
self, self,
@@ -72,7 +73,7 @@ class _FakeBot(BotClient):
reply_markup: dict | None = None, reply_markup: dict | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict[str, Any]: ) -> Message | None:
_ = chat_id _ = chat_id
_ = message_id _ = message_id
_ = entities _ = entities
@@ -85,7 +86,7 @@ class _FakeBot(BotClient):
self._edit_attempts += 1 self._edit_attempts += 1
raise TelegramRetryAfter(self.retry_after) raise TelegramRetryAfter(self.retry_after)
self._edit_attempts += 1 self._edit_attempts += 1
return {"message_id": message_id} return Message(message_id=message_id)
async def delete_message( async def delete_message(
self, self,
@@ -113,7 +114,7 @@ class _FakeBot(BotClient):
offset: int | None, offset: int | None,
timeout_s: int = 50, timeout_s: int = 50,
allowed_updates: list[str] | None = None, allowed_updates: list[str] | None = None,
) -> list[dict[str, Any]] | None: ) -> list[Update] | None:
_ = offset _ = offset
_ = timeout_s _ = timeout_s
_ = allowed_updates _ = allowed_updates
@@ -123,7 +124,7 @@ class _FakeBot(BotClient):
self._updates_attempts += 1 self._updates_attempts += 1
return [] return []
async def get_file(self, file_id: str) -> dict[str, Any] | None: async def get_file(self, file_id: str) -> File | None:
_ = file_id _ = file_id
return None return None
@@ -134,8 +135,8 @@ class _FakeBot(BotClient):
async def close(self) -> None: async def close(self) -> None:
return None return None
async def get_me(self) -> dict[str, Any] | None: async def get_me(self) -> User | None:
return {"id": 1} return User(id=1)
async def answer_callback_query( async def answer_callback_query(
self, self,
@@ -187,7 +188,7 @@ async def test_edits_coalesce_latest() -> None:
reply_markup: dict | None = None, reply_markup: dict | None = None,
*, *,
wait: bool = True, wait: bool = True,
) -> dict: ) -> Message | None:
if self._block_first: if self._block_first:
self._block_first = False self._block_first = False
self.edit_started.set() self.edit_started.set()
@@ -305,7 +306,8 @@ async def test_retry_after_retries_once() -> None:
text="retry", text="retry",
) )
assert result == {"message_id": 1} assert result is not None
assert result.message_id == 1
assert bot._edit_attempts == 2 assert bot._edit_attempts == 2
+243
View File
@@ -0,0 +1,243 @@
from __future__ import annotations
import pytest
from takopi.telegram.api_models import (
Chat,
ChatMember,
File,
ForumTopic,
Message,
Update,
User,
)
from takopi.telegram.client import BotClient
from takopi.telegram.types import TelegramIncomingMessage, TelegramVoice
from takopi.telegram.voice import transcribe_voice
class _Bot(BotClient):
def __init__(self, *, file_info: File | None, audio: bytes | None) -> None:
self._file_info = file_info
self._audio = audio
async def close(self) -> None:
return None
async def get_updates(
self,
offset: int | None,
timeout_s: int = 50,
allowed_updates: list[str] | None = None,
) -> list[Update] | None:
_ = offset, timeout_s, allowed_updates
return []
async def get_file(self, file_id: str) -> File | None:
_ = file_id
return self._file_info
async def download_file(self, file_path: str) -> bytes | None:
_ = file_path
return self._audio
async def send_message(
self,
chat_id: int,
text: str,
reply_to_message_id: int | None = None,
disable_notification: bool | None = False,
message_thread_id: int | None = None,
entities: list[dict] | None = None,
parse_mode: str | None = None,
reply_markup: dict | None = None,
*,
replace_message_id: int | None = None,
) -> Message | None:
_ = (
chat_id,
text,
reply_to_message_id,
disable_notification,
message_thread_id,
entities,
parse_mode,
reply_markup,
replace_message_id,
)
raise AssertionError("send_message should not be called")
async def send_document(
self,
chat_id: int,
filename: str,
content: bytes,
reply_to_message_id: int | None = None,
message_thread_id: int | None = None,
disable_notification: bool | None = False,
caption: str | None = None,
) -> Message | None:
_ = (
chat_id,
filename,
content,
reply_to_message_id,
message_thread_id,
disable_notification,
caption,
)
raise AssertionError("send_document should not be called")
async def edit_message_text(
self,
chat_id: int,
message_id: int,
text: str,
entities: list[dict] | None = None,
parse_mode: str | None = None,
reply_markup: dict | None = None,
*,
wait: bool = True,
) -> Message | None:
_ = (
chat_id,
message_id,
text,
entities,
parse_mode,
reply_markup,
wait,
)
raise AssertionError("edit_message_text should not be called")
async def delete_message(self, chat_id: int, message_id: int) -> bool:
_ = chat_id, message_id
raise AssertionError("delete_message should not be called")
async def set_my_commands(
self,
commands: list[dict],
*,
scope: dict | None = None,
language_code: str | None = None,
) -> bool:
_ = commands, scope, language_code
raise AssertionError("set_my_commands should not be called")
async def get_me(self) -> User | None:
raise AssertionError("get_me should not be called")
async def answer_callback_query(
self,
callback_query_id: str,
text: str | None = None,
show_alert: bool | None = None,
) -> bool:
_ = callback_query_id, text, show_alert
raise AssertionError("answer_callback_query should not be called")
async def get_chat(self, chat_id: int) -> Chat | None:
_ = chat_id
raise AssertionError("get_chat should not be called")
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
_ = chat_id, user_id
raise AssertionError("get_chat_member should not be called")
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
_ = chat_id, name
raise AssertionError("create_forum_topic should not be called")
async def edit_forum_topic(
self, chat_id: int, message_thread_id: int, name: str
) -> bool:
_ = chat_id, message_thread_id, name
raise AssertionError("edit_forum_topic should not be called")
def _voice_message(*, file_size: int = 123) -> TelegramIncomingMessage:
voice = TelegramVoice(
file_id="voice-id",
mime_type="audio/ogg",
file_size=file_size,
duration=1,
raw={},
)
return TelegramIncomingMessage(
transport="telegram",
chat_id=1,
message_id=1,
text="",
reply_to_message_id=None,
reply_to_text=None,
sender_id=1,
voice=voice,
raw={},
)
@pytest.mark.anyio
async def test_transcribe_voice_handles_missing_file() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
bot = _Bot(file_info=None, audio=None)
result = await transcribe_voice(
bot=bot,
msg=_voice_message(),
enabled=True,
reply=reply,
)
assert result is None
assert replies[-1] == "failed to fetch voice file."
@pytest.mark.anyio
async def test_transcribe_voice_handles_missing_download() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
bot = _Bot(file_info=File(file_path="voice.ogg"), audio=None)
result = await transcribe_voice(
bot=bot,
msg=_voice_message(),
enabled=True,
reply=reply,
)
assert result is None
assert replies[-1] == "failed to download voice file."
@pytest.mark.anyio
async def test_transcribe_voice_rejects_large_voice_without_downloading() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
class _NoFetchBot(_Bot):
async def get_file(self, file_id: str) -> File | None: # type: ignore[override]
_ = file_id
raise AssertionError("get_file should not be called")
async def download_file(self, file_path: str) -> bytes | None: # type: ignore[override]
_ = file_path
raise AssertionError("download_file should not be called")
bot = _NoFetchBot(file_info=None, audio=None)
result = await transcribe_voice(
bot=bot,
msg=_voice_message(file_size=10_000),
enabled=True,
max_bytes=100,
reply=reply,
)
assert result is None
assert replies[-1] == "voice message is too large to transcribe."
+2 -2
View File
@@ -15,14 +15,14 @@ class DummyTransport:
def interactive_setup(self, *, force: bool) -> bool: def interactive_setup(self, *, force: bool) -> bool:
raise NotImplementedError raise NotImplementedError
def lock_token(self, *, transport_config: dict[str, object], config_path): def lock_token(self, *, transport_config: object, config_path):
_ = transport_config, config_path _ = transport_config, config_path
raise NotImplementedError raise NotImplementedError
def build_and_run( def build_and_run(
self, self,
*, *,
transport_config: dict[str, object], transport_config: object,
config_path, config_path,
runtime, runtime,
final_notify: bool, final_notify: bool,