refactor(telegram): boundary types (#90)
This commit is contained in:
+7
-8
@@ -105,11 +105,7 @@ def _default_engine_for_setup(
|
||||
if settings is None or config_path is None:
|
||||
return "codex"
|
||||
value = settings.default_engine
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
raise ConfigError(
|
||||
f"Invalid `default_engine` in {config_path}; expected a non-empty string."
|
||||
)
|
||||
return value.strip()
|
||||
return value
|
||||
|
||||
|
||||
def _config_path_display(path: Path) -> str:
|
||||
@@ -235,9 +231,12 @@ def _run_auto_router(
|
||||
default_engine_override=default_engine_override,
|
||||
reserved=("cancel",),
|
||||
)
|
||||
transport_config = settings.transport_config(
|
||||
settings.transport, config_path=config_path
|
||||
)
|
||||
if settings.transport == "telegram":
|
||||
transport_config = settings.transports.telegram
|
||||
else:
|
||||
transport_config = settings.transport_config(
|
||||
settings.transport, config_path=config_path
|
||||
)
|
||||
lock_token = transport_backend.lock_token(
|
||||
transport_config=transport_config,
|
||||
config_path=config_path,
|
||||
|
||||
+11
-2
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
from typing import Iterable, Literal, TypeAlias
|
||||
|
||||
from .model import EngineId, ResumeToken
|
||||
from .runner import Runner
|
||||
@@ -17,13 +17,22 @@ class RunnerUnavailableError(RuntimeError):
|
||||
self.issue = issue
|
||||
|
||||
|
||||
EngineStatus: TypeAlias = Literal["ok", "missing_cli", "bad_config", "load_error"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RunnerEntry:
|
||||
engine: EngineId
|
||||
runner: Runner
|
||||
available: bool = True
|
||||
status: EngineStatus = "ok"
|
||||
issue: str | None = None
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
# "bad_config" means we ignored user config and built the runner with defaults.
|
||||
# The engine is still runnable, but a warning should be surfaced to the user.
|
||||
return self.status in {"ok", "bad_config"}
|
||||
|
||||
|
||||
class AutoRouter:
|
||||
def __init__(
|
||||
|
||||
@@ -9,7 +9,7 @@ from .backends import EngineBackend
|
||||
from .config import ConfigError, ProjectsConfig
|
||||
from .engines import get_backend, list_backend_ids
|
||||
from .logging import get_logger
|
||||
from .router import AutoRouter, RunnerEntry
|
||||
from .router import AutoRouter, EngineStatus, RunnerEntry
|
||||
from .settings import TakopiSettings
|
||||
from .transport_runtime import TransportRuntime
|
||||
|
||||
@@ -83,6 +83,7 @@ def build_router(
|
||||
for backend in backends:
|
||||
engine_id = backend.id
|
||||
issue: str | None = None
|
||||
status: EngineStatus = "ok"
|
||||
engine_cfg: dict
|
||||
try:
|
||||
engine_cfg = settings.engine_config(engine_id, config_path=config_path)
|
||||
@@ -90,6 +91,7 @@ def build_router(
|
||||
if engine_id == default_engine:
|
||||
raise
|
||||
issue = str(exc)
|
||||
status = "bad_config"
|
||||
engine_cfg = {}
|
||||
|
||||
try:
|
||||
@@ -104,26 +106,31 @@ def build_router(
|
||||
except Exception as fallback_exc:
|
||||
warnings.append(f"{engine_id}: {issue or str(fallback_exc)}")
|
||||
continue
|
||||
status = "bad_config"
|
||||
else:
|
||||
status = "load_error"
|
||||
warnings.append(f"{engine_id}: {issue}")
|
||||
continue
|
||||
|
||||
cmd = backend.cli_cmd or backend.id
|
||||
if shutil.which(cmd) is None:
|
||||
issue = issue or f"{cmd} not found on PATH"
|
||||
status = "missing_cli"
|
||||
if issue:
|
||||
issue = f"{issue}; {cmd} not found on PATH"
|
||||
else:
|
||||
issue = f"{cmd} not found on PATH"
|
||||
|
||||
if issue and engine_id == default_engine:
|
||||
if status != "ok" and engine_id == default_engine:
|
||||
raise ConfigError(f"Default engine {engine_id!r} unavailable: {issue}")
|
||||
|
||||
available = issue is None
|
||||
if issue and engine_id != default_engine:
|
||||
if status != "ok" and engine_id != default_engine:
|
||||
warnings.append(f"{engine_id}: {issue}")
|
||||
|
||||
entries.append(
|
||||
RunnerEntry(
|
||||
engine=engine_id,
|
||||
runner=runner,
|
||||
available=available,
|
||||
status=status,
|
||||
issue=issue,
|
||||
)
|
||||
)
|
||||
|
||||
+11
-6
@@ -83,6 +83,14 @@ class TelegramFilesSettings(BaseModel):
|
||||
raise ValueError("files.uploads_dir must be a relative path")
|
||||
return value
|
||||
|
||||
@property
|
||||
def max_upload_bytes(self) -> int:
|
||||
return 20 * 1024 * 1024
|
||||
|
||||
@property
|
||||
def max_download_bytes(self) -> int:
|
||||
return 50 * 1024 * 1024
|
||||
|
||||
|
||||
class TelegramTransportSettings(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
|
||||
@@ -90,14 +98,13 @@ class TelegramTransportSettings(BaseModel):
|
||||
bot_token: NonEmptyStr
|
||||
chat_id: StrictInt
|
||||
voice_transcription: bool = False
|
||||
voice_max_bytes: StrictInt = 10 * 1024 * 1024
|
||||
topics: TelegramTopicsSettings = Field(default_factory=TelegramTopicsSettings)
|
||||
files: TelegramFilesSettings = Field(default_factory=TelegramFilesSettings)
|
||||
|
||||
|
||||
class TransportsSettings(BaseModel):
|
||||
telegram: TelegramTransportSettings = Field(
|
||||
default_factory=TelegramTransportSettings
|
||||
)
|
||||
telegram: TelegramTransportSettings
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -132,7 +139,7 @@ class TakopiSettings(BaseSettings):
|
||||
projects: dict[str, ProjectSettings] = Field(default_factory=dict)
|
||||
|
||||
transport: NonEmptyStr = "telegram"
|
||||
transports: TransportsSettings = Field(default_factory=TransportsSettings)
|
||||
transports: TransportsSettings
|
||||
|
||||
plugins: PluginsSettings = Field(default_factory=PluginsSettings)
|
||||
|
||||
@@ -310,8 +317,6 @@ def require_telegram(settings: TakopiSettings, config_path: Path) -> tuple[str,
|
||||
"(telegram only for now)."
|
||||
)
|
||||
tg = settings.transports.telegram
|
||||
if not tg.bot_token:
|
||||
raise ConfigError(f"Missing bot token in {config_path}.")
|
||||
return tg.bot_token, tg.chat_id
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import msgspec
|
||||
|
||||
__all__ = [
|
||||
"Chat",
|
||||
"ChatMember",
|
||||
"File",
|
||||
"ForumTopic",
|
||||
"Message",
|
||||
"Update",
|
||||
"User",
|
||||
]
|
||||
|
||||
|
||||
class User(msgspec.Struct, forbid_unknown_fields=False):
|
||||
id: int
|
||||
username: str | None = None
|
||||
first_name: str | None = None
|
||||
last_name: str | None = None
|
||||
|
||||
|
||||
class Chat(msgspec.Struct, forbid_unknown_fields=False):
|
||||
id: int
|
||||
type: str
|
||||
is_forum: bool | None = None
|
||||
|
||||
|
||||
class ChatMember(msgspec.Struct, forbid_unknown_fields=False):
|
||||
status: str
|
||||
can_manage_topics: bool | None = None
|
||||
|
||||
|
||||
class Message(msgspec.Struct, forbid_unknown_fields=False):
|
||||
message_id: int
|
||||
message_thread_id: int | None = None
|
||||
text: str | None = None
|
||||
|
||||
|
||||
class File(msgspec.Struct, forbid_unknown_fields=False):
|
||||
file_path: str
|
||||
|
||||
|
||||
class ForumTopic(msgspec.Struct, forbid_unknown_fields=False):
|
||||
message_thread_id: int
|
||||
|
||||
|
||||
class Update(msgspec.Struct, forbid_unknown_fields=False):
|
||||
update_id: int
|
||||
message: dict[str, Any] | None = None
|
||||
callback_query: dict[str, Any] | None = None
|
||||
@@ -2,22 +2,19 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
|
||||
from ..backends import EngineBackend
|
||||
from ..runner_bridge import ExecBridgeConfig
|
||||
from ..logging import get_logger
|
||||
|
||||
from ..transports import SetupResult, TransportBackend
|
||||
from ..runner_bridge import ExecBridgeConfig
|
||||
from ..settings import TelegramTransportSettings
|
||||
from ..transport_runtime import TransportRuntime
|
||||
from ..transports import SetupResult, TransportBackend
|
||||
from .bridge import (
|
||||
TelegramBridgeConfig,
|
||||
TelegramPresenter,
|
||||
TelegramTransport,
|
||||
TelegramFilesConfig,
|
||||
TelegramTopicsConfig,
|
||||
run_main_loop,
|
||||
)
|
||||
from .client import TelegramClient
|
||||
@@ -26,6 +23,12 @@ from .onboarding import check_setup, interactive_setup
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _expect_transport_settings(transport_config: object) -> TelegramTransportSettings:
|
||||
if isinstance(transport_config, TelegramTransportSettings):
|
||||
return transport_config
|
||||
raise TypeError("transport_config must be TelegramTransportSettings")
|
||||
|
||||
|
||||
def _build_startup_message(
|
||||
runtime: TransportRuntime,
|
||||
*,
|
||||
@@ -33,9 +36,20 @@ def _build_startup_message(
|
||||
) -> str:
|
||||
available_engines = list(runtime.available_engine_ids())
|
||||
missing_engines = list(runtime.missing_engine_ids())
|
||||
misconfigured_engines = list(runtime.engine_ids_with_status("bad_config"))
|
||||
failed_engines = list(runtime.engine_ids_with_status("load_error"))
|
||||
|
||||
engine_list = ", ".join(available_engines) if available_engines else "none"
|
||||
|
||||
notes: list[str] = []
|
||||
if missing_engines:
|
||||
engine_list = f"{engine_list} (not installed: {', '.join(missing_engines)})"
|
||||
notes.append(f"not installed: {', '.join(missing_engines)}")
|
||||
if misconfigured_engines:
|
||||
notes.append(f"misconfigured: {', '.join(misconfigured_engines)}")
|
||||
if failed_engines:
|
||||
notes.append(f"failed to load: {', '.join(failed_engines)}")
|
||||
if notes:
|
||||
engine_list = f"{engine_list} ({'; '.join(notes)})"
|
||||
project_aliases = sorted(
|
||||
{alias for alias in runtime.project_aliases()}, key=str.lower
|
||||
)
|
||||
@@ -49,34 +63,6 @@ def _build_startup_message(
|
||||
)
|
||||
|
||||
|
||||
def _build_topics_config(transport_config: dict[str, object]) -> TelegramTopicsConfig:
|
||||
raw = cast(dict[str, object], transport_config.get("topics", {}))
|
||||
return TelegramTopicsConfig(
|
||||
enabled=cast(bool, raw.get("enabled", False)),
|
||||
scope=cast(str, raw.get("scope", "auto")),
|
||||
)
|
||||
|
||||
|
||||
def _build_files_config(transport_config: dict[str, object]) -> TelegramFilesConfig:
|
||||
defaults = TelegramFilesConfig()
|
||||
raw = cast(dict[str, object], transport_config.get("files", {}))
|
||||
return TelegramFilesConfig(
|
||||
enabled=cast(bool, raw.get("enabled", defaults.enabled)),
|
||||
auto_put=cast(bool, raw.get("auto_put", defaults.auto_put)),
|
||||
uploads_dir=cast(str, raw.get("uploads_dir", defaults.uploads_dir)),
|
||||
max_upload_bytes=defaults.max_upload_bytes,
|
||||
max_download_bytes=defaults.max_download_bytes,
|
||||
allowed_user_ids=frozenset(
|
||||
cast(
|
||||
list[int], raw.get("allowed_user_ids", list(defaults.allowed_user_ids))
|
||||
)
|
||||
),
|
||||
deny_globs=tuple(
|
||||
cast(list[str], raw.get("deny_globs", list(defaults.deny_globs)))
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TelegramBackend(TransportBackend):
|
||||
id = "telegram"
|
||||
description = "Telegram bot"
|
||||
@@ -92,23 +78,23 @@ class TelegramBackend(TransportBackend):
|
||||
def interactive_setup(self, *, force: bool) -> bool:
|
||||
return interactive_setup(force=force)
|
||||
|
||||
def lock_token(
|
||||
self, *, transport_config: dict[str, object], config_path: Path
|
||||
) -> str | None:
|
||||
def lock_token(self, *, transport_config: object, config_path: Path) -> str | None:
|
||||
_ = config_path
|
||||
return cast(str, transport_config.get("bot_token"))
|
||||
settings = _expect_transport_settings(transport_config)
|
||||
return settings.bot_token
|
||||
|
||||
def build_and_run(
|
||||
self,
|
||||
*,
|
||||
transport_config: dict[str, object],
|
||||
transport_config: object,
|
||||
config_path: Path,
|
||||
runtime: TransportRuntime,
|
||||
final_notify: bool,
|
||||
default_engine_override: str | None,
|
||||
) -> None:
|
||||
token = cast(str, transport_config.get("bot_token"))
|
||||
chat_id = cast(int, transport_config.get("chat_id"))
|
||||
settings = _expect_transport_settings(transport_config)
|
||||
token = settings.bot_token
|
||||
chat_id = settings.chat_id
|
||||
startup_msg = _build_startup_message(
|
||||
runtime,
|
||||
startup_pwd=os.getcwd(),
|
||||
@@ -121,19 +107,16 @@ class TelegramBackend(TransportBackend):
|
||||
presenter=presenter,
|
||||
final_notify=final_notify,
|
||||
)
|
||||
topics = _build_topics_config(transport_config)
|
||||
files = _build_files_config(transport_config)
|
||||
cfg = TelegramBridgeConfig(
|
||||
bot=bot,
|
||||
runtime=runtime,
|
||||
chat_id=chat_id,
|
||||
startup_msg=startup_msg,
|
||||
exec_cfg=exec_cfg,
|
||||
voice_transcription=cast(
|
||||
bool, transport_config.get("voice_transcription", False)
|
||||
),
|
||||
topics=topics,
|
||||
files=files,
|
||||
voice_transcription=settings.voice_transcription,
|
||||
voice_max_bytes=int(settings.voice_max_bytes),
|
||||
topics=settings.topics,
|
||||
files=settings.files,
|
||||
)
|
||||
|
||||
async def run_loop() -> None:
|
||||
@@ -142,7 +125,7 @@ class TelegramBackend(TransportBackend):
|
||||
watch_config=runtime.watch_config,
|
||||
default_engine_override=default_engine_override,
|
||||
transport_id=self.id,
|
||||
transport_config=transport_config,
|
||||
transport_config=settings,
|
||||
)
|
||||
|
||||
anyio.run(run_loop)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import cast
|
||||
|
||||
from ..logging import get_logger
|
||||
@@ -12,6 +12,11 @@ from ..transport import MessageRef, RenderedMessage, SendOptions, Transport
|
||||
from ..transport_runtime import TransportRuntime
|
||||
from ..context import RunContext
|
||||
from ..model import ResumeToken
|
||||
from ..settings import (
|
||||
TelegramFilesSettings,
|
||||
TelegramTopicsSettings,
|
||||
TelegramTransportSettings,
|
||||
)
|
||||
from .client import BotClient
|
||||
from .render import prepare_telegram
|
||||
from .types import TelegramCallbackQuery, TelegramIncomingMessage
|
||||
@@ -20,8 +25,6 @@ logger = get_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"TelegramBridgeConfig",
|
||||
"TelegramFilesConfig",
|
||||
"TelegramTopicsConfig",
|
||||
"TelegramPresenter",
|
||||
"TelegramTransport",
|
||||
"build_bot_commands",
|
||||
@@ -85,29 +88,6 @@ def _is_cancelled_label(label: str) -> bool:
|
||||
return stripped.lower() == "cancelled"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TelegramFilesConfig:
|
||||
enabled: bool = False
|
||||
auto_put: bool = True
|
||||
uploads_dir: str = "incoming"
|
||||
max_upload_bytes: int = 20 * 1024 * 1024
|
||||
max_download_bytes: int = 50 * 1024 * 1024
|
||||
allowed_user_ids: frozenset[int] = frozenset()
|
||||
deny_globs: tuple[str, ...] = (
|
||||
".git/**",
|
||||
".env",
|
||||
".envrc",
|
||||
"**/*.pem",
|
||||
"**/.ssh/**",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TelegramTopicsConfig:
|
||||
enabled: bool = False
|
||||
scope: str = "auto"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TelegramBridgeConfig:
|
||||
bot: BotClient
|
||||
@@ -116,9 +96,10 @@ class TelegramBridgeConfig:
|
||||
startup_msg: str
|
||||
exec_cfg: ExecBridgeConfig
|
||||
voice_transcription: bool = False
|
||||
files: TelegramFilesConfig = TelegramFilesConfig()
|
||||
voice_max_bytes: int = 10 * 1024 * 1024
|
||||
files: TelegramFilesSettings = field(default_factory=TelegramFilesSettings)
|
||||
chat_ids: tuple[int, ...] | None = None
|
||||
topics: TelegramTopicsConfig = TelegramTopicsConfig()
|
||||
topics: TelegramTopicsSettings = field(default_factory=TelegramTopicsSettings)
|
||||
|
||||
|
||||
class TelegramTransport:
|
||||
@@ -166,7 +147,7 @@ class TelegramTransport:
|
||||
)
|
||||
if sent is None:
|
||||
return None
|
||||
message_id = cast(int, sent["message_id"])
|
||||
message_id = sent.message_id
|
||||
return MessageRef(
|
||||
channel_id=chat_id,
|
||||
message_id=message_id,
|
||||
@@ -192,7 +173,7 @@ class TelegramTransport:
|
||||
)
|
||||
if edited is None:
|
||||
return ref if not wait else None
|
||||
message_id = cast(int, edited.get("message_id", message_id))
|
||||
message_id = edited.message_id
|
||||
return MessageRef(
|
||||
channel_id=chat_id,
|
||||
message_id=message_id,
|
||||
@@ -287,7 +268,7 @@ async def run_main_loop(
|
||||
watch_config: bool | None = None,
|
||||
default_engine_override: str | None = None,
|
||||
transport_id: str | None = None,
|
||||
transport_config: dict[str, object] | None = None,
|
||||
transport_config: TelegramTransportSettings | None = None,
|
||||
) -> None:
|
||||
from .loop import run_main_loop as _run_main_loop
|
||||
|
||||
|
||||
+287
-245
@@ -12,13 +12,16 @@ from typing import (
|
||||
Iterable,
|
||||
Protocol,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import msgspec
|
||||
import httpx
|
||||
|
||||
import anyio
|
||||
|
||||
from ..logging import get_logger
|
||||
from .api_models import Chat, ChatMember, File, ForumTopic, Message, Update, User
|
||||
from .types import (
|
||||
TelegramCallbackQuery,
|
||||
TelegramDocument,
|
||||
@@ -29,6 +32,8 @@ from .types import (
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
SEND_PRIORITY = 0
|
||||
DELETE_PRIORITY = 1
|
||||
@@ -51,15 +56,20 @@ def is_group_chat_id(chat_id: int) -> bool:
|
||||
|
||||
|
||||
def parse_incoming_update(
|
||||
update: dict[str, Any],
|
||||
update: Update | dict[str, Any],
|
||||
*,
|
||||
chat_id: int | None = None,
|
||||
chat_ids: set[int] | None = None,
|
||||
) -> TelegramIncomingUpdate | None:
|
||||
msg = update.get("message")
|
||||
if isinstance(update, Update):
|
||||
msg = update.message
|
||||
callback_query = update.callback_query
|
||||
else:
|
||||
msg = update.get("message")
|
||||
callback_query = update.get("callback_query")
|
||||
|
||||
if isinstance(msg, dict):
|
||||
return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids)
|
||||
callback_query = update.get("callback_query")
|
||||
if isinstance(callback_query, dict):
|
||||
return _parse_callback_query(
|
||||
callback_query,
|
||||
@@ -303,7 +313,7 @@ async def poll_incoming(
|
||||
if allowed is None and chat_id is not None:
|
||||
allowed = {chat_id}
|
||||
for upd in updates:
|
||||
offset = upd["update_id"] + 1
|
||||
offset = upd.update_id + 1
|
||||
msg = parse_incoming_update(upd, chat_ids=allowed)
|
||||
if msg is not None:
|
||||
yield msg
|
||||
@@ -317,9 +327,9 @@ class BotClient(Protocol):
|
||||
offset: int | None,
|
||||
timeout_s: int = 50,
|
||||
allowed_updates: list[str] | None = None,
|
||||
) -> list[dict] | None: ...
|
||||
) -> list[Update] | None: ...
|
||||
|
||||
async def get_file(self, file_id: str) -> dict | None: ...
|
||||
async def get_file(self, file_id: str) -> File | None: ...
|
||||
|
||||
async def download_file(self, file_path: str) -> bytes | None: ...
|
||||
|
||||
@@ -335,7 +345,7 @@ class BotClient(Protocol):
|
||||
reply_markup: dict[str, Any] | None = None,
|
||||
*,
|
||||
replace_message_id: int | None = None,
|
||||
) -> dict | None: ...
|
||||
) -> Message | None: ...
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
@@ -346,7 +356,7 @@ class BotClient(Protocol):
|
||||
message_thread_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
caption: str | None = None,
|
||||
) -> dict | None: ...
|
||||
) -> Message | None: ...
|
||||
|
||||
async def edit_message_text(
|
||||
self,
|
||||
@@ -358,7 +368,7 @@ class BotClient(Protocol):
|
||||
reply_markup: dict[str, Any] | None = None,
|
||||
*,
|
||||
wait: bool = True,
|
||||
) -> dict | None: ...
|
||||
) -> Message | None: ...
|
||||
|
||||
async def delete_message(
|
||||
self,
|
||||
@@ -374,7 +384,7 @@ class BotClient(Protocol):
|
||||
language_code: str | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
async def get_me(self) -> dict | None: ...
|
||||
async def get_me(self) -> User | None: ...
|
||||
|
||||
async def answer_callback_query(
|
||||
self,
|
||||
@@ -383,15 +393,17 @@ class BotClient(Protocol):
|
||||
show_alert: bool | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
async def get_chat(self, chat_id: int) -> dict | None: ...
|
||||
async def get_chat(self, chat_id: int) -> Chat | None: ...
|
||||
|
||||
async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None: ...
|
||||
async def get_chat_member(
|
||||
self, chat_id: int, user_id: int
|
||||
) -> ChatMember | None: ...
|
||||
|
||||
async def create_forum_topic(
|
||||
self,
|
||||
chat_id: int,
|
||||
name: str,
|
||||
) -> dict | None: ...
|
||||
) -> ForumTopic | None: ...
|
||||
|
||||
async def edit_forum_topic(
|
||||
self,
|
||||
@@ -672,71 +684,13 @@ class TelegramClient:
|
||||
if self._owns_http_client and self._http_client is not None:
|
||||
await self._http_client.aclose()
|
||||
|
||||
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
|
||||
if self._http_client is None or self._base is None:
|
||||
raise RuntimeError("TelegramClient is configured without an HTTP client.")
|
||||
logger.debug("telegram.request", method=method, payload=json_data)
|
||||
try:
|
||||
resp = await self._http_client.post(
|
||||
f"{self._base}/{method}", json=json_data
|
||||
)
|
||||
except httpx.HTTPError as e:
|
||||
url = getattr(e.request, "url", None)
|
||||
logger.error(
|
||||
"telegram.network_error",
|
||||
method=method,
|
||||
url=str(url) if url is not None else None,
|
||||
error=str(e),
|
||||
error_type=e.__class__.__name__,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if resp.status_code == 429:
|
||||
retry_after: float | None = None
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
if isinstance(payload, dict):
|
||||
retry_after = retry_after_from_payload(payload)
|
||||
retry_after = 5.0 if retry_after is None else retry_after
|
||||
logger.warning(
|
||||
"telegram.rate_limited",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
retry_after=retry_after,
|
||||
)
|
||||
raise TelegramRetryAfter(retry_after) from e
|
||||
body = resp.text
|
||||
logger.error(
|
||||
"telegram.http_error",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
error=str(e),
|
||||
body=body,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception as e:
|
||||
body = resp.text
|
||||
logger.error(
|
||||
"telegram.bad_response",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
error=str(e),
|
||||
error_type=e.__class__.__name__,
|
||||
body=body,
|
||||
)
|
||||
return None
|
||||
|
||||
def _parse_telegram_envelope(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
resp: httpx.Response,
|
||||
payload: Any,
|
||||
) -> Any | None:
|
||||
if not isinstance(payload, dict):
|
||||
logger.error(
|
||||
"telegram.invalid_payload",
|
||||
@@ -768,169 +722,237 @@ class TelegramClient:
|
||||
logger.debug("telegram.response", method=method, payload=payload)
|
||||
return payload.get("result")
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
*,
|
||||
json: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
) -> Any | None:
|
||||
if self._http_client is None or self._base is None:
|
||||
raise RuntimeError("TelegramClient is configured without an HTTP client.")
|
||||
request_payload = json if json is not None else data
|
||||
logger.debug("telegram.request", method=method, payload=request_payload)
|
||||
try:
|
||||
if json is not None:
|
||||
resp = await self._http_client.post(f"{self._base}/{method}", json=json)
|
||||
else:
|
||||
resp = await self._http_client.post(
|
||||
f"{self._base}/{method}", data=data, files=files
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
url = getattr(exc.request, "url", None)
|
||||
logger.error(
|
||||
"telegram.network_error",
|
||||
method=method,
|
||||
url=str(url) if url is not None else None,
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if resp.status_code == 429:
|
||||
retry_after: float | None = None
|
||||
try:
|
||||
response_payload = resp.json()
|
||||
except Exception:
|
||||
response_payload = None
|
||||
if isinstance(response_payload, dict):
|
||||
retry_after = retry_after_from_payload(response_payload)
|
||||
retry_after = 5.0 if retry_after is None else retry_after
|
||||
logger.warning(
|
||||
"telegram.rate_limited",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
retry_after=retry_after,
|
||||
)
|
||||
raise TelegramRetryAfter(retry_after) from exc
|
||||
body = resp.text
|
||||
logger.error(
|
||||
"telegram.http_error",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
error=str(exc),
|
||||
body=body,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
response_payload = resp.json()
|
||||
except Exception as exc:
|
||||
body = resp.text
|
||||
logger.error(
|
||||
"telegram.bad_response",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
body=body,
|
||||
)
|
||||
return None
|
||||
|
||||
return self._parse_telegram_envelope(
|
||||
method=method,
|
||||
resp=resp,
|
||||
payload=response_payload,
|
||||
)
|
||||
|
||||
def _decode_result(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
payload: Any,
|
||||
model: type[T],
|
||||
) -> T | None:
|
||||
if payload is None:
|
||||
return None
|
||||
try:
|
||||
return msgspec.convert(payload, type=model)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"telegram.decode_error",
|
||||
method=method,
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _call_with_retry_after(
|
||||
self,
|
||||
fn: Callable[[], Awaitable[T]],
|
||||
) -> T:
|
||||
while True:
|
||||
try:
|
||||
return await fn()
|
||||
except TelegramRetryAfter as exc:
|
||||
await self._sleep(exc.retry_after)
|
||||
|
||||
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
|
||||
return await self._request(method, json=json_data)
|
||||
|
||||
async def _post_form(
|
||||
self,
|
||||
method: str,
|
||||
data: dict[str, Any],
|
||||
files: dict[str, Any],
|
||||
) -> Any | None:
|
||||
if self._http_client is None or self._base is None:
|
||||
raise RuntimeError("TelegramClient is configured without an HTTP client.")
|
||||
logger.debug("telegram.request", method=method, payload=data)
|
||||
try:
|
||||
resp = await self._http_client.post(
|
||||
f"{self._base}/{method}", data=data, files=files
|
||||
)
|
||||
except httpx.HTTPError as e:
|
||||
url = getattr(e.request, "url", None)
|
||||
logger.error(
|
||||
"telegram.network_error",
|
||||
method=method,
|
||||
url=str(url) if url is not None else None,
|
||||
error=str(e),
|
||||
error_type=e.__class__.__name__,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if resp.status_code == 429:
|
||||
retry_after: float | None = None
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
if isinstance(payload, dict):
|
||||
retry_after = retry_after_from_payload(payload)
|
||||
retry_after = 5.0 if retry_after is None else retry_after
|
||||
logger.warning(
|
||||
"telegram.rate_limited",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
retry_after=retry_after,
|
||||
)
|
||||
raise TelegramRetryAfter(retry_after) from e
|
||||
body = resp.text
|
||||
logger.error(
|
||||
"telegram.http_error",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
error=str(e),
|
||||
body=body,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception as e:
|
||||
body = resp.text
|
||||
logger.error(
|
||||
"telegram.bad_response",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
error=str(e),
|
||||
error_type=e.__class__.__name__,
|
||||
body=body,
|
||||
)
|
||||
return None
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
logger.error(
|
||||
"telegram.invalid_payload",
|
||||
method=method,
|
||||
url=str(resp.request.url),
|
||||
payload=payload,
|
||||
)
|
||||
return None
|
||||
|
||||
if not payload.get("ok"):
|
||||
if payload.get("error_code") == 429:
|
||||
retry_after = retry_after_from_payload(payload)
|
||||
retry_after = 5.0 if retry_after is None else retry_after
|
||||
logger.warning(
|
||||
"telegram.rate_limited",
|
||||
method=method,
|
||||
url=str(resp.request.url),
|
||||
retry_after=retry_after,
|
||||
)
|
||||
raise TelegramRetryAfter(retry_after)
|
||||
logger.error(
|
||||
"telegram.api_error",
|
||||
method=method,
|
||||
url=str(resp.request.url),
|
||||
payload=payload,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.debug("telegram.response", method=method, payload=payload)
|
||||
return payload.get("result")
|
||||
return await self._request(method, data=data, files=files)
|
||||
|
||||
async def get_updates(
|
||||
self,
|
||||
offset: int | None,
|
||||
timeout_s: int = 50,
|
||||
allowed_updates: list[str] | None = None,
|
||||
) -> list[dict] | None:
|
||||
while True:
|
||||
try:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.get_updates(
|
||||
offset=offset,
|
||||
timeout_s=timeout_s,
|
||||
allowed_updates=allowed_updates,
|
||||
) -> list[Update] | None:
|
||||
async def execute() -> list[Update] | None:
|
||||
if self._client_override is not None:
|
||||
raw = await self._client_override.get_updates(
|
||||
offset=offset,
|
||||
timeout_s=timeout_s,
|
||||
allowed_updates=allowed_updates,
|
||||
)
|
||||
if raw is None:
|
||||
return None
|
||||
try:
|
||||
return msgspec.convert(raw, type=list[Update])
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"telegram.decode_error",
|
||||
method="getUpdates",
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
params: dict[str, Any] = {"timeout": timeout_s}
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
if allowed_updates is not None:
|
||||
params["allowed_updates"] = allowed_updates
|
||||
result = await self._post("getUpdates", params)
|
||||
return result if isinstance(result, list) else None
|
||||
except TelegramRetryAfter as exc:
|
||||
await self._sleep(exc.retry_after)
|
||||
return None
|
||||
|
||||
async def get_file(self, file_id: str) -> dict | None:
|
||||
while True:
|
||||
params: dict[str, Any] = {"timeout": timeout_s}
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
if allowed_updates is not None:
|
||||
params["allowed_updates"] = allowed_updates
|
||||
result = await self._post("getUpdates", params)
|
||||
if result is None or not isinstance(result, list):
|
||||
return None
|
||||
try:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.get_file(file_id)
|
||||
result = await self._post("getFile", {"file_id": file_id})
|
||||
return result if isinstance(result, dict) else None
|
||||
except TelegramRetryAfter as exc:
|
||||
await self._sleep(exc.retry_after)
|
||||
return msgspec.convert(result, type=list[Update])
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"telegram.decode_error",
|
||||
method="getUpdates",
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
return None
|
||||
|
||||
return await self._call_with_retry_after(execute)
|
||||
|
||||
async def get_file(self, file_id: str) -> File | None:
|
||||
async def execute() -> File | None:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.get_file(file_id)
|
||||
result = await self._post("getFile", {"file_id": file_id})
|
||||
return self._decode_result(method="getFile", payload=result, model=File)
|
||||
|
||||
return await self._call_with_retry_after(execute)
|
||||
|
||||
async def download_file(self, file_path: str) -> bytes | None:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.download_file(file_path)
|
||||
if self._http_client is None or self._file_base is None:
|
||||
raise RuntimeError("TelegramClient is configured without an HTTP client.")
|
||||
url = f"{self._file_base}/{file_path}"
|
||||
try:
|
||||
resp = await self._http_client.get(url)
|
||||
except httpx.HTTPError as exc:
|
||||
request_url = getattr(exc.request, "url", None)
|
||||
logger.error(
|
||||
"telegram.file_network_error",
|
||||
url=str(request_url) if request_url is not None else None,
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
return None
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error(
|
||||
"telegram.file_http_error",
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
error=str(exc),
|
||||
body=resp.text,
|
||||
)
|
||||
return None
|
||||
return resp.content
|
||||
async def execute() -> bytes | None:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.download_file(file_path)
|
||||
if self._http_client is None or self._file_base is None:
|
||||
raise RuntimeError(
|
||||
"TelegramClient is configured without an HTTP client."
|
||||
)
|
||||
url = f"{self._file_base}/{file_path}"
|
||||
try:
|
||||
resp = await self._http_client.get(url)
|
||||
except httpx.HTTPError as exc:
|
||||
request_url = getattr(exc.request, "url", None)
|
||||
logger.error(
|
||||
"telegram.file_network_error",
|
||||
url=str(request_url) if request_url is not None else None,
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
return None
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if resp.status_code == 429:
|
||||
retry_after: float | None = None
|
||||
try:
|
||||
response_payload = resp.json()
|
||||
except Exception:
|
||||
response_payload = None
|
||||
if isinstance(response_payload, dict):
|
||||
retry_after = retry_after_from_payload(response_payload)
|
||||
retry_after = 5.0 if retry_after is None else retry_after
|
||||
logger.warning(
|
||||
"telegram.rate_limited",
|
||||
method="download_file",
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
retry_after=retry_after,
|
||||
)
|
||||
raise TelegramRetryAfter(retry_after) from exc
|
||||
|
||||
logger.error(
|
||||
"telegram.file_http_error",
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
error=str(exc),
|
||||
body=resp.text,
|
||||
)
|
||||
return None
|
||||
return resp.content
|
||||
|
||||
return await self._call_with_retry_after(execute)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
@@ -944,8 +966,8 @@ class TelegramClient:
|
||||
reply_markup: dict[str, Any] | None = None,
|
||||
*,
|
||||
replace_message_id: int | None = None,
|
||||
) -> dict | None:
|
||||
async def execute() -> dict | None:
|
||||
) -> Message | None:
|
||||
async def execute() -> Message | None:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.send_message(
|
||||
chat_id=chat_id,
|
||||
@@ -972,7 +994,11 @@ class TelegramClient:
|
||||
if reply_markup is not None:
|
||||
params["reply_markup"] = reply_markup
|
||||
result = await self._post("sendMessage", params)
|
||||
return result if isinstance(result, dict) else None
|
||||
return self._decode_result(
|
||||
method="sendMessage",
|
||||
payload=result,
|
||||
model=Message,
|
||||
)
|
||||
|
||||
if replace_message_id is not None:
|
||||
await self._outbox.drop_pending(key=("edit", chat_id, replace_message_id))
|
||||
@@ -1000,8 +1026,8 @@ class TelegramClient:
|
||||
message_thread_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
caption: str | None = None,
|
||||
) -> dict | None:
|
||||
async def execute() -> dict | None:
|
||||
) -> Message | None:
|
||||
async def execute() -> Message | None:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.send_document(
|
||||
chat_id=chat_id,
|
||||
@@ -1026,7 +1052,11 @@ class TelegramClient:
|
||||
params,
|
||||
files={"document": (filename, content)},
|
||||
)
|
||||
return result if isinstance(result, dict) else None
|
||||
return self._decode_result(
|
||||
method="sendDocument",
|
||||
payload=result,
|
||||
model=Message,
|
||||
)
|
||||
|
||||
return await self.enqueue_op(
|
||||
key=self.unique_key("send_document"),
|
||||
@@ -1046,8 +1076,8 @@ class TelegramClient:
|
||||
reply_markup: dict[str, Any] | None = None,
|
||||
*,
|
||||
wait: bool = True,
|
||||
) -> dict | None:
|
||||
async def execute() -> dict | None:
|
||||
) -> Message | None:
|
||||
async def execute() -> Message | None:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
@@ -1070,7 +1100,11 @@ class TelegramClient:
|
||||
if reply_markup is not None:
|
||||
params["reply_markup"] = reply_markup
|
||||
result = await self._post("editMessageText", params)
|
||||
return result if isinstance(result, dict) else None
|
||||
return self._decode_result(
|
||||
method="editMessageText",
|
||||
payload=result,
|
||||
model=Message,
|
||||
)
|
||||
|
||||
return await self.enqueue_op(
|
||||
key=("edit", chat_id, message_id),
|
||||
@@ -1142,12 +1176,12 @@ class TelegramClient:
|
||||
)
|
||||
)
|
||||
|
||||
async def get_me(self) -> dict | None:
|
||||
async def execute() -> dict | None:
|
||||
async def get_me(self) -> User | None:
|
||||
async def execute() -> User | None:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.get_me()
|
||||
result = await self._post("getMe", {})
|
||||
return result if isinstance(result, dict) else None
|
||||
return self._decode_result(method="getMe", payload=result, model=User)
|
||||
|
||||
return await self.enqueue_op(
|
||||
key=self.unique_key("get_me"),
|
||||
@@ -1188,12 +1222,12 @@ class TelegramClient:
|
||||
)
|
||||
)
|
||||
|
||||
async def get_chat(self, chat_id: int) -> dict | None:
|
||||
async def execute() -> dict | None:
|
||||
async def get_chat(self, chat_id: int) -> Chat | None:
|
||||
async def execute() -> Chat | None:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.get_chat(chat_id)
|
||||
result = await self._post("getChat", {"chat_id": chat_id})
|
||||
return result if isinstance(result, dict) else None
|
||||
return self._decode_result(method="getChat", payload=result, model=Chat)
|
||||
|
||||
return await self.enqueue_op(
|
||||
key=self.unique_key("get_chat"),
|
||||
@@ -1203,14 +1237,18 @@ class TelegramClient:
|
||||
chat_id=chat_id,
|
||||
)
|
||||
|
||||
async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None:
|
||||
async def execute() -> dict | None:
|
||||
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
|
||||
async def execute() -> ChatMember | None:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.get_chat_member(chat_id, user_id)
|
||||
result = await self._post(
|
||||
"getChatMember", {"chat_id": chat_id, "user_id": user_id}
|
||||
)
|
||||
return result if isinstance(result, dict) else None
|
||||
return self._decode_result(
|
||||
method="getChatMember",
|
||||
payload=result,
|
||||
model=ChatMember,
|
||||
)
|
||||
|
||||
return await self.enqueue_op(
|
||||
key=self.unique_key("get_chat_member"),
|
||||
@@ -1220,14 +1258,18 @@ class TelegramClient:
|
||||
chat_id=chat_id,
|
||||
)
|
||||
|
||||
async def create_forum_topic(self, chat_id: int, name: str) -> dict | None:
|
||||
async def execute() -> dict | None:
|
||||
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
|
||||
async def execute() -> ForumTopic | None:
|
||||
if self._client_override is not None:
|
||||
return await self._client_override.create_forum_topic(chat_id, name)
|
||||
result = await self._post(
|
||||
"createForumTopic", {"chat_id": chat_id, "name": name}
|
||||
)
|
||||
return result if isinstance(result, dict) else None
|
||||
return self._decode_result(
|
||||
method="createForumTopic",
|
||||
payload=result,
|
||||
model=ForumTopic,
|
||||
)
|
||||
|
||||
return await self.enqueue_op(
|
||||
key=self.unique_key("create_forum_topic"),
|
||||
|
||||
@@ -289,11 +289,10 @@ async def _check_file_permissions(cfg, msg: TelegramIncomingMessage) -> bool:
|
||||
if is_private:
|
||||
return True
|
||||
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
|
||||
if not isinstance(member, dict):
|
||||
if member is None:
|
||||
await reply(text="failed to verify file transfer permissions.")
|
||||
return False
|
||||
status = member.get("status")
|
||||
if status in {"creator", "administrator"}:
|
||||
if member.status in {"creator", "administrator"}:
|
||||
return True
|
||||
await reply(text="file transfer is restricted to group admins.")
|
||||
return False
|
||||
@@ -377,21 +376,14 @@ async def _save_document_payload(
|
||||
error="file is too large to upload.",
|
||||
)
|
||||
file_info = await cfg.bot.get_file(document.file_id)
|
||||
if not isinstance(file_info, dict):
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=None,
|
||||
size=None,
|
||||
error="failed to fetch file metadata.",
|
||||
)
|
||||
file_path = file_info.get("file_path")
|
||||
if not isinstance(file_path, str) or not file_path:
|
||||
if file_info is None:
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=None,
|
||||
size=None,
|
||||
error="failed to fetch file metadata.",
|
||||
)
|
||||
file_path = file_info.file_path
|
||||
name = default_upload_name(document.file_name, file_path)
|
||||
resolved_path = rel_path
|
||||
if resolved_path is None:
|
||||
@@ -972,10 +964,10 @@ async def _handle_topic_command(
|
||||
return
|
||||
title = _topic_title(runtime=cfg.runtime, context=context)
|
||||
created = await cfg.bot.create_forum_topic(msg.chat_id, title)
|
||||
thread_id = created.get("message_thread_id") if isinstance(created, dict) else None
|
||||
if isinstance(thread_id, bool) or not isinstance(thread_id, int):
|
||||
if created is None:
|
||||
await reply(text="failed to create topic.")
|
||||
return
|
||||
thread_id = created.message_thread_id
|
||||
await store.set_context(
|
||||
msg.chat_id,
|
||||
thread_id,
|
||||
|
||||
@@ -14,6 +14,7 @@ from ..directives import DirectiveError
|
||||
from ..logging import get_logger
|
||||
from ..model import EngineId, ResumeToken
|
||||
from ..scheduler import ThreadJob, ThreadScheduler
|
||||
from ..settings import TelegramTransportSettings
|
||||
from ..transport import MessageRef
|
||||
from ..context import RunContext
|
||||
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
|
||||
@@ -169,7 +170,7 @@ async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int |
|
||||
if drained:
|
||||
logger.info("startup.backlog.drained", count=drained)
|
||||
return offset
|
||||
offset = updates[-1]["update_id"] + 1
|
||||
offset = updates[-1].update_id + 1
|
||||
drained += len(updates)
|
||||
|
||||
|
||||
@@ -266,7 +267,7 @@ async def run_main_loop(
|
||||
watch_config: bool | None = None,
|
||||
default_engine_override: str | None = None,
|
||||
transport_id: str | None = None,
|
||||
transport_config: dict[str, object] | None = None,
|
||||
transport_config: TelegramTransportSettings | None = None,
|
||||
) -> None:
|
||||
from ..runner_bridge import RunningTasks
|
||||
|
||||
@@ -277,7 +278,7 @@ async def run_main_loop(
|
||||
}
|
||||
reserved_commands = _reserved_commands(cfg.runtime)
|
||||
transport_snapshot = (
|
||||
dict(transport_config) if transport_config is not None else None
|
||||
transport_config.model_dump() if transport_config is not None else None
|
||||
)
|
||||
topic_store: TopicStateStore | None = None
|
||||
media_groups: dict[tuple[int, str], _MediaGroupState] = {}
|
||||
@@ -476,6 +477,7 @@ async def run_main_loop(
|
||||
bot=cfg.bot,
|
||||
msg=msg,
|
||||
enabled=cfg.voice_transcription,
|
||||
max_bytes=cfg.voice_max_bytes,
|
||||
reply=reply,
|
||||
)
|
||||
if text is None:
|
||||
|
||||
@@ -33,6 +33,7 @@ from ..engines import list_backends
|
||||
from ..logging import suppress_logs
|
||||
from ..settings import HOME_CONFIG_PATH, load_settings, require_telegram
|
||||
from ..transports import SetupResult
|
||||
from .api_models import User
|
||||
from .client import TelegramClient, TelegramRetryAfter
|
||||
|
||||
__all__ = [
|
||||
@@ -131,7 +132,7 @@ def mask_token(token: str) -> str:
|
||||
return f"{token[:9]}...{token[-5:]}"
|
||||
|
||||
|
||||
async def get_bot_info(token: str) -> dict[str, Any] | None:
|
||||
async def get_bot_info(token: str) -> User | None:
|
||||
bot = TelegramClient(token)
|
||||
try:
|
||||
for _ in range(3):
|
||||
@@ -153,23 +154,19 @@ async def wait_for_chat(token: str) -> ChatInfo:
|
||||
offset=None, timeout_s=0, allowed_updates=allowed_updates
|
||||
)
|
||||
if drained:
|
||||
offset = drained[-1]["update_id"] + 1
|
||||
offset = drained[-1].update_id + 1
|
||||
while True:
|
||||
try:
|
||||
updates = await bot.get_updates(
|
||||
offset=offset, timeout_s=50, allowed_updates=allowed_updates
|
||||
)
|
||||
except TelegramRetryAfter as exc:
|
||||
await anyio.sleep(exc.retry_after)
|
||||
continue
|
||||
updates = await bot.get_updates(
|
||||
offset=offset, timeout_s=50, allowed_updates=allowed_updates
|
||||
)
|
||||
if updates is None:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
if not updates:
|
||||
continue
|
||||
offset = updates[-1]["update_id"] + 1
|
||||
offset = updates[-1].update_id + 1
|
||||
update = updates[-1]
|
||||
msg = update.get("message")
|
||||
msg = update.message
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
sender = msg.get("from")
|
||||
@@ -298,7 +295,7 @@ def _confirm(message: str, *, default: bool = True) -> bool | None:
|
||||
return question.ask()
|
||||
|
||||
|
||||
def _prompt_token(console: Console) -> tuple[str, dict[str, Any]] | None:
|
||||
def _prompt_token(console: Console) -> tuple[str, User] | None:
|
||||
while True:
|
||||
token = questionary.password("paste your bot token:").ask()
|
||||
if token is None:
|
||||
@@ -310,11 +307,10 @@ def _prompt_token(console: Console) -> tuple[str, dict[str, Any]] | None:
|
||||
console.print(" validating...")
|
||||
info = anyio.run(get_bot_info, token)
|
||||
if info:
|
||||
username = info.get("username")
|
||||
if isinstance(username, str) and username:
|
||||
console.print(f" connected to @{username}")
|
||||
if info.username:
|
||||
console.print(f" connected to @{info.username}")
|
||||
else:
|
||||
name = info.get("first_name") or "your bot"
|
||||
name = info.first_name or "your bot"
|
||||
console.print(f" connected to {name}")
|
||||
return token, info
|
||||
console.print(" failed to connect, check the token and try again")
|
||||
@@ -342,7 +338,7 @@ def capture_chat_id(*, token: str | None = None) -> ChatInfo | None:
|
||||
return None
|
||||
token, info = token_info
|
||||
|
||||
bot_ref = f"@{info['username']}"
|
||||
bot_ref = f"@{info.username}"
|
||||
console.print("")
|
||||
console.print(f" send /start to {bot_ref} (works in groups too)")
|
||||
console.print(" waiting...")
|
||||
@@ -401,7 +397,7 @@ def interactive_setup(*, force: bool) -> bool:
|
||||
if token_info is None:
|
||||
return False
|
||||
token, info = token_info
|
||||
bot_ref = f"@{info['username']}"
|
||||
bot_ref = f"@{info.username}"
|
||||
|
||||
console.print("")
|
||||
console.print(f" send /start to {bot_ref} (works in groups too)")
|
||||
|
||||
@@ -4,9 +4,9 @@ import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import anyio
|
||||
import msgspec
|
||||
|
||||
from ..context import RunContext
|
||||
from ..logging import get_logger
|
||||
@@ -27,6 +27,26 @@ class TopicThreadSnapshot:
|
||||
topic_title: str | None
|
||||
|
||||
|
||||
class _ContextState(msgspec.Struct, forbid_unknown_fields=False):
|
||||
project: str | None = None
|
||||
branch: str | None = None
|
||||
|
||||
|
||||
class _SessionState(msgspec.Struct, forbid_unknown_fields=False):
|
||||
resume: str
|
||||
|
||||
|
||||
class _ThreadState(msgspec.Struct, forbid_unknown_fields=False):
|
||||
context: _ContextState | None = None
|
||||
sessions: dict[str, _SessionState] = msgspec.field(default_factory=dict)
|
||||
topic_title: str | None = None
|
||||
|
||||
|
||||
class _TopicState(msgspec.Struct, forbid_unknown_fields=False):
|
||||
version: int
|
||||
threads: dict[str, _ThreadState] = msgspec.field(default_factory=dict)
|
||||
|
||||
|
||||
def resolve_state_path(config_path: Path) -> Path:
|
||||
return config_path.with_name(STATE_FILENAME)
|
||||
|
||||
@@ -35,34 +55,31 @@ def _thread_key(chat_id: int, thread_id: int) -> str:
|
||||
return f"{chat_id}:{thread_id}"
|
||||
|
||||
|
||||
def _parse_context(raw: object) -> RunContext | None:
|
||||
if not isinstance(raw, dict):
|
||||
def _normalize_text(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
payload = cast(dict[str, object], raw)
|
||||
project = payload.get("project")
|
||||
branch = payload.get("branch")
|
||||
if project is not None and not isinstance(project, str):
|
||||
project = None
|
||||
if isinstance(project, str):
|
||||
project = project.strip() or None
|
||||
if branch is not None and not isinstance(branch, str):
|
||||
branch = None
|
||||
if isinstance(branch, str):
|
||||
branch = branch.strip() or None
|
||||
value = value.strip()
|
||||
return value or None
|
||||
|
||||
|
||||
def _context_from_state(state: _ContextState | None) -> RunContext | None:
|
||||
if state is None:
|
||||
return None
|
||||
project = _normalize_text(state.project)
|
||||
branch = _normalize_text(state.branch)
|
||||
if project is None and branch is None:
|
||||
return None
|
||||
return RunContext(project=project, branch=branch)
|
||||
|
||||
|
||||
def _dump_context(context: RunContext | None) -> dict[str, str] | None:
|
||||
if context is None or (context.project is None and context.branch is None):
|
||||
def _context_to_state(context: RunContext | None) -> _ContextState | None:
|
||||
if context is None:
|
||||
return None
|
||||
payload: dict[str, str] = {}
|
||||
if context.project is not None:
|
||||
payload["project"] = context.project
|
||||
if context.branch is not None:
|
||||
payload["branch"] = context.branch
|
||||
return payload or None
|
||||
project = _normalize_text(context.project)
|
||||
branch = _normalize_text(context.branch)
|
||||
if project is None and branch is None:
|
||||
return None
|
||||
return _ContextState(project=project, branch=branch)
|
||||
|
||||
|
||||
class TopicStateStore:
|
||||
@@ -71,10 +88,7 @@ class TopicStateStore:
|
||||
self._lock = anyio.Lock()
|
||||
self._loaded = False
|
||||
self._mtime_ns: int | None = None
|
||||
self._data: dict[str, Any] = {
|
||||
"version": STATE_VERSION,
|
||||
"threads": {},
|
||||
}
|
||||
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||
|
||||
async def get_thread(
|
||||
self, chat_id: int, thread_id: int
|
||||
@@ -92,7 +106,7 @@ class TopicStateStore:
|
||||
thread = self._get_thread_locked(chat_id, thread_id)
|
||||
if thread is None:
|
||||
return None
|
||||
return _parse_context(thread.get("context"))
|
||||
return _context_from_state(thread.context)
|
||||
|
||||
async def set_context(
|
||||
self,
|
||||
@@ -105,9 +119,9 @@ class TopicStateStore:
|
||||
async with self._lock:
|
||||
self._reload_locked_if_needed()
|
||||
thread = self._ensure_thread_locked(chat_id, thread_id)
|
||||
thread["context"] = _dump_context(context)
|
||||
thread.context = _context_to_state(context)
|
||||
if topic_title is not None:
|
||||
thread["topic_title"] = topic_title
|
||||
thread.topic_title = topic_title
|
||||
self._save_locked()
|
||||
|
||||
async def clear_context(self, chat_id: int, thread_id: int) -> None:
|
||||
@@ -116,7 +130,7 @@ class TopicStateStore:
|
||||
thread = self._get_thread_locked(chat_id, thread_id)
|
||||
if thread is None:
|
||||
return
|
||||
thread.pop("context", None)
|
||||
thread.context = None
|
||||
self._save_locked()
|
||||
|
||||
async def get_session_resume(
|
||||
@@ -127,16 +141,10 @@ class TopicStateStore:
|
||||
thread = self._get_thread_locked(chat_id, thread_id)
|
||||
if thread is None:
|
||||
return None
|
||||
sessions = thread.get("sessions")
|
||||
if not isinstance(sessions, dict):
|
||||
entry = thread.sessions.get(engine)
|
||||
if entry is None or not entry.resume:
|
||||
return None
|
||||
entry = sessions.get(engine)
|
||||
if not isinstance(entry, dict):
|
||||
return None
|
||||
value = entry.get("resume")
|
||||
if not isinstance(value, str) or not value:
|
||||
return None
|
||||
return ResumeToken(engine=engine, value=value)
|
||||
return ResumeToken(engine=engine, value=entry.resume)
|
||||
|
||||
async def set_session_resume(
|
||||
self, chat_id: int, thread_id: int, token: ResumeToken
|
||||
@@ -144,13 +152,7 @@ class TopicStateStore:
|
||||
async with self._lock:
|
||||
self._reload_locked_if_needed()
|
||||
thread = self._ensure_thread_locked(chat_id, thread_id)
|
||||
sessions = thread.get("sessions")
|
||||
if not isinstance(sessions, dict):
|
||||
sessions = {}
|
||||
thread["sessions"] = sessions
|
||||
sessions[token.engine] = {
|
||||
"resume": token.value,
|
||||
}
|
||||
thread.sessions[token.engine] = _SessionState(resume=token.value)
|
||||
self._save_locked()
|
||||
|
||||
async def clear_sessions(self, chat_id: int, thread_id: int) -> None:
|
||||
@@ -159,7 +161,7 @@ class TopicStateStore:
|
||||
thread = self._get_thread_locked(chat_id, thread_id)
|
||||
if thread is None:
|
||||
return
|
||||
thread.pop("sessions", None)
|
||||
thread.sessions = {}
|
||||
self._save_locked()
|
||||
|
||||
async def find_thread_for_context(
|
||||
@@ -167,47 +169,37 @@ class TopicStateStore:
|
||||
) -> int | None:
|
||||
async with self._lock:
|
||||
self._reload_locked_if_needed()
|
||||
threads = self._data.get("threads")
|
||||
if not isinstance(threads, dict):
|
||||
return None
|
||||
for raw_key, payload in threads.items():
|
||||
if not isinstance(raw_key, str) or not isinstance(payload, dict):
|
||||
target_project = _normalize_text(context.project)
|
||||
target_branch = _normalize_text(context.branch)
|
||||
for raw_key, thread in self._state.threads.items():
|
||||
if not raw_key.startswith(f"{chat_id}:"):
|
||||
continue
|
||||
parsed = _parse_context(payload.get("context"))
|
||||
parsed = _context_from_state(thread.context)
|
||||
if parsed is None:
|
||||
continue
|
||||
if parsed.project != context.project or parsed.branch != context.branch:
|
||||
continue
|
||||
if not raw_key.startswith(f"{chat_id}:"):
|
||||
if parsed.project != target_project or parsed.branch != target_branch:
|
||||
continue
|
||||
try:
|
||||
_, thread_str = raw_key.split(":", 1)
|
||||
return int(thread_str)
|
||||
except (ValueError, TypeError):
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
def _snapshot_locked(
|
||||
self, thread: dict[str, Any], chat_id: int, thread_id: int
|
||||
self, thread: _ThreadState, chat_id: int, thread_id: int
|
||||
) -> TopicThreadSnapshot:
|
||||
sessions: dict[str, str] = {}
|
||||
raw_sessions = thread.get("sessions")
|
||||
if isinstance(raw_sessions, dict):
|
||||
for engine, entry in raw_sessions.items():
|
||||
if not isinstance(engine, str) or not isinstance(entry, dict):
|
||||
continue
|
||||
value = entry.get("resume")
|
||||
if isinstance(value, str) and value:
|
||||
sessions[engine] = value
|
||||
topic_title = thread.get("topic_title")
|
||||
if not isinstance(topic_title, str):
|
||||
topic_title = None
|
||||
sessions = {
|
||||
engine: entry.resume
|
||||
for engine, entry in thread.sessions.items()
|
||||
if entry.resume
|
||||
}
|
||||
return TopicThreadSnapshot(
|
||||
chat_id=chat_id,
|
||||
thread_id=thread_id,
|
||||
context=_parse_context(thread.get("context")),
|
||||
context=_context_from_state(thread.context),
|
||||
sessions=sessions,
|
||||
topic_title=topic_title,
|
||||
topic_title=thread.topic_title,
|
||||
)
|
||||
|
||||
def _stat_mtime_ns(self) -> int | None:
|
||||
@@ -226,10 +218,10 @@ class TopicStateStore:
|
||||
self._loaded = True
|
||||
self._mtime_ns = self._stat_mtime_ns()
|
||||
if self._mtime_ns is None:
|
||||
self._data = {"version": STATE_VERSION, "threads": {}}
|
||||
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||
return
|
||||
try:
|
||||
payload = json.loads(self._path.read_text(encoding="utf-8"))
|
||||
payload = msgspec.json.decode(self._path.read_bytes(), type=_TopicState)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"telegram.topic_state.load_failed",
|
||||
@@ -237,29 +229,22 @@ class TopicStateStore:
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
self._data = {"version": STATE_VERSION, "threads": {}}
|
||||
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||
return
|
||||
if not isinstance(payload, dict):
|
||||
self._data = {"version": STATE_VERSION, "threads": {}}
|
||||
return
|
||||
version = payload.get("version")
|
||||
if version != STATE_VERSION:
|
||||
if payload.version != STATE_VERSION:
|
||||
logger.warning(
|
||||
"telegram.topic_state.version_mismatch",
|
||||
path=str(self._path),
|
||||
version=version,
|
||||
version=payload.version,
|
||||
expected=STATE_VERSION,
|
||||
)
|
||||
self._data = {"version": STATE_VERSION, "threads": {}}
|
||||
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||
return
|
||||
threads = payload.get("threads")
|
||||
if not isinstance(threads, dict):
|
||||
threads = {}
|
||||
self._data = {"version": STATE_VERSION, "threads": threads}
|
||||
self._state = payload
|
||||
|
||||
def _save_locked(self) -> None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {"version": STATE_VERSION, "threads": self._data.get("threads", {})}
|
||||
payload = msgspec.to_builtins(self._state)
|
||||
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
|
||||
with open(tmp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(payload, handle, indent=2, sort_keys=True)
|
||||
@@ -267,22 +252,14 @@ class TopicStateStore:
|
||||
os.replace(tmp_path, self._path)
|
||||
self._mtime_ns = self._stat_mtime_ns()
|
||||
|
||||
def _get_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any] | None:
|
||||
threads = self._data.get("threads")
|
||||
if not isinstance(threads, dict):
|
||||
return None
|
||||
entry = threads.get(_thread_key(chat_id, thread_id))
|
||||
return entry if isinstance(entry, dict) else None
|
||||
def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None:
|
||||
return self._state.threads.get(_thread_key(chat_id, thread_id))
|
||||
|
||||
def _ensure_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any]:
|
||||
threads = self._data.get("threads")
|
||||
if not isinstance(threads, dict):
|
||||
threads = {}
|
||||
self._data["threads"] = threads
|
||||
def _ensure_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState:
|
||||
key = _thread_key(chat_id, thread_id)
|
||||
entry = threads.get(key)
|
||||
if isinstance(entry, dict):
|
||||
entry = self._state.threads.get(key)
|
||||
if entry is not None:
|
||||
return entry
|
||||
entry = {}
|
||||
threads[key] = entry
|
||||
entry = _ThreadState()
|
||||
self._state.threads[key] = entry
|
||||
return entry
|
||||
|
||||
@@ -185,9 +185,9 @@ async def _validate_topics_setup(cfg: TelegramBridgeConfig) -> None:
|
||||
if not cfg.topics.enabled:
|
||||
return
|
||||
me = await cfg.bot.get_me()
|
||||
bot_id = me.get("id") if isinstance(me, dict) else None
|
||||
if not isinstance(bot_id, int):
|
||||
if me is None:
|
||||
raise ConfigError("failed to fetch bot id for topics validation.")
|
||||
bot_id = me.id
|
||||
scope, chat_ids = _resolve_topics_scope(cfg)
|
||||
if scope == "projects" and not chat_ids:
|
||||
raise ConfigError(
|
||||
@@ -197,37 +197,34 @@ async def _validate_topics_setup(cfg: TelegramBridgeConfig) -> None:
|
||||
|
||||
for chat_id in chat_ids:
|
||||
chat = await cfg.bot.get_chat(chat_id)
|
||||
if not isinstance(chat, dict):
|
||||
if chat is None:
|
||||
raise ConfigError(
|
||||
f"failed to fetch chat info for topics validation ({chat_id})."
|
||||
)
|
||||
chat_type = chat.get("type")
|
||||
is_forum = chat.get("is_forum")
|
||||
if chat_type != "supergroup":
|
||||
if chat.type != "supergroup":
|
||||
raise ConfigError(
|
||||
"topics enabled but chat is not a supergroup "
|
||||
f"(chat_id={chat_id}); convert the group and enable topics."
|
||||
)
|
||||
if is_forum is not True:
|
||||
if chat.is_forum is not True:
|
||||
raise ConfigError(
|
||||
"topics enabled but chat does not have topics enabled "
|
||||
f"(chat_id={chat_id}); turn on topics in group settings."
|
||||
)
|
||||
member = await cfg.bot.get_chat_member(chat_id, bot_id)
|
||||
if not isinstance(member, dict):
|
||||
if member is None:
|
||||
raise ConfigError(
|
||||
"failed to fetch bot permissions "
|
||||
f"(chat_id={chat_id}); promote the bot to admin with manage topics."
|
||||
)
|
||||
status = member.get("status")
|
||||
if status == "creator":
|
||||
if member.status == "creator":
|
||||
continue
|
||||
if status != "administrator":
|
||||
if member.status != "administrator":
|
||||
raise ConfigError(
|
||||
"topics enabled but bot is not an admin "
|
||||
f"(chat_id={chat_id}); promote it and grant manage topics."
|
||||
)
|
||||
if member.get("can_manage_topics") is not True:
|
||||
if member.can_manage_topics is not True:
|
||||
raise ConfigError(
|
||||
"topics enabled but bot lacks manage topics permission "
|
||||
f"(chat_id={chat_id}); grant can_manage_topics."
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import io
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import cast
|
||||
|
||||
from ..logging import get_logger
|
||||
from openai import AsyncOpenAI, OpenAIError
|
||||
@@ -29,6 +28,7 @@ async def transcribe_voice(
|
||||
bot: BotClient,
|
||||
msg: TelegramIncomingMessage,
|
||||
enabled: bool,
|
||||
max_bytes: int | None = None,
|
||||
reply: Callable[..., Awaitable[None]],
|
||||
) -> str | None:
|
||||
voice = msg.voice
|
||||
@@ -37,9 +37,24 @@ async def transcribe_voice(
|
||||
if not enabled:
|
||||
await reply(text=VOICE_TRANSCRIPTION_DISABLED_HINT)
|
||||
return None
|
||||
file_info = cast(dict[str, object], await bot.get_file(voice.file_id))
|
||||
file_path = cast(str, file_info["file_path"])
|
||||
audio_bytes = cast(bytes, await bot.download_file(file_path))
|
||||
if (
|
||||
max_bytes is not None
|
||||
and voice.file_size is not None
|
||||
and voice.file_size > max_bytes
|
||||
):
|
||||
await reply(text="voice message is too large to transcribe.")
|
||||
return None
|
||||
file_info = await bot.get_file(voice.file_id)
|
||||
if file_info is None:
|
||||
await reply(text="failed to fetch voice file.")
|
||||
return None
|
||||
audio_bytes = await bot.download_file(file_info.file_path)
|
||||
if audio_bytes is None:
|
||||
await reply(text="failed to download voice file.")
|
||||
return None
|
||||
if max_bytes is not None and len(audio_bytes) > max_bytes:
|
||||
await reply(text="voice message is too large to transcribe.")
|
||||
return None
|
||||
audio_file = io.BytesIO(audio_bytes)
|
||||
audio_file.name = "voice.ogg"
|
||||
async with AsyncOpenAI(timeout=120) as client:
|
||||
|
||||
@@ -10,7 +10,7 @@ from .context import RunContext
|
||||
from .directives import format_context_line, parse_context_line, parse_directives
|
||||
from .model import EngineId, ResumeToken
|
||||
from .plugins import normalize_allowlist
|
||||
from .router import AutoRouter
|
||||
from .router import AutoRouter, EngineStatus
|
||||
from .runner import Runner
|
||||
from .worktrees import WorktreeError, resolve_run_cwd
|
||||
|
||||
@@ -108,11 +108,14 @@ class TransportRuntime:
|
||||
def available_engine_ids(self) -> tuple[EngineId, ...]:
|
||||
return tuple(entry.engine for entry in self._router.available_entries)
|
||||
|
||||
def missing_engine_ids(self) -> tuple[EngineId, ...]:
|
||||
def engine_ids_with_status(self, status: EngineStatus) -> tuple[EngineId, ...]:
|
||||
return tuple(
|
||||
entry.engine for entry in self._router.entries if not entry.available
|
||||
entry.engine for entry in self._router.entries if entry.status == status
|
||||
)
|
||||
|
||||
def missing_engine_ids(self) -> tuple[EngineId, ...]:
|
||||
return self.engine_ids_with_status("missing_cli")
|
||||
|
||||
def project_aliases(self) -> tuple[str, ...]:
|
||||
return tuple(project.alias for project in self._projects.projects.values())
|
||||
|
||||
|
||||
@@ -41,13 +41,13 @@ class TransportBackend(Protocol):
|
||||
def interactive_setup(self, *, force: bool) -> bool: ...
|
||||
|
||||
def lock_token(
|
||||
self, *, transport_config: dict[str, object], config_path: Path
|
||||
self, *, transport_config: object, config_path: Path
|
||||
) -> str | None: ...
|
||||
|
||||
def build_and_run(
|
||||
self,
|
||||
*,
|
||||
transport_config: dict[str, object],
|
||||
transport_config: object,
|
||||
config_path: Path,
|
||||
runtime: TransportRuntime,
|
||||
final_notify: bool,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from takopi.backends import EngineBackend
|
||||
from takopi.config import dump_toml
|
||||
from takopi.telegram import onboarding
|
||||
from takopi.backends import EngineBackend
|
||||
from takopi.telegram.api_models import User
|
||||
|
||||
|
||||
def test_mask_token_short() -> None:
|
||||
@@ -91,7 +92,7 @@ def test_interactive_setup_writes_config(monkeypatch, tmp_path) -> None:
|
||||
|
||||
def _fake_run(func, *args, **kwargs):
|
||||
if func is onboarding.get_bot_info:
|
||||
return {"username": "my_bot"}
|
||||
return User(id=1, username="my_bot")
|
||||
if func is onboarding.wait_for_chat:
|
||||
return onboarding.ChatInfo(
|
||||
chat_id=123,
|
||||
@@ -136,7 +137,7 @@ def test_interactive_setup_preserves_projects(monkeypatch, tmp_path) -> None:
|
||||
|
||||
def _fake_run(func, *args, **kwargs):
|
||||
if func is onboarding.get_bot_info:
|
||||
return {"username": "my_bot"}
|
||||
return User(id=1, username="my_bot")
|
||||
if func is onboarding.wait_for_chat:
|
||||
return onboarding.ChatInfo(
|
||||
chat_id=123,
|
||||
@@ -173,7 +174,7 @@ def test_interactive_setup_no_agents_aborts(monkeypatch, tmp_path) -> None:
|
||||
|
||||
def _fake_run(func, *args, **kwargs):
|
||||
if func is onboarding.get_bot_info:
|
||||
return {"username": "my_bot"}
|
||||
return User(id=1, username="my_bot")
|
||||
if func is onboarding.wait_for_chat:
|
||||
return onboarding.ChatInfo(
|
||||
chat_id=123,
|
||||
@@ -211,7 +212,7 @@ def test_interactive_setup_recovers_from_malformed_toml(monkeypatch, tmp_path) -
|
||||
|
||||
def _fake_run(func, *args, **kwargs):
|
||||
if func is onboarding.get_bot_info:
|
||||
return {"username": "my_bot"}
|
||||
return User(id=1, username="my_bot")
|
||||
if func is onboarding.wait_for_chat:
|
||||
return onboarding.ChatInfo(
|
||||
chat_id=123,
|
||||
@@ -239,7 +240,7 @@ def test_interactive_setup_recovers_from_malformed_toml(monkeypatch, tmp_path) -
|
||||
def test_capture_chat_id_with_token(monkeypatch) -> None:
|
||||
def _fake_run(func, *args, **kwargs):
|
||||
if func is onboarding.get_bot_info:
|
||||
return {"username": "my_bot"}
|
||||
return User(id=1, username="my_bot")
|
||||
if func is onboarding.wait_for_chat:
|
||||
return onboarding.ChatInfo(
|
||||
chat_id=456,
|
||||
@@ -261,7 +262,9 @@ def test_capture_chat_id_with_token(monkeypatch) -> None:
|
||||
|
||||
def test_capture_chat_id_prompts_for_token(monkeypatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
onboarding, "_prompt_token", lambda _console: ("token", {"username": "bot"})
|
||||
onboarding,
|
||||
"_prompt_token",
|
||||
lambda _console: ("token", User(id=1, username="bot")),
|
||||
)
|
||||
|
||||
def _fake_run(func, *args, **kwargs):
|
||||
|
||||
@@ -9,6 +9,11 @@ from takopi.config import ProjectsConfig
|
||||
from takopi.model import EngineId
|
||||
from takopi.router import AutoRouter, RunnerEntry
|
||||
from takopi.runners.mock import Return, ScriptRunner
|
||||
from takopi.settings import (
|
||||
TelegramFilesSettings,
|
||||
TelegramTopicsSettings,
|
||||
TelegramTransportSettings,
|
||||
)
|
||||
from takopi.telegram import backend as telegram_backend
|
||||
from takopi.transport_runtime import TransportRuntime
|
||||
|
||||
@@ -20,8 +25,13 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
|
||||
missing = ScriptRunner([Return(answer="ok")], engine=pi)
|
||||
router = AutoRouter(
|
||||
entries=[
|
||||
RunnerEntry(engine=codex, runner=runner, available=True),
|
||||
RunnerEntry(engine=pi, runner=missing, available=False, issue="missing"),
|
||||
RunnerEntry(engine=codex, runner=runner),
|
||||
RunnerEntry(
|
||||
engine=pi,
|
||||
runner=missing,
|
||||
status="missing_cli",
|
||||
issue="missing",
|
||||
),
|
||||
],
|
||||
default_engine=codex,
|
||||
)
|
||||
@@ -40,6 +50,44 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
|
||||
assert "projects: `none`" in message
|
||||
|
||||
|
||||
def test_build_startup_message_surfaces_unavailable_engine_reasons(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
codex = EngineId("codex")
|
||||
pi = EngineId("pi")
|
||||
claude = EngineId("claude")
|
||||
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
||||
bad_cfg = ScriptRunner([Return(answer="ok")], engine=pi)
|
||||
load_err = ScriptRunner([Return(answer="ok")], engine=claude)
|
||||
|
||||
router = AutoRouter(
|
||||
entries=[
|
||||
RunnerEntry(engine=codex, runner=runner),
|
||||
RunnerEntry(engine=pi, runner=bad_cfg, status="bad_config", issue="bad"),
|
||||
RunnerEntry(
|
||||
engine=claude,
|
||||
runner=load_err,
|
||||
status="load_error",
|
||||
issue="failed",
|
||||
),
|
||||
],
|
||||
default_engine=codex,
|
||||
)
|
||||
runtime = TransportRuntime(
|
||||
router=router,
|
||||
projects=ProjectsConfig(projects={}, default_project=None),
|
||||
watch_config=True,
|
||||
)
|
||||
|
||||
message = telegram_backend._build_startup_message(
|
||||
runtime, startup_pwd=str(tmp_path)
|
||||
)
|
||||
|
||||
assert "agents: `codex" in message
|
||||
assert "misconfigured: pi" in message
|
||||
assert "failed to load: claude" in message
|
||||
|
||||
|
||||
def test_telegram_backend_build_and_run_wires_config(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
@@ -55,7 +103,7 @@ def test_telegram_backend_build_and_run_wires_config(
|
||||
codex = EngineId("codex")
|
||||
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
||||
router = AutoRouter(
|
||||
entries=[RunnerEntry(engine=codex, runner=runner, available=True)],
|
||||
entries=[RunnerEntry(engine=codex, runner=runner)],
|
||||
default_engine=codex,
|
||||
)
|
||||
runtime = TransportRuntime(
|
||||
@@ -80,13 +128,14 @@ def test_telegram_backend_build_and_run_wires_config(
|
||||
monkeypatch.setattr(telegram_backend, "run_main_loop", fake_run_main_loop)
|
||||
monkeypatch.setattr(telegram_backend, "TelegramClient", _FakeClient)
|
||||
|
||||
transport_config = {
|
||||
"bot_token": "token",
|
||||
"chat_id": 321,
|
||||
"voice_transcription": True,
|
||||
"files": {"enabled": True, "allowed_user_ids": [1, 2]},
|
||||
"topics": {"enabled": True, "scope": "main"},
|
||||
}
|
||||
transport_config = TelegramTransportSettings(
|
||||
bot_token="token",
|
||||
chat_id=321,
|
||||
voice_transcription=True,
|
||||
voice_max_bytes=1234,
|
||||
files=TelegramFilesSettings(enabled=True, allowed_user_ids=[1, 2]),
|
||||
topics=TelegramTopicsSettings(enabled=True, scope="main"),
|
||||
)
|
||||
|
||||
telegram_backend.TelegramBackend().build_and_run(
|
||||
transport_config=transport_config,
|
||||
@@ -100,18 +149,19 @@ def test_telegram_backend_build_and_run_wires_config(
|
||||
kwargs = captured["kwargs"]
|
||||
assert cfg.chat_id == 321
|
||||
assert cfg.voice_transcription is True
|
||||
assert cfg.voice_max_bytes == 1234
|
||||
assert cfg.files.enabled is True
|
||||
assert cfg.files.allowed_user_ids == frozenset({1, 2})
|
||||
assert cfg.files.allowed_user_ids == [1, 2]
|
||||
assert cfg.topics.enabled is True
|
||||
assert cfg.bot.token == "token"
|
||||
assert kwargs["watch_config"] is True
|
||||
assert kwargs["transport_id"] == "telegram"
|
||||
|
||||
|
||||
def test_build_files_config_defaults() -> None:
|
||||
cfg = telegram_backend._build_files_config({})
|
||||
def test_telegram_files_settings_defaults() -> None:
|
||||
cfg = TelegramFilesSettings()
|
||||
|
||||
assert cfg.enabled is False
|
||||
assert cfg.auto_put is True
|
||||
assert cfg.uploads_dir == "incoming"
|
||||
assert cfg.allowed_user_ids == frozenset()
|
||||
assert cfg.allowed_user_ids == []
|
||||
|
||||
@@ -6,22 +6,30 @@ import anyio
|
||||
import pytest
|
||||
|
||||
from takopi import commands, plugins
|
||||
import takopi.telegram.bridge as bridge
|
||||
import takopi.telegram.loop as telegram_loop
|
||||
import takopi.telegram.commands as telegram_commands
|
||||
import takopi.telegram.topics as telegram_topics
|
||||
from takopi.directives import parse_directives
|
||||
from takopi.telegram.api_models import (
|
||||
Chat,
|
||||
ChatMember,
|
||||
File,
|
||||
ForumTopic,
|
||||
Message,
|
||||
Update,
|
||||
User,
|
||||
)
|
||||
from takopi.settings import TelegramFilesSettings, TelegramTopicsSettings
|
||||
from takopi.telegram.bridge import (
|
||||
TelegramBridgeConfig,
|
||||
TelegramFilesConfig,
|
||||
TelegramPresenter,
|
||||
TelegramTransport,
|
||||
build_bot_commands,
|
||||
handle_callback_cancel,
|
||||
handle_cancel,
|
||||
is_cancel_command,
|
||||
send_with_resume,
|
||||
run_main_loop,
|
||||
send_with_resume,
|
||||
)
|
||||
from takopi.telegram.client import BotClient
|
||||
from takopi.telegram.topic_state import TopicStateStore, resolve_state_path
|
||||
@@ -122,13 +130,13 @@ class _FakeBot(BotClient):
|
||||
offset: int | None,
|
||||
timeout_s: int = 50,
|
||||
allowed_updates: list[str] | None = None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
) -> list[Update] | None:
|
||||
_ = offset
|
||||
_ = timeout_s
|
||||
_ = allowed_updates
|
||||
return []
|
||||
|
||||
async def get_file(self, file_id: str) -> dict[str, Any] | None:
|
||||
async def get_file(self, file_id: str) -> File | None:
|
||||
_ = file_id
|
||||
return None
|
||||
|
||||
@@ -148,7 +156,7 @@ class _FakeBot(BotClient):
|
||||
reply_markup: dict | None = None,
|
||||
*,
|
||||
replace_message_id: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
) -> Message:
|
||||
self.send_calls.append(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
@@ -162,7 +170,7 @@ class _FakeBot(BotClient):
|
||||
"replace_message_id": replace_message_id,
|
||||
}
|
||||
)
|
||||
return {"message_id": 1}
|
||||
return Message(message_id=1)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
@@ -173,7 +181,7 @@ class _FakeBot(BotClient):
|
||||
message_thread_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
caption: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
) -> Message:
|
||||
self.document_calls.append(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
@@ -185,7 +193,7 @@ class _FakeBot(BotClient):
|
||||
"caption": caption,
|
||||
}
|
||||
)
|
||||
return {"message_id": 2}
|
||||
return Message(message_id=2)
|
||||
|
||||
async def edit_message_text(
|
||||
self,
|
||||
@@ -197,7 +205,7 @@ class _FakeBot(BotClient):
|
||||
reply_markup: dict | None = None,
|
||||
*,
|
||||
wait: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
) -> Message:
|
||||
self.edit_calls.append(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
@@ -209,7 +217,7 @@ class _FakeBot(BotClient):
|
||||
"wait": wait,
|
||||
}
|
||||
)
|
||||
return {"message_id": message_id}
|
||||
return Message(message_id=message_id)
|
||||
|
||||
async def delete_message(self, chat_id: int, message_id: int) -> bool:
|
||||
self.delete_calls.append({"chat_id": chat_id, "message_id": message_id})
|
||||
@@ -231,26 +239,22 @@ class _FakeBot(BotClient):
|
||||
)
|
||||
return True
|
||||
|
||||
async def get_me(self) -> dict[str, Any] | None:
|
||||
return {"id": 1}
|
||||
async def get_me(self) -> User | None:
|
||||
return User(id=1, username="bot")
|
||||
|
||||
async def get_chat(self, chat_id: int) -> dict[str, Any] | None:
|
||||
async def get_chat(self, chat_id: int) -> Chat | None:
|
||||
_ = chat_id
|
||||
return {"id": chat_id, "type": "supergroup", "is_forum": True}
|
||||
return Chat(id=chat_id, type="supergroup", is_forum=True)
|
||||
|
||||
async def get_chat_member(
|
||||
self, chat_id: int, user_id: int
|
||||
) -> dict[str, Any] | None:
|
||||
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
|
||||
_ = chat_id
|
||||
_ = user_id
|
||||
return {"status": "administrator", "can_manage_topics": True}
|
||||
return ChatMember(status="administrator", can_manage_topics=True)
|
||||
|
||||
async def create_forum_topic(
|
||||
self, chat_id: int, name: str
|
||||
) -> dict[str, Any] | None:
|
||||
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
|
||||
_ = chat_id
|
||||
_ = name
|
||||
return {"message_thread_id": 1}
|
||||
return ForumTopic(message_thread_id=1)
|
||||
|
||||
async def edit_forum_topic(
|
||||
self, chat_id: int, message_thread_id: int, name: str
|
||||
@@ -539,10 +543,10 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
||||
offset: int | None,
|
||||
timeout_s: int = 50,
|
||||
allowed_updates: list[str] | None = None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
) -> list[Update] | None:
|
||||
return None
|
||||
|
||||
async def get_file(self, file_id: str) -> dict[str, Any] | None:
|
||||
async def get_file(self, file_id: str) -> File | None:
|
||||
_ = file_id
|
||||
return None
|
||||
|
||||
@@ -562,7 +566,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
||||
reply_markup: dict | None = None,
|
||||
*,
|
||||
replace_message_id: int | None = None,
|
||||
) -> dict | None:
|
||||
) -> Message | None:
|
||||
_ = reply_markup
|
||||
return None
|
||||
|
||||
@@ -575,7 +579,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
||||
message_thread_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
caption: str | None = None,
|
||||
) -> dict | None:
|
||||
) -> Message | None:
|
||||
_ = (
|
||||
chat_id,
|
||||
filename,
|
||||
@@ -597,7 +601,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
||||
reply_markup: dict | None = None,
|
||||
*,
|
||||
wait: bool = True,
|
||||
) -> dict | None:
|
||||
) -> Message | None:
|
||||
self.edit_calls.append(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
@@ -611,7 +615,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
||||
)
|
||||
if not wait:
|
||||
return None
|
||||
return {"message_id": message_id}
|
||||
return Message(message_id=message_id)
|
||||
|
||||
async def delete_message(
|
||||
self,
|
||||
@@ -629,7 +633,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
async def get_me(self) -> dict[str, Any] | None:
|
||||
async def get_me(self) -> User | None:
|
||||
return None
|
||||
|
||||
async def close(self) -> None:
|
||||
@@ -778,9 +782,9 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
|
||||
payload = b"hello"
|
||||
|
||||
class _FileBot(_FakeBot):
|
||||
async def get_file(self, file_id: str) -> dict[str, Any] | None:
|
||||
async def get_file(self, file_id: str) -> File | None:
|
||||
_ = file_id
|
||||
return {"file_path": "files/hello.txt"}
|
||||
return File(file_path="files/hello.txt")
|
||||
|
||||
async def download_file(self, file_path: str) -> bytes | None:
|
||||
_ = file_path
|
||||
@@ -811,7 +815,7 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
|
||||
chat_id=123,
|
||||
startup_msg="",
|
||||
exec_cfg=exec_cfg,
|
||||
files=TelegramFilesConfig(enabled=True),
|
||||
files=TelegramFilesSettings(enabled=True),
|
||||
)
|
||||
msg = TelegramIncomingMessage(
|
||||
transport="telegram",
|
||||
@@ -876,9 +880,9 @@ async def test_handle_file_get_sends_document_for_allowed_user(
|
||||
chat_id=123,
|
||||
startup_msg="",
|
||||
exec_cfg=exec_cfg,
|
||||
files=TelegramFilesConfig(
|
||||
files=TelegramFilesSettings(
|
||||
enabled=True,
|
||||
allowed_user_ids=frozenset({42}),
|
||||
allowed_user_ids=[42],
|
||||
),
|
||||
)
|
||||
msg = TelegramIncomingMessage(
|
||||
@@ -1006,7 +1010,7 @@ def test_topic_title_projects_scope_includes_project() -> None:
|
||||
transport = _FakeTransport()
|
||||
cfg = replace(
|
||||
_make_cfg(transport),
|
||||
topics=bridge.TelegramTopicsConfig(
|
||||
topics=TelegramTopicsSettings(
|
||||
enabled=True,
|
||||
scope="projects",
|
||||
),
|
||||
@@ -1277,7 +1281,7 @@ async def test_run_main_loop_persists_topic_sessions_in_project_scope(
|
||||
chat_id=123,
|
||||
startup_msg="",
|
||||
exec_cfg=exec_cfg,
|
||||
topics=bridge.TelegramTopicsConfig(
|
||||
topics=TelegramTopicsSettings(
|
||||
enabled=True,
|
||||
scope="projects",
|
||||
),
|
||||
@@ -1363,11 +1367,11 @@ async def test_run_main_loop_batches_media_group_upload(
|
||||
}
|
||||
|
||||
class _MediaBot(_FakeBot):
|
||||
async def get_file(self, file_id: str) -> dict[str, Any] | None:
|
||||
async def get_file(self, file_id: str) -> File | None:
|
||||
file_path = file_map.get(file_id)
|
||||
if file_path is None:
|
||||
return None
|
||||
return {"file_path": file_path}
|
||||
return File(file_path=file_path)
|
||||
|
||||
async def download_file(self, file_path: str) -> bytes | None:
|
||||
return payloads.get(file_path)
|
||||
@@ -1397,7 +1401,7 @@ async def test_run_main_loop_batches_media_group_upload(
|
||||
chat_id=123,
|
||||
startup_msg="",
|
||||
exec_cfg=exec_cfg,
|
||||
files=TelegramFilesConfig(enabled=True, auto_put=True),
|
||||
files=TelegramFilesSettings(enabled=True, auto_put=True),
|
||||
)
|
||||
msg1 = TelegramIncomingMessage(
|
||||
transport="telegram",
|
||||
|
||||
@@ -57,3 +57,175 @@ async def test_no_token_in_logs_on_http_error(
|
||||
out = capsys.readouterr().out
|
||||
assert token not in out
|
||||
assert "bot[REDACTED]" in out
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_telegram_429_no_retry_post_form() -> None:
|
||||
calls: list[int] = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
calls.append(1)
|
||||
return httpx.Response(
|
||||
429,
|
||||
json={
|
||||
"ok": False,
|
||||
"description": "retry",
|
||||
"parameters": {"retry_after": 2},
|
||||
},
|
||||
request=request,
|
||||
)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||
with pytest.raises(TelegramRetryAfter) as exc:
|
||||
await tg._post_form(
|
||||
"sendDocument",
|
||||
{"chat_id": 1},
|
||||
files={"document": ("note.txt", b"hi")},
|
||||
)
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
assert exc.value.retry_after == 2
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_telegram_429_defaults_retry_after_on_bad_body() -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(429, text="nope", request=request)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||
with pytest.raises(TelegramRetryAfter) as exc:
|
||||
await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
assert exc.value.retry_after == 5.0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_telegram_ok_false_returns_none() -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={"ok": False, "error_code": 400, "description": "bad"},
|
||||
request=request,
|
||||
)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||
result = await tg._post("getUpdates", {"timeout": 1})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_telegram_invalid_payload_returns_none() -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json=["not", "a", "dict"], request=request)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||
result = await tg._post("getUpdates", {"timeout": 1})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_telegram_decode_failure_returns_none() -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={"ok": True, "result": {"username": "bot-only"}},
|
||||
request=request,
|
||||
)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||
result = await tg.get_me()
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_telegram_download_file_retries_on_429() -> None:
|
||||
calls: list[int] = []
|
||||
sleeps: list[float] = []
|
||||
|
||||
async def sleep(delay: float) -> None:
|
||||
sleeps.append(delay)
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
calls.append(1)
|
||||
if len(calls) == 1:
|
||||
return httpx.Response(
|
||||
429,
|
||||
json={"ok": False, "parameters": {"retry_after": 3}},
|
||||
request=request,
|
||||
)
|
||||
return httpx.Response(200, content=b"ok", request=request)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client, sleep=sleep)
|
||||
payload = await tg.download_file("path/to/file")
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
assert payload == b"ok"
|
||||
assert sleeps == [3.0]
|
||||
assert len(calls) == 2
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_telegram_download_file_429_defaults_retry_after_on_bad_body() -> None:
|
||||
sleeps: list[float] = []
|
||||
|
||||
async def sleep(delay: float) -> None:
|
||||
sleeps.append(delay)
|
||||
|
||||
calls: list[int] = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
calls.append(1)
|
||||
if len(calls) == 1:
|
||||
return httpx.Response(429, text="nope", request=request)
|
||||
return httpx.Response(200, content=b"ok", request=request)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client, sleep=sleep)
|
||||
payload = await tg.download_file("path")
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
assert payload == b"ok"
|
||||
assert sleeps == [5.0]
|
||||
assert len(calls) == 2
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from takopi.telegram.api_models import File, Message, Update, User
|
||||
from takopi.telegram.client import BotClient, TelegramClient, TelegramRetryAfter
|
||||
|
||||
|
||||
@@ -29,7 +30,7 @@ class _FakeBot(BotClient):
|
||||
reply_markup: dict | None = None,
|
||||
*,
|
||||
replace_message_id: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
) -> Message | None:
|
||||
_ = reply_to_message_id
|
||||
_ = disable_notification
|
||||
_ = message_thread_id
|
||||
@@ -38,7 +39,7 @@ class _FakeBot(BotClient):
|
||||
_ = reply_markup
|
||||
_ = replace_message_id
|
||||
self.calls.append("send_message")
|
||||
return {"message_id": 1}
|
||||
return Message(message_id=1)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
@@ -49,7 +50,7 @@ class _FakeBot(BotClient):
|
||||
message_thread_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
caption: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
) -> Message | None:
|
||||
_ = (
|
||||
chat_id,
|
||||
filename,
|
||||
@@ -60,7 +61,7 @@ class _FakeBot(BotClient):
|
||||
caption,
|
||||
)
|
||||
self.calls.append("send_document")
|
||||
return {"message_id": 1}
|
||||
return Message(message_id=1)
|
||||
|
||||
async def edit_message_text(
|
||||
self,
|
||||
@@ -72,7 +73,7 @@ class _FakeBot(BotClient):
|
||||
reply_markup: dict | None = None,
|
||||
*,
|
||||
wait: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
) -> Message | None:
|
||||
_ = chat_id
|
||||
_ = message_id
|
||||
_ = entities
|
||||
@@ -85,7 +86,7 @@ class _FakeBot(BotClient):
|
||||
self._edit_attempts += 1
|
||||
raise TelegramRetryAfter(self.retry_after)
|
||||
self._edit_attempts += 1
|
||||
return {"message_id": message_id}
|
||||
return Message(message_id=message_id)
|
||||
|
||||
async def delete_message(
|
||||
self,
|
||||
@@ -113,7 +114,7 @@ class _FakeBot(BotClient):
|
||||
offset: int | None,
|
||||
timeout_s: int = 50,
|
||||
allowed_updates: list[str] | None = None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
) -> list[Update] | None:
|
||||
_ = offset
|
||||
_ = timeout_s
|
||||
_ = allowed_updates
|
||||
@@ -123,7 +124,7 @@ class _FakeBot(BotClient):
|
||||
self._updates_attempts += 1
|
||||
return []
|
||||
|
||||
async def get_file(self, file_id: str) -> dict[str, Any] | None:
|
||||
async def get_file(self, file_id: str) -> File | None:
|
||||
_ = file_id
|
||||
return None
|
||||
|
||||
@@ -134,8 +135,8 @@ class _FakeBot(BotClient):
|
||||
async def close(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_me(self) -> dict[str, Any] | None:
|
||||
return {"id": 1}
|
||||
async def get_me(self) -> User | None:
|
||||
return User(id=1)
|
||||
|
||||
async def answer_callback_query(
|
||||
self,
|
||||
@@ -187,7 +188,7 @@ async def test_edits_coalesce_latest() -> None:
|
||||
reply_markup: dict | None = None,
|
||||
*,
|
||||
wait: bool = True,
|
||||
) -> dict:
|
||||
) -> Message | None:
|
||||
if self._block_first:
|
||||
self._block_first = False
|
||||
self.edit_started.set()
|
||||
@@ -305,7 +306,8 @@ async def test_retry_after_retries_once() -> None:
|
||||
text="retry",
|
||||
)
|
||||
|
||||
assert result == {"message_id": 1}
|
||||
assert result is not None
|
||||
assert result.message_id == 1
|
||||
assert bot._edit_attempts == 2
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from takopi.telegram.api_models import (
|
||||
Chat,
|
||||
ChatMember,
|
||||
File,
|
||||
ForumTopic,
|
||||
Message,
|
||||
Update,
|
||||
User,
|
||||
)
|
||||
from takopi.telegram.client import BotClient
|
||||
from takopi.telegram.types import TelegramIncomingMessage, TelegramVoice
|
||||
from takopi.telegram.voice import transcribe_voice
|
||||
|
||||
|
||||
class _Bot(BotClient):
|
||||
def __init__(self, *, file_info: File | None, audio: bytes | None) -> None:
|
||||
self._file_info = file_info
|
||||
self._audio = audio
|
||||
|
||||
async def close(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_updates(
|
||||
self,
|
||||
offset: int | None,
|
||||
timeout_s: int = 50,
|
||||
allowed_updates: list[str] | None = None,
|
||||
) -> list[Update] | None:
|
||||
_ = offset, timeout_s, allowed_updates
|
||||
return []
|
||||
|
||||
async def get_file(self, file_id: str) -> File | None:
|
||||
_ = file_id
|
||||
return self._file_info
|
||||
|
||||
async def download_file(self, file_path: str) -> bytes | None:
|
||||
_ = file_path
|
||||
return self._audio
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_to_message_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
message_thread_id: int | None = None,
|
||||
entities: list[dict] | None = None,
|
||||
parse_mode: str | None = None,
|
||||
reply_markup: dict | None = None,
|
||||
*,
|
||||
replace_message_id: int | None = None,
|
||||
) -> Message | None:
|
||||
_ = (
|
||||
chat_id,
|
||||
text,
|
||||
reply_to_message_id,
|
||||
disable_notification,
|
||||
message_thread_id,
|
||||
entities,
|
||||
parse_mode,
|
||||
reply_markup,
|
||||
replace_message_id,
|
||||
)
|
||||
raise AssertionError("send_message should not be called")
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: int,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
reply_to_message_id: int | None = None,
|
||||
message_thread_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
caption: str | None = None,
|
||||
) -> Message | None:
|
||||
_ = (
|
||||
chat_id,
|
||||
filename,
|
||||
content,
|
||||
reply_to_message_id,
|
||||
message_thread_id,
|
||||
disable_notification,
|
||||
caption,
|
||||
)
|
||||
raise AssertionError("send_document should not be called")
|
||||
|
||||
async def edit_message_text(
|
||||
self,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
text: str,
|
||||
entities: list[dict] | None = None,
|
||||
parse_mode: str | None = None,
|
||||
reply_markup: dict | None = None,
|
||||
*,
|
||||
wait: bool = True,
|
||||
) -> Message | None:
|
||||
_ = (
|
||||
chat_id,
|
||||
message_id,
|
||||
text,
|
||||
entities,
|
||||
parse_mode,
|
||||
reply_markup,
|
||||
wait,
|
||||
)
|
||||
raise AssertionError("edit_message_text should not be called")
|
||||
|
||||
async def delete_message(self, chat_id: int, message_id: int) -> bool:
|
||||
_ = chat_id, message_id
|
||||
raise AssertionError("delete_message should not be called")
|
||||
|
||||
async def set_my_commands(
|
||||
self,
|
||||
commands: list[dict],
|
||||
*,
|
||||
scope: dict | None = None,
|
||||
language_code: str | None = None,
|
||||
) -> bool:
|
||||
_ = commands, scope, language_code
|
||||
raise AssertionError("set_my_commands should not be called")
|
||||
|
||||
async def get_me(self) -> User | None:
|
||||
raise AssertionError("get_me should not be called")
|
||||
|
||||
async def answer_callback_query(
|
||||
self,
|
||||
callback_query_id: str,
|
||||
text: str | None = None,
|
||||
show_alert: bool | None = None,
|
||||
) -> bool:
|
||||
_ = callback_query_id, text, show_alert
|
||||
raise AssertionError("answer_callback_query should not be called")
|
||||
|
||||
async def get_chat(self, chat_id: int) -> Chat | None:
|
||||
_ = chat_id
|
||||
raise AssertionError("get_chat should not be called")
|
||||
|
||||
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
|
||||
_ = chat_id, user_id
|
||||
raise AssertionError("get_chat_member should not be called")
|
||||
|
||||
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
|
||||
_ = chat_id, name
|
||||
raise AssertionError("create_forum_topic should not be called")
|
||||
|
||||
async def edit_forum_topic(
|
||||
self, chat_id: int, message_thread_id: int, name: str
|
||||
) -> bool:
|
||||
_ = chat_id, message_thread_id, name
|
||||
raise AssertionError("edit_forum_topic should not be called")
|
||||
|
||||
|
||||
def _voice_message(*, file_size: int = 123) -> TelegramIncomingMessage:
|
||||
voice = TelegramVoice(
|
||||
file_id="voice-id",
|
||||
mime_type="audio/ogg",
|
||||
file_size=file_size,
|
||||
duration=1,
|
||||
raw={},
|
||||
)
|
||||
return TelegramIncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=1,
|
||||
message_id=1,
|
||||
text="",
|
||||
reply_to_message_id=None,
|
||||
reply_to_text=None,
|
||||
sender_id=1,
|
||||
voice=voice,
|
||||
raw={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_transcribe_voice_handles_missing_file() -> None:
|
||||
replies: list[str] = []
|
||||
|
||||
async def reply(**kwargs) -> None:
|
||||
replies.append(kwargs["text"])
|
||||
|
||||
bot = _Bot(file_info=None, audio=None)
|
||||
result = await transcribe_voice(
|
||||
bot=bot,
|
||||
msg=_voice_message(),
|
||||
enabled=True,
|
||||
reply=reply,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert replies[-1] == "failed to fetch voice file."
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_transcribe_voice_handles_missing_download() -> None:
|
||||
replies: list[str] = []
|
||||
|
||||
async def reply(**kwargs) -> None:
|
||||
replies.append(kwargs["text"])
|
||||
|
||||
bot = _Bot(file_info=File(file_path="voice.ogg"), audio=None)
|
||||
result = await transcribe_voice(
|
||||
bot=bot,
|
||||
msg=_voice_message(),
|
||||
enabled=True,
|
||||
reply=reply,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert replies[-1] == "failed to download voice file."
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_transcribe_voice_rejects_large_voice_without_downloading() -> None:
|
||||
replies: list[str] = []
|
||||
|
||||
async def reply(**kwargs) -> None:
|
||||
replies.append(kwargs["text"])
|
||||
|
||||
class _NoFetchBot(_Bot):
|
||||
async def get_file(self, file_id: str) -> File | None: # type: ignore[override]
|
||||
_ = file_id
|
||||
raise AssertionError("get_file should not be called")
|
||||
|
||||
async def download_file(self, file_path: str) -> bytes | None: # type: ignore[override]
|
||||
_ = file_path
|
||||
raise AssertionError("download_file should not be called")
|
||||
|
||||
bot = _NoFetchBot(file_info=None, audio=None)
|
||||
result = await transcribe_voice(
|
||||
bot=bot,
|
||||
msg=_voice_message(file_size=10_000),
|
||||
enabled=True,
|
||||
max_bytes=100,
|
||||
reply=reply,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert replies[-1] == "voice message is too large to transcribe."
|
||||
@@ -15,14 +15,14 @@ class DummyTransport:
|
||||
def interactive_setup(self, *, force: bool) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def lock_token(self, *, transport_config: dict[str, object], config_path):
|
||||
def lock_token(self, *, transport_config: object, config_path):
|
||||
_ = transport_config, config_path
|
||||
raise NotImplementedError
|
||||
|
||||
def build_and_run(
|
||||
self,
|
||||
*,
|
||||
transport_config: dict[str, object],
|
||||
transport_config: object,
|
||||
config_path,
|
||||
runtime,
|
||||
final_notify: bool,
|
||||
|
||||
Reference in New Issue
Block a user