refactor(telegram): boundary types (#90)

This commit is contained in:
banteg
2026-01-11 21:36:07 +04:00
committed by GitHub
parent c6c34ac17f
commit e8c478d786
23 changed files with 1116 additions and 581 deletions
+7 -8
View File
@@ -105,11 +105,7 @@ def _default_engine_for_setup(
if settings is None or config_path is None:
return "codex"
value = settings.default_engine
if not isinstance(value, str) or not value.strip():
raise ConfigError(
f"Invalid `default_engine` in {config_path}; expected a non-empty string."
)
return value.strip()
return value
def _config_path_display(path: Path) -> str:
@@ -235,9 +231,12 @@ def _run_auto_router(
default_engine_override=default_engine_override,
reserved=("cancel",),
)
transport_config = settings.transport_config(
settings.transport, config_path=config_path
)
if settings.transport == "telegram":
transport_config = settings.transports.telegram
else:
transport_config = settings.transport_config(
settings.transport, config_path=config_path
)
lock_token = transport_backend.lock_token(
transport_config=transport_config,
config_path=config_path,
+11 -2
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable
from typing import Iterable, Literal, TypeAlias
from .model import EngineId, ResumeToken
from .runner import Runner
@@ -17,13 +17,22 @@ class RunnerUnavailableError(RuntimeError):
self.issue = issue
EngineStatus: TypeAlias = Literal["ok", "missing_cli", "bad_config", "load_error"]
@dataclass(frozen=True, slots=True)
class RunnerEntry:
engine: EngineId
runner: Runner
available: bool = True
status: EngineStatus = "ok"
issue: str | None = None
@property
def available(self) -> bool:
# "bad_config" means we ignored user config and built the runner with defaults.
# The engine is still runnable, but a warning should be surfaced to the user.
return self.status in {"ok", "bad_config"}
class AutoRouter:
def __init__(
+13 -6
View File
@@ -9,7 +9,7 @@ from .backends import EngineBackend
from .config import ConfigError, ProjectsConfig
from .engines import get_backend, list_backend_ids
from .logging import get_logger
from .router import AutoRouter, RunnerEntry
from .router import AutoRouter, EngineStatus, RunnerEntry
from .settings import TakopiSettings
from .transport_runtime import TransportRuntime
@@ -83,6 +83,7 @@ def build_router(
for backend in backends:
engine_id = backend.id
issue: str | None = None
status: EngineStatus = "ok"
engine_cfg: dict
try:
engine_cfg = settings.engine_config(engine_id, config_path=config_path)
@@ -90,6 +91,7 @@ def build_router(
if engine_id == default_engine:
raise
issue = str(exc)
status = "bad_config"
engine_cfg = {}
try:
@@ -104,26 +106,31 @@ def build_router(
except Exception as fallback_exc:
warnings.append(f"{engine_id}: {issue or str(fallback_exc)}")
continue
status = "bad_config"
else:
status = "load_error"
warnings.append(f"{engine_id}: {issue}")
continue
cmd = backend.cli_cmd or backend.id
if shutil.which(cmd) is None:
issue = issue or f"{cmd} not found on PATH"
status = "missing_cli"
if issue:
issue = f"{issue}; {cmd} not found on PATH"
else:
issue = f"{cmd} not found on PATH"
if issue and engine_id == default_engine:
if status != "ok" and engine_id == default_engine:
raise ConfigError(f"Default engine {engine_id!r} unavailable: {issue}")
available = issue is None
if issue and engine_id != default_engine:
if status != "ok" and engine_id != default_engine:
warnings.append(f"{engine_id}: {issue}")
entries.append(
RunnerEntry(
engine=engine_id,
runner=runner,
available=available,
status=status,
issue=issue,
)
)
+11 -6
View File
@@ -83,6 +83,14 @@ class TelegramFilesSettings(BaseModel):
raise ValueError("files.uploads_dir must be a relative path")
return value
@property
def max_upload_bytes(self) -> int:
return 20 * 1024 * 1024
@property
def max_download_bytes(self) -> int:
return 50 * 1024 * 1024
class TelegramTransportSettings(BaseModel):
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
@@ -90,14 +98,13 @@ class TelegramTransportSettings(BaseModel):
bot_token: NonEmptyStr
chat_id: StrictInt
voice_transcription: bool = False
voice_max_bytes: StrictInt = 10 * 1024 * 1024
topics: TelegramTopicsSettings = Field(default_factory=TelegramTopicsSettings)
files: TelegramFilesSettings = Field(default_factory=TelegramFilesSettings)
class TransportsSettings(BaseModel):
telegram: TelegramTransportSettings = Field(
default_factory=TelegramTransportSettings
)
telegram: TelegramTransportSettings
model_config = ConfigDict(extra="allow")
@@ -132,7 +139,7 @@ class TakopiSettings(BaseSettings):
projects: dict[str, ProjectSettings] = Field(default_factory=dict)
transport: NonEmptyStr = "telegram"
transports: TransportsSettings = Field(default_factory=TransportsSettings)
transports: TransportsSettings
plugins: PluginsSettings = Field(default_factory=PluginsSettings)
@@ -310,8 +317,6 @@ def require_telegram(settings: TakopiSettings, config_path: Path) -> tuple[str,
"(telegram only for now)."
)
tg = settings.transports.telegram
if not tg.bot_token:
raise ConfigError(f"Missing bot token in {config_path}.")
return tg.bot_token, tg.chat_id
+53
View File
@@ -0,0 +1,53 @@
from __future__ import annotations
from typing import Any
import msgspec
__all__ = [
"Chat",
"ChatMember",
"File",
"ForumTopic",
"Message",
"Update",
"User",
]
class User(msgspec.Struct, forbid_unknown_fields=False):
id: int
username: str | None = None
first_name: str | None = None
last_name: str | None = None
class Chat(msgspec.Struct, forbid_unknown_fields=False):
id: int
type: str
is_forum: bool | None = None
class ChatMember(msgspec.Struct, forbid_unknown_fields=False):
status: str
can_manage_topics: bool | None = None
class Message(msgspec.Struct, forbid_unknown_fields=False):
message_id: int
message_thread_id: int | None = None
text: str | None = None
class File(msgspec.Struct, forbid_unknown_fields=False):
file_path: str
class ForumTopic(msgspec.Struct, forbid_unknown_fields=False):
message_thread_id: int
class Update(msgspec.Struct, forbid_unknown_fields=False):
update_id: int
message: dict[str, Any] | None = None
callback_query: dict[str, Any] | None = None
+33 -50
View File
@@ -2,22 +2,19 @@ from __future__ import annotations
import os
from pathlib import Path
from typing import cast
import anyio
from ..backends import EngineBackend
from ..runner_bridge import ExecBridgeConfig
from ..logging import get_logger
from ..transports import SetupResult, TransportBackend
from ..runner_bridge import ExecBridgeConfig
from ..settings import TelegramTransportSettings
from ..transport_runtime import TransportRuntime
from ..transports import SetupResult, TransportBackend
from .bridge import (
TelegramBridgeConfig,
TelegramPresenter,
TelegramTransport,
TelegramFilesConfig,
TelegramTopicsConfig,
run_main_loop,
)
from .client import TelegramClient
@@ -26,6 +23,12 @@ from .onboarding import check_setup, interactive_setup
logger = get_logger(__name__)
def _expect_transport_settings(transport_config: object) -> TelegramTransportSettings:
if isinstance(transport_config, TelegramTransportSettings):
return transport_config
raise TypeError("transport_config must be TelegramTransportSettings")
def _build_startup_message(
runtime: TransportRuntime,
*,
@@ -33,9 +36,20 @@ def _build_startup_message(
) -> str:
available_engines = list(runtime.available_engine_ids())
missing_engines = list(runtime.missing_engine_ids())
misconfigured_engines = list(runtime.engine_ids_with_status("bad_config"))
failed_engines = list(runtime.engine_ids_with_status("load_error"))
engine_list = ", ".join(available_engines) if available_engines else "none"
notes: list[str] = []
if missing_engines:
engine_list = f"{engine_list} (not installed: {', '.join(missing_engines)})"
notes.append(f"not installed: {', '.join(missing_engines)}")
if misconfigured_engines:
notes.append(f"misconfigured: {', '.join(misconfigured_engines)}")
if failed_engines:
notes.append(f"failed to load: {', '.join(failed_engines)}")
if notes:
engine_list = f"{engine_list} ({'; '.join(notes)})"
project_aliases = sorted(
{alias for alias in runtime.project_aliases()}, key=str.lower
)
@@ -49,34 +63,6 @@ def _build_startup_message(
)
def _build_topics_config(transport_config: dict[str, object]) -> TelegramTopicsConfig:
raw = cast(dict[str, object], transport_config.get("topics", {}))
return TelegramTopicsConfig(
enabled=cast(bool, raw.get("enabled", False)),
scope=cast(str, raw.get("scope", "auto")),
)
def _build_files_config(transport_config: dict[str, object]) -> TelegramFilesConfig:
defaults = TelegramFilesConfig()
raw = cast(dict[str, object], transport_config.get("files", {}))
return TelegramFilesConfig(
enabled=cast(bool, raw.get("enabled", defaults.enabled)),
auto_put=cast(bool, raw.get("auto_put", defaults.auto_put)),
uploads_dir=cast(str, raw.get("uploads_dir", defaults.uploads_dir)),
max_upload_bytes=defaults.max_upload_bytes,
max_download_bytes=defaults.max_download_bytes,
allowed_user_ids=frozenset(
cast(
list[int], raw.get("allowed_user_ids", list(defaults.allowed_user_ids))
)
),
deny_globs=tuple(
cast(list[str], raw.get("deny_globs", list(defaults.deny_globs)))
),
)
class TelegramBackend(TransportBackend):
id = "telegram"
description = "Telegram bot"
@@ -92,23 +78,23 @@ class TelegramBackend(TransportBackend):
def interactive_setup(self, *, force: bool) -> bool:
return interactive_setup(force=force)
def lock_token(
self, *, transport_config: dict[str, object], config_path: Path
) -> str | None:
def lock_token(self, *, transport_config: object, config_path: Path) -> str | None:
_ = config_path
return cast(str, transport_config.get("bot_token"))
settings = _expect_transport_settings(transport_config)
return settings.bot_token
def build_and_run(
self,
*,
transport_config: dict[str, object],
transport_config: object,
config_path: Path,
runtime: TransportRuntime,
final_notify: bool,
default_engine_override: str | None,
) -> None:
token = cast(str, transport_config.get("bot_token"))
chat_id = cast(int, transport_config.get("chat_id"))
settings = _expect_transport_settings(transport_config)
token = settings.bot_token
chat_id = settings.chat_id
startup_msg = _build_startup_message(
runtime,
startup_pwd=os.getcwd(),
@@ -121,19 +107,16 @@ class TelegramBackend(TransportBackend):
presenter=presenter,
final_notify=final_notify,
)
topics = _build_topics_config(transport_config)
files = _build_files_config(transport_config)
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=chat_id,
startup_msg=startup_msg,
exec_cfg=exec_cfg,
voice_transcription=cast(
bool, transport_config.get("voice_transcription", False)
),
topics=topics,
files=files,
voice_transcription=settings.voice_transcription,
voice_max_bytes=int(settings.voice_max_bytes),
topics=settings.topics,
files=settings.files,
)
async def run_loop() -> None:
@@ -142,7 +125,7 @@ class TelegramBackend(TransportBackend):
watch_config=runtime.watch_config,
default_engine_override=default_engine_override,
transport_id=self.id,
transport_config=transport_config,
transport_config=settings,
)
anyio.run(run_loop)
+12 -31
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import cast
from ..logging import get_logger
@@ -12,6 +12,11 @@ from ..transport import MessageRef, RenderedMessage, SendOptions, Transport
from ..transport_runtime import TransportRuntime
from ..context import RunContext
from ..model import ResumeToken
from ..settings import (
TelegramFilesSettings,
TelegramTopicsSettings,
TelegramTransportSettings,
)
from .client import BotClient
from .render import prepare_telegram
from .types import TelegramCallbackQuery, TelegramIncomingMessage
@@ -20,8 +25,6 @@ logger = get_logger(__name__)
__all__ = [
"TelegramBridgeConfig",
"TelegramFilesConfig",
"TelegramTopicsConfig",
"TelegramPresenter",
"TelegramTransport",
"build_bot_commands",
@@ -85,29 +88,6 @@ def _is_cancelled_label(label: str) -> bool:
return stripped.lower() == "cancelled"
@dataclass(frozen=True)
class TelegramFilesConfig:
enabled: bool = False
auto_put: bool = True
uploads_dir: str = "incoming"
max_upload_bytes: int = 20 * 1024 * 1024
max_download_bytes: int = 50 * 1024 * 1024
allowed_user_ids: frozenset[int] = frozenset()
deny_globs: tuple[str, ...] = (
".git/**",
".env",
".envrc",
"**/*.pem",
"**/.ssh/**",
)
@dataclass(frozen=True)
class TelegramTopicsConfig:
enabled: bool = False
scope: str = "auto"
@dataclass(frozen=True)
class TelegramBridgeConfig:
bot: BotClient
@@ -116,9 +96,10 @@ class TelegramBridgeConfig:
startup_msg: str
exec_cfg: ExecBridgeConfig
voice_transcription: bool = False
files: TelegramFilesConfig = TelegramFilesConfig()
voice_max_bytes: int = 10 * 1024 * 1024
files: TelegramFilesSettings = field(default_factory=TelegramFilesSettings)
chat_ids: tuple[int, ...] | None = None
topics: TelegramTopicsConfig = TelegramTopicsConfig()
topics: TelegramTopicsSettings = field(default_factory=TelegramTopicsSettings)
class TelegramTransport:
@@ -166,7 +147,7 @@ class TelegramTransport:
)
if sent is None:
return None
message_id = cast(int, sent["message_id"])
message_id = sent.message_id
return MessageRef(
channel_id=chat_id,
message_id=message_id,
@@ -192,7 +173,7 @@ class TelegramTransport:
)
if edited is None:
return ref if not wait else None
message_id = cast(int, edited.get("message_id", message_id))
message_id = edited.message_id
return MessageRef(
channel_id=chat_id,
message_id=message_id,
@@ -287,7 +268,7 @@ async def run_main_loop(
watch_config: bool | None = None,
default_engine_override: str | None = None,
transport_id: str | None = None,
transport_config: dict[str, object] | None = None,
transport_config: TelegramTransportSettings | None = None,
) -> None:
from .loop import run_main_loop as _run_main_loop
+287 -245
View File
@@ -12,13 +12,16 @@ from typing import (
Iterable,
Protocol,
TYPE_CHECKING,
TypeVar,
)
import msgspec
import httpx
import anyio
from ..logging import get_logger
from .api_models import Chat, ChatMember, File, ForumTopic, Message, Update, User
from .types import (
TelegramCallbackQuery,
TelegramDocument,
@@ -29,6 +32,8 @@ from .types import (
logger = get_logger(__name__)
T = TypeVar("T")
SEND_PRIORITY = 0
DELETE_PRIORITY = 1
@@ -51,15 +56,20 @@ def is_group_chat_id(chat_id: int) -> bool:
def parse_incoming_update(
update: dict[str, Any],
update: Update | dict[str, Any],
*,
chat_id: int | None = None,
chat_ids: set[int] | None = None,
) -> TelegramIncomingUpdate | None:
msg = update.get("message")
if isinstance(update, Update):
msg = update.message
callback_query = update.callback_query
else:
msg = update.get("message")
callback_query = update.get("callback_query")
if isinstance(msg, dict):
return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids)
callback_query = update.get("callback_query")
if isinstance(callback_query, dict):
return _parse_callback_query(
callback_query,
@@ -303,7 +313,7 @@ async def poll_incoming(
if allowed is None and chat_id is not None:
allowed = {chat_id}
for upd in updates:
offset = upd["update_id"] + 1
offset = upd.update_id + 1
msg = parse_incoming_update(upd, chat_ids=allowed)
if msg is not None:
yield msg
@@ -317,9 +327,9 @@ class BotClient(Protocol):
offset: int | None,
timeout_s: int = 50,
allowed_updates: list[str] | None = None,
) -> list[dict] | None: ...
) -> list[Update] | None: ...
async def get_file(self, file_id: str) -> dict | None: ...
async def get_file(self, file_id: str) -> File | None: ...
async def download_file(self, file_path: str) -> bytes | None: ...
@@ -335,7 +345,7 @@ class BotClient(Protocol):
reply_markup: dict[str, Any] | None = None,
*,
replace_message_id: int | None = None,
) -> dict | None: ...
) -> Message | None: ...
async def send_document(
self,
@@ -346,7 +356,7 @@ class BotClient(Protocol):
message_thread_id: int | None = None,
disable_notification: bool | None = False,
caption: str | None = None,
) -> dict | None: ...
) -> Message | None: ...
async def edit_message_text(
self,
@@ -358,7 +368,7 @@ class BotClient(Protocol):
reply_markup: dict[str, Any] | None = None,
*,
wait: bool = True,
) -> dict | None: ...
) -> Message | None: ...
async def delete_message(
self,
@@ -374,7 +384,7 @@ class BotClient(Protocol):
language_code: str | None = None,
) -> bool: ...
async def get_me(self) -> dict | None: ...
async def get_me(self) -> User | None: ...
async def answer_callback_query(
self,
@@ -383,15 +393,17 @@ class BotClient(Protocol):
show_alert: bool | None = None,
) -> bool: ...
async def get_chat(self, chat_id: int) -> dict | None: ...
async def get_chat(self, chat_id: int) -> Chat | None: ...
async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None: ...
async def get_chat_member(
self, chat_id: int, user_id: int
) -> ChatMember | None: ...
async def create_forum_topic(
self,
chat_id: int,
name: str,
) -> dict | None: ...
) -> ForumTopic | None: ...
async def edit_forum_topic(
self,
@@ -672,71 +684,13 @@ class TelegramClient:
if self._owns_http_client and self._http_client is not None:
await self._http_client.aclose()
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
if self._http_client is None or self._base is None:
raise RuntimeError("TelegramClient is configured without an HTTP client.")
logger.debug("telegram.request", method=method, payload=json_data)
try:
resp = await self._http_client.post(
f"{self._base}/{method}", json=json_data
)
except httpx.HTTPError as e:
url = getattr(e.request, "url", None)
logger.error(
"telegram.network_error",
method=method,
url=str(url) if url is not None else None,
error=str(e),
error_type=e.__class__.__name__,
)
return None
try:
resp.raise_for_status()
except httpx.HTTPStatusError as e:
if resp.status_code == 429:
retry_after: float | None = None
try:
payload = resp.json()
except Exception:
payload = None
if isinstance(payload, dict):
retry_after = retry_after_from_payload(payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method=method,
status=resp.status_code,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after) from e
body = resp.text
logger.error(
"telegram.http_error",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(e),
body=body,
)
return None
try:
payload = resp.json()
except Exception as e:
body = resp.text
logger.error(
"telegram.bad_response",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(e),
error_type=e.__class__.__name__,
body=body,
)
return None
def _parse_telegram_envelope(
self,
*,
method: str,
resp: httpx.Response,
payload: Any,
) -> Any | None:
if not isinstance(payload, dict):
logger.error(
"telegram.invalid_payload",
@@ -768,169 +722,237 @@ class TelegramClient:
logger.debug("telegram.response", method=method, payload=payload)
return payload.get("result")
async def _request(
self,
method: str,
*,
json: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
files: dict[str, Any] | None = None,
) -> Any | None:
if self._http_client is None or self._base is None:
raise RuntimeError("TelegramClient is configured without an HTTP client.")
request_payload = json if json is not None else data
logger.debug("telegram.request", method=method, payload=request_payload)
try:
if json is not None:
resp = await self._http_client.post(f"{self._base}/{method}", json=json)
else:
resp = await self._http_client.post(
f"{self._base}/{method}", data=data, files=files
)
except httpx.HTTPError as exc:
url = getattr(exc.request, "url", None)
logger.error(
"telegram.network_error",
method=method,
url=str(url) if url is not None else None,
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
try:
resp.raise_for_status()
except httpx.HTTPStatusError as exc:
if resp.status_code == 429:
retry_after: float | None = None
try:
response_payload = resp.json()
except Exception:
response_payload = None
if isinstance(response_payload, dict):
retry_after = retry_after_from_payload(response_payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method=method,
status=resp.status_code,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after) from exc
body = resp.text
logger.error(
"telegram.http_error",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(exc),
body=body,
)
return None
try:
response_payload = resp.json()
except Exception as exc:
body = resp.text
logger.error(
"telegram.bad_response",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(exc),
error_type=exc.__class__.__name__,
body=body,
)
return None
return self._parse_telegram_envelope(
method=method,
resp=resp,
payload=response_payload,
)
def _decode_result(
self,
*,
method: str,
payload: Any,
model: type[T],
) -> T | None:
if payload is None:
return None
try:
return msgspec.convert(payload, type=model)
except Exception as exc:
logger.error(
"telegram.decode_error",
method=method,
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
async def _call_with_retry_after(
self,
fn: Callable[[], Awaitable[T]],
) -> T:
while True:
try:
return await fn()
except TelegramRetryAfter as exc:
await self._sleep(exc.retry_after)
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
return await self._request(method, json=json_data)
async def _post_form(
self,
method: str,
data: dict[str, Any],
files: dict[str, Any],
) -> Any | None:
if self._http_client is None or self._base is None:
raise RuntimeError("TelegramClient is configured without an HTTP client.")
logger.debug("telegram.request", method=method, payload=data)
try:
resp = await self._http_client.post(
f"{self._base}/{method}", data=data, files=files
)
except httpx.HTTPError as e:
url = getattr(e.request, "url", None)
logger.error(
"telegram.network_error",
method=method,
url=str(url) if url is not None else None,
error=str(e),
error_type=e.__class__.__name__,
)
return None
try:
resp.raise_for_status()
except httpx.HTTPStatusError as e:
if resp.status_code == 429:
retry_after: float | None = None
try:
payload = resp.json()
except Exception:
payload = None
if isinstance(payload, dict):
retry_after = retry_after_from_payload(payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method=method,
status=resp.status_code,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after) from e
body = resp.text
logger.error(
"telegram.http_error",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(e),
body=body,
)
return None
try:
payload = resp.json()
except Exception as e:
body = resp.text
logger.error(
"telegram.bad_response",
method=method,
status=resp.status_code,
error=str(e),
error_type=e.__class__.__name__,
body=body,
)
return None
if not isinstance(payload, dict):
logger.error(
"telegram.invalid_payload",
method=method,
url=str(resp.request.url),
payload=payload,
)
return None
if not payload.get("ok"):
if payload.get("error_code") == 429:
retry_after = retry_after_from_payload(payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method=method,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after)
logger.error(
"telegram.api_error",
method=method,
url=str(resp.request.url),
payload=payload,
)
return None
logger.debug("telegram.response", method=method, payload=payload)
return payload.get("result")
return await self._request(method, data=data, files=files)
async def get_updates(
self,
offset: int | None,
timeout_s: int = 50,
allowed_updates: list[str] | None = None,
) -> list[dict] | None:
while True:
try:
if self._client_override is not None:
return await self._client_override.get_updates(
offset=offset,
timeout_s=timeout_s,
allowed_updates=allowed_updates,
) -> list[Update] | None:
async def execute() -> list[Update] | None:
if self._client_override is not None:
raw = await self._client_override.get_updates(
offset=offset,
timeout_s=timeout_s,
allowed_updates=allowed_updates,
)
if raw is None:
return None
try:
return msgspec.convert(raw, type=list[Update])
except Exception as exc:
logger.error(
"telegram.decode_error",
method="getUpdates",
error=str(exc),
error_type=exc.__class__.__name__,
)
params: dict[str, Any] = {"timeout": timeout_s}
if offset is not None:
params["offset"] = offset
if allowed_updates is not None:
params["allowed_updates"] = allowed_updates
result = await self._post("getUpdates", params)
return result if isinstance(result, list) else None
except TelegramRetryAfter as exc:
await self._sleep(exc.retry_after)
return None
async def get_file(self, file_id: str) -> dict | None:
while True:
params: dict[str, Any] = {"timeout": timeout_s}
if offset is not None:
params["offset"] = offset
if allowed_updates is not None:
params["allowed_updates"] = allowed_updates
result = await self._post("getUpdates", params)
if result is None or not isinstance(result, list):
return None
try:
if self._client_override is not None:
return await self._client_override.get_file(file_id)
result = await self._post("getFile", {"file_id": file_id})
return result if isinstance(result, dict) else None
except TelegramRetryAfter as exc:
await self._sleep(exc.retry_after)
return msgspec.convert(result, type=list[Update])
except Exception as exc:
logger.error(
"telegram.decode_error",
method="getUpdates",
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
return await self._call_with_retry_after(execute)
async def get_file(self, file_id: str) -> File | None:
async def execute() -> File | None:
if self._client_override is not None:
return await self._client_override.get_file(file_id)
result = await self._post("getFile", {"file_id": file_id})
return self._decode_result(method="getFile", payload=result, model=File)
return await self._call_with_retry_after(execute)
async def download_file(self, file_path: str) -> bytes | None:
if self._client_override is not None:
return await self._client_override.download_file(file_path)
if self._http_client is None or self._file_base is None:
raise RuntimeError("TelegramClient is configured without an HTTP client.")
url = f"{self._file_base}/{file_path}"
try:
resp = await self._http_client.get(url)
except httpx.HTTPError as exc:
request_url = getattr(exc.request, "url", None)
logger.error(
"telegram.file_network_error",
url=str(request_url) if request_url is not None else None,
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
try:
resp.raise_for_status()
except httpx.HTTPStatusError as exc:
logger.error(
"telegram.file_http_error",
status=resp.status_code,
url=str(resp.request.url),
error=str(exc),
body=resp.text,
)
return None
return resp.content
async def execute() -> bytes | None:
if self._client_override is not None:
return await self._client_override.download_file(file_path)
if self._http_client is None or self._file_base is None:
raise RuntimeError(
"TelegramClient is configured without an HTTP client."
)
url = f"{self._file_base}/{file_path}"
try:
resp = await self._http_client.get(url)
except httpx.HTTPError as exc:
request_url = getattr(exc.request, "url", None)
logger.error(
"telegram.file_network_error",
url=str(request_url) if request_url is not None else None,
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
try:
resp.raise_for_status()
except httpx.HTTPStatusError as exc:
if resp.status_code == 429:
retry_after: float | None = None
try:
response_payload = resp.json()
except Exception:
response_payload = None
if isinstance(response_payload, dict):
retry_after = retry_after_from_payload(response_payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method="download_file",
status=resp.status_code,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after) from exc
logger.error(
"telegram.file_http_error",
status=resp.status_code,
url=str(resp.request.url),
error=str(exc),
body=resp.text,
)
return None
return resp.content
return await self._call_with_retry_after(execute)
async def send_message(
self,
@@ -944,8 +966,8 @@ class TelegramClient:
reply_markup: dict[str, Any] | None = None,
*,
replace_message_id: int | None = None,
) -> dict | None:
async def execute() -> dict | None:
) -> Message | None:
async def execute() -> Message | None:
if self._client_override is not None:
return await self._client_override.send_message(
chat_id=chat_id,
@@ -972,7 +994,11 @@ class TelegramClient:
if reply_markup is not None:
params["reply_markup"] = reply_markup
result = await self._post("sendMessage", params)
return result if isinstance(result, dict) else None
return self._decode_result(
method="sendMessage",
payload=result,
model=Message,
)
if replace_message_id is not None:
await self._outbox.drop_pending(key=("edit", chat_id, replace_message_id))
@@ -1000,8 +1026,8 @@ class TelegramClient:
message_thread_id: int | None = None,
disable_notification: bool | None = False,
caption: str | None = None,
) -> dict | None:
async def execute() -> dict | None:
) -> Message | None:
async def execute() -> Message | None:
if self._client_override is not None:
return await self._client_override.send_document(
chat_id=chat_id,
@@ -1026,7 +1052,11 @@ class TelegramClient:
params,
files={"document": (filename, content)},
)
return result if isinstance(result, dict) else None
return self._decode_result(
method="sendDocument",
payload=result,
model=Message,
)
return await self.enqueue_op(
key=self.unique_key("send_document"),
@@ -1046,8 +1076,8 @@ class TelegramClient:
reply_markup: dict[str, Any] | None = None,
*,
wait: bool = True,
) -> dict | None:
async def execute() -> dict | None:
) -> Message | None:
async def execute() -> Message | None:
if self._client_override is not None:
return await self._client_override.edit_message_text(
chat_id=chat_id,
@@ -1070,7 +1100,11 @@ class TelegramClient:
if reply_markup is not None:
params["reply_markup"] = reply_markup
result = await self._post("editMessageText", params)
return result if isinstance(result, dict) else None
return self._decode_result(
method="editMessageText",
payload=result,
model=Message,
)
return await self.enqueue_op(
key=("edit", chat_id, message_id),
@@ -1142,12 +1176,12 @@ class TelegramClient:
)
)
async def get_me(self) -> dict | None:
async def execute() -> dict | None:
async def get_me(self) -> User | None:
async def execute() -> User | None:
if self._client_override is not None:
return await self._client_override.get_me()
result = await self._post("getMe", {})
return result if isinstance(result, dict) else None
return self._decode_result(method="getMe", payload=result, model=User)
return await self.enqueue_op(
key=self.unique_key("get_me"),
@@ -1188,12 +1222,12 @@ class TelegramClient:
)
)
async def get_chat(self, chat_id: int) -> dict | None:
async def execute() -> dict | None:
async def get_chat(self, chat_id: int) -> Chat | None:
async def execute() -> Chat | None:
if self._client_override is not None:
return await self._client_override.get_chat(chat_id)
result = await self._post("getChat", {"chat_id": chat_id})
return result if isinstance(result, dict) else None
return self._decode_result(method="getChat", payload=result, model=Chat)
return await self.enqueue_op(
key=self.unique_key("get_chat"),
@@ -1203,14 +1237,18 @@ class TelegramClient:
chat_id=chat_id,
)
async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None:
async def execute() -> dict | None:
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
async def execute() -> ChatMember | None:
if self._client_override is not None:
return await self._client_override.get_chat_member(chat_id, user_id)
result = await self._post(
"getChatMember", {"chat_id": chat_id, "user_id": user_id}
)
return result if isinstance(result, dict) else None
return self._decode_result(
method="getChatMember",
payload=result,
model=ChatMember,
)
return await self.enqueue_op(
key=self.unique_key("get_chat_member"),
@@ -1220,14 +1258,18 @@ class TelegramClient:
chat_id=chat_id,
)
async def create_forum_topic(self, chat_id: int, name: str) -> dict | None:
async def execute() -> dict | None:
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
async def execute() -> ForumTopic | None:
if self._client_override is not None:
return await self._client_override.create_forum_topic(chat_id, name)
result = await self._post(
"createForumTopic", {"chat_id": chat_id, "name": name}
)
return result if isinstance(result, dict) else None
return self._decode_result(
method="createForumTopic",
payload=result,
model=ForumTopic,
)
return await self.enqueue_op(
key=self.unique_key("create_forum_topic"),
+6 -14
View File
@@ -289,11 +289,10 @@ async def _check_file_permissions(cfg, msg: TelegramIncomingMessage) -> bool:
if is_private:
return True
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
if not isinstance(member, dict):
if member is None:
await reply(text="failed to verify file transfer permissions.")
return False
status = member.get("status")
if status in {"creator", "administrator"}:
if member.status in {"creator", "administrator"}:
return True
await reply(text="file transfer is restricted to group admins.")
return False
@@ -377,21 +376,14 @@ async def _save_document_payload(
error="file is too large to upload.",
)
file_info = await cfg.bot.get_file(document.file_id)
if not isinstance(file_info, dict):
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error="failed to fetch file metadata.",
)
file_path = file_info.get("file_path")
if not isinstance(file_path, str) or not file_path:
if file_info is None:
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error="failed to fetch file metadata.",
)
file_path = file_info.file_path
name = default_upload_name(document.file_name, file_path)
resolved_path = rel_path
if resolved_path is None:
@@ -972,10 +964,10 @@ async def _handle_topic_command(
return
title = _topic_title(runtime=cfg.runtime, context=context)
created = await cfg.bot.create_forum_topic(msg.chat_id, title)
thread_id = created.get("message_thread_id") if isinstance(created, dict) else None
if isinstance(thread_id, bool) or not isinstance(thread_id, int):
if created is None:
await reply(text="failed to create topic.")
return
thread_id = created.message_thread_id
await store.set_context(
msg.chat_id,
thread_id,
+5 -3
View File
@@ -14,6 +14,7 @@ from ..directives import DirectiveError
from ..logging import get_logger
from ..model import EngineId, ResumeToken
from ..scheduler import ThreadJob, ThreadScheduler
from ..settings import TelegramTransportSettings
from ..transport import MessageRef
from ..context import RunContext
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
@@ -169,7 +170,7 @@ async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int |
if drained:
logger.info("startup.backlog.drained", count=drained)
return offset
offset = updates[-1]["update_id"] + 1
offset = updates[-1].update_id + 1
drained += len(updates)
@@ -266,7 +267,7 @@ async def run_main_loop(
watch_config: bool | None = None,
default_engine_override: str | None = None,
transport_id: str | None = None,
transport_config: dict[str, object] | None = None,
transport_config: TelegramTransportSettings | None = None,
) -> None:
from ..runner_bridge import RunningTasks
@@ -277,7 +278,7 @@ async def run_main_loop(
}
reserved_commands = _reserved_commands(cfg.runtime)
transport_snapshot = (
dict(transport_config) if transport_config is not None else None
transport_config.model_dump() if transport_config is not None else None
)
topic_store: TopicStateStore | None = None
media_groups: dict[tuple[int, str], _MediaGroupState] = {}
@@ -476,6 +477,7 @@ async def run_main_loop(
bot=cfg.bot,
msg=msg,
enabled=cfg.voice_transcription,
max_bytes=cfg.voice_max_bytes,
reply=reply,
)
if text is None:
+14 -18
View File
@@ -33,6 +33,7 @@ from ..engines import list_backends
from ..logging import suppress_logs
from ..settings import HOME_CONFIG_PATH, load_settings, require_telegram
from ..transports import SetupResult
from .api_models import User
from .client import TelegramClient, TelegramRetryAfter
__all__ = [
@@ -131,7 +132,7 @@ def mask_token(token: str) -> str:
return f"{token[:9]}...{token[-5:]}"
async def get_bot_info(token: str) -> dict[str, Any] | None:
async def get_bot_info(token: str) -> User | None:
bot = TelegramClient(token)
try:
for _ in range(3):
@@ -153,23 +154,19 @@ async def wait_for_chat(token: str) -> ChatInfo:
offset=None, timeout_s=0, allowed_updates=allowed_updates
)
if drained:
offset = drained[-1]["update_id"] + 1
offset = drained[-1].update_id + 1
while True:
try:
updates = await bot.get_updates(
offset=offset, timeout_s=50, allowed_updates=allowed_updates
)
except TelegramRetryAfter as exc:
await anyio.sleep(exc.retry_after)
continue
updates = await bot.get_updates(
offset=offset, timeout_s=50, allowed_updates=allowed_updates
)
if updates is None:
await anyio.sleep(1)
continue
if not updates:
continue
offset = updates[-1]["update_id"] + 1
offset = updates[-1].update_id + 1
update = updates[-1]
msg = update.get("message")
msg = update.message
if not isinstance(msg, dict):
continue
sender = msg.get("from")
@@ -298,7 +295,7 @@ def _confirm(message: str, *, default: bool = True) -> bool | None:
return question.ask()
def _prompt_token(console: Console) -> tuple[str, dict[str, Any]] | None:
def _prompt_token(console: Console) -> tuple[str, User] | None:
while True:
token = questionary.password("paste your bot token:").ask()
if token is None:
@@ -310,11 +307,10 @@ def _prompt_token(console: Console) -> tuple[str, dict[str, Any]] | None:
console.print(" validating...")
info = anyio.run(get_bot_info, token)
if info:
username = info.get("username")
if isinstance(username, str) and username:
console.print(f" connected to @{username}")
if info.username:
console.print(f" connected to @{info.username}")
else:
name = info.get("first_name") or "your bot"
name = info.first_name or "your bot"
console.print(f" connected to {name}")
return token, info
console.print(" failed to connect, check the token and try again")
@@ -342,7 +338,7 @@ def capture_chat_id(*, token: str | None = None) -> ChatInfo | None:
return None
token, info = token_info
bot_ref = f"@{info['username']}"
bot_ref = f"@{info.username}"
console.print("")
console.print(f" send /start to {bot_ref} (works in groups too)")
console.print(" waiting...")
@@ -401,7 +397,7 @@ def interactive_setup(*, force: bool) -> bool:
if token_info is None:
return False
token, info = token_info
bot_ref = f"@{info['username']}"
bot_ref = f"@{info.username}"
console.print("")
console.print(f" send /start to {bot_ref} (works in groups too)")
+79 -102
View File
@@ -4,9 +4,9 @@ import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, cast
import anyio
import msgspec
from ..context import RunContext
from ..logging import get_logger
@@ -27,6 +27,26 @@ class TopicThreadSnapshot:
topic_title: str | None
class _ContextState(msgspec.Struct, forbid_unknown_fields=False):
project: str | None = None
branch: str | None = None
class _SessionState(msgspec.Struct, forbid_unknown_fields=False):
resume: str
class _ThreadState(msgspec.Struct, forbid_unknown_fields=False):
context: _ContextState | None = None
sessions: dict[str, _SessionState] = msgspec.field(default_factory=dict)
topic_title: str | None = None
class _TopicState(msgspec.Struct, forbid_unknown_fields=False):
version: int
threads: dict[str, _ThreadState] = msgspec.field(default_factory=dict)
def resolve_state_path(config_path: Path) -> Path:
return config_path.with_name(STATE_FILENAME)
@@ -35,34 +55,31 @@ def _thread_key(chat_id: int, thread_id: int) -> str:
return f"{chat_id}:{thread_id}"
def _parse_context(raw: object) -> RunContext | None:
if not isinstance(raw, dict):
def _normalize_text(value: str | None) -> str | None:
if value is None:
return None
payload = cast(dict[str, object], raw)
project = payload.get("project")
branch = payload.get("branch")
if project is not None and not isinstance(project, str):
project = None
if isinstance(project, str):
project = project.strip() or None
if branch is not None and not isinstance(branch, str):
branch = None
if isinstance(branch, str):
branch = branch.strip() or None
value = value.strip()
return value or None
def _context_from_state(state: _ContextState | None) -> RunContext | None:
if state is None:
return None
project = _normalize_text(state.project)
branch = _normalize_text(state.branch)
if project is None and branch is None:
return None
return RunContext(project=project, branch=branch)
def _dump_context(context: RunContext | None) -> dict[str, str] | None:
if context is None or (context.project is None and context.branch is None):
def _context_to_state(context: RunContext | None) -> _ContextState | None:
if context is None:
return None
payload: dict[str, str] = {}
if context.project is not None:
payload["project"] = context.project
if context.branch is not None:
payload["branch"] = context.branch
return payload or None
project = _normalize_text(context.project)
branch = _normalize_text(context.branch)
if project is None and branch is None:
return None
return _ContextState(project=project, branch=branch)
class TopicStateStore:
@@ -71,10 +88,7 @@ class TopicStateStore:
self._lock = anyio.Lock()
self._loaded = False
self._mtime_ns: int | None = None
self._data: dict[str, Any] = {
"version": STATE_VERSION,
"threads": {},
}
self._state = _TopicState(version=STATE_VERSION, threads={})
async def get_thread(
self, chat_id: int, thread_id: int
@@ -92,7 +106,7 @@ class TopicStateStore:
thread = self._get_thread_locked(chat_id, thread_id)
if thread is None:
return None
return _parse_context(thread.get("context"))
return _context_from_state(thread.context)
async def set_context(
self,
@@ -105,9 +119,9 @@ class TopicStateStore:
async with self._lock:
self._reload_locked_if_needed()
thread = self._ensure_thread_locked(chat_id, thread_id)
thread["context"] = _dump_context(context)
thread.context = _context_to_state(context)
if topic_title is not None:
thread["topic_title"] = topic_title
thread.topic_title = topic_title
self._save_locked()
async def clear_context(self, chat_id: int, thread_id: int) -> None:
@@ -116,7 +130,7 @@ class TopicStateStore:
thread = self._get_thread_locked(chat_id, thread_id)
if thread is None:
return
thread.pop("context", None)
thread.context = None
self._save_locked()
async def get_session_resume(
@@ -127,16 +141,10 @@ class TopicStateStore:
thread = self._get_thread_locked(chat_id, thread_id)
if thread is None:
return None
sessions = thread.get("sessions")
if not isinstance(sessions, dict):
entry = thread.sessions.get(engine)
if entry is None or not entry.resume:
return None
entry = sessions.get(engine)
if not isinstance(entry, dict):
return None
value = entry.get("resume")
if not isinstance(value, str) or not value:
return None
return ResumeToken(engine=engine, value=value)
return ResumeToken(engine=engine, value=entry.resume)
async def set_session_resume(
self, chat_id: int, thread_id: int, token: ResumeToken
@@ -144,13 +152,7 @@ class TopicStateStore:
async with self._lock:
self._reload_locked_if_needed()
thread = self._ensure_thread_locked(chat_id, thread_id)
sessions = thread.get("sessions")
if not isinstance(sessions, dict):
sessions = {}
thread["sessions"] = sessions
sessions[token.engine] = {
"resume": token.value,
}
thread.sessions[token.engine] = _SessionState(resume=token.value)
self._save_locked()
async def clear_sessions(self, chat_id: int, thread_id: int) -> None:
@@ -159,7 +161,7 @@ class TopicStateStore:
thread = self._get_thread_locked(chat_id, thread_id)
if thread is None:
return
thread.pop("sessions", None)
thread.sessions = {}
self._save_locked()
async def find_thread_for_context(
@@ -167,47 +169,37 @@ class TopicStateStore:
) -> int | None:
async with self._lock:
self._reload_locked_if_needed()
threads = self._data.get("threads")
if not isinstance(threads, dict):
return None
for raw_key, payload in threads.items():
if not isinstance(raw_key, str) or not isinstance(payload, dict):
target_project = _normalize_text(context.project)
target_branch = _normalize_text(context.branch)
for raw_key, thread in self._state.threads.items():
if not raw_key.startswith(f"{chat_id}:"):
continue
parsed = _parse_context(payload.get("context"))
parsed = _context_from_state(thread.context)
if parsed is None:
continue
if parsed.project != context.project or parsed.branch != context.branch:
continue
if not raw_key.startswith(f"{chat_id}:"):
if parsed.project != target_project or parsed.branch != target_branch:
continue
try:
_, thread_str = raw_key.split(":", 1)
return int(thread_str)
except (ValueError, TypeError):
except ValueError:
continue
return None
def _snapshot_locked(
self, thread: dict[str, Any], chat_id: int, thread_id: int
self, thread: _ThreadState, chat_id: int, thread_id: int
) -> TopicThreadSnapshot:
sessions: dict[str, str] = {}
raw_sessions = thread.get("sessions")
if isinstance(raw_sessions, dict):
for engine, entry in raw_sessions.items():
if not isinstance(engine, str) or not isinstance(entry, dict):
continue
value = entry.get("resume")
if isinstance(value, str) and value:
sessions[engine] = value
topic_title = thread.get("topic_title")
if not isinstance(topic_title, str):
topic_title = None
sessions = {
engine: entry.resume
for engine, entry in thread.sessions.items()
if entry.resume
}
return TopicThreadSnapshot(
chat_id=chat_id,
thread_id=thread_id,
context=_parse_context(thread.get("context")),
context=_context_from_state(thread.context),
sessions=sessions,
topic_title=topic_title,
topic_title=thread.topic_title,
)
def _stat_mtime_ns(self) -> int | None:
@@ -226,10 +218,10 @@ class TopicStateStore:
self._loaded = True
self._mtime_ns = self._stat_mtime_ns()
if self._mtime_ns is None:
self._data = {"version": STATE_VERSION, "threads": {}}
self._state = _TopicState(version=STATE_VERSION, threads={})
return
try:
payload = json.loads(self._path.read_text(encoding="utf-8"))
payload = msgspec.json.decode(self._path.read_bytes(), type=_TopicState)
except Exception as exc:
logger.warning(
"telegram.topic_state.load_failed",
@@ -237,29 +229,22 @@ class TopicStateStore:
error=str(exc),
error_type=exc.__class__.__name__,
)
self._data = {"version": STATE_VERSION, "threads": {}}
self._state = _TopicState(version=STATE_VERSION, threads={})
return
if not isinstance(payload, dict):
self._data = {"version": STATE_VERSION, "threads": {}}
return
version = payload.get("version")
if version != STATE_VERSION:
if payload.version != STATE_VERSION:
logger.warning(
"telegram.topic_state.version_mismatch",
path=str(self._path),
version=version,
version=payload.version,
expected=STATE_VERSION,
)
self._data = {"version": STATE_VERSION, "threads": {}}
self._state = _TopicState(version=STATE_VERSION, threads={})
return
threads = payload.get("threads")
if not isinstance(threads, dict):
threads = {}
self._data = {"version": STATE_VERSION, "threads": threads}
self._state = payload
def _save_locked(self) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True)
payload = {"version": STATE_VERSION, "threads": self._data.get("threads", {})}
payload = msgspec.to_builtins(self._state)
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2, sort_keys=True)
@@ -267,22 +252,14 @@ class TopicStateStore:
os.replace(tmp_path, self._path)
self._mtime_ns = self._stat_mtime_ns()
def _get_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any] | None:
threads = self._data.get("threads")
if not isinstance(threads, dict):
return None
entry = threads.get(_thread_key(chat_id, thread_id))
return entry if isinstance(entry, dict) else None
def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None:
return self._state.threads.get(_thread_key(chat_id, thread_id))
def _ensure_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any]:
threads = self._data.get("threads")
if not isinstance(threads, dict):
threads = {}
self._data["threads"] = threads
def _ensure_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState:
key = _thread_key(chat_id, thread_id)
entry = threads.get(key)
if isinstance(entry, dict):
entry = self._state.threads.get(key)
if entry is not None:
return entry
entry = {}
threads[key] = entry
entry = _ThreadState()
self._state.threads[key] = entry
return entry
+9 -12
View File
@@ -185,9 +185,9 @@ async def _validate_topics_setup(cfg: TelegramBridgeConfig) -> None:
if not cfg.topics.enabled:
return
me = await cfg.bot.get_me()
bot_id = me.get("id") if isinstance(me, dict) else None
if not isinstance(bot_id, int):
if me is None:
raise ConfigError("failed to fetch bot id for topics validation.")
bot_id = me.id
scope, chat_ids = _resolve_topics_scope(cfg)
if scope == "projects" and not chat_ids:
raise ConfigError(
@@ -197,37 +197,34 @@ async def _validate_topics_setup(cfg: TelegramBridgeConfig) -> None:
for chat_id in chat_ids:
chat = await cfg.bot.get_chat(chat_id)
if not isinstance(chat, dict):
if chat is None:
raise ConfigError(
f"failed to fetch chat info for topics validation ({chat_id})."
)
chat_type = chat.get("type")
is_forum = chat.get("is_forum")
if chat_type != "supergroup":
if chat.type != "supergroup":
raise ConfigError(
"topics enabled but chat is not a supergroup "
f"(chat_id={chat_id}); convert the group and enable topics."
)
if is_forum is not True:
if chat.is_forum is not True:
raise ConfigError(
"topics enabled but chat does not have topics enabled "
f"(chat_id={chat_id}); turn on topics in group settings."
)
member = await cfg.bot.get_chat_member(chat_id, bot_id)
if not isinstance(member, dict):
if member is None:
raise ConfigError(
"failed to fetch bot permissions "
f"(chat_id={chat_id}); promote the bot to admin with manage topics."
)
status = member.get("status")
if status == "creator":
if member.status == "creator":
continue
if status != "administrator":
if member.status != "administrator":
raise ConfigError(
"topics enabled but bot is not an admin "
f"(chat_id={chat_id}); promote it and grant manage topics."
)
if member.get("can_manage_topics") is not True:
if member.can_manage_topics is not True:
raise ConfigError(
"topics enabled but bot lacks manage topics permission "
f"(chat_id={chat_id}); grant can_manage_topics."
+19 -4
View File
@@ -2,7 +2,6 @@ from __future__ import annotations
import io
from collections.abc import Awaitable, Callable
from typing import cast
from ..logging import get_logger
from openai import AsyncOpenAI, OpenAIError
@@ -29,6 +28,7 @@ async def transcribe_voice(
bot: BotClient,
msg: TelegramIncomingMessage,
enabled: bool,
max_bytes: int | None = None,
reply: Callable[..., Awaitable[None]],
) -> str | None:
voice = msg.voice
@@ -37,9 +37,24 @@ async def transcribe_voice(
if not enabled:
await reply(text=VOICE_TRANSCRIPTION_DISABLED_HINT)
return None
file_info = cast(dict[str, object], await bot.get_file(voice.file_id))
file_path = cast(str, file_info["file_path"])
audio_bytes = cast(bytes, await bot.download_file(file_path))
if (
max_bytes is not None
and voice.file_size is not None
and voice.file_size > max_bytes
):
await reply(text="voice message is too large to transcribe.")
return None
file_info = await bot.get_file(voice.file_id)
if file_info is None:
await reply(text="failed to fetch voice file.")
return None
audio_bytes = await bot.download_file(file_info.file_path)
if audio_bytes is None:
await reply(text="failed to download voice file.")
return None
if max_bytes is not None and len(audio_bytes) > max_bytes:
await reply(text="voice message is too large to transcribe.")
return None
audio_file = io.BytesIO(audio_bytes)
audio_file.name = "voice.ogg"
async with AsyncOpenAI(timeout=120) as client:
+6 -3
View File
@@ -10,7 +10,7 @@ from .context import RunContext
from .directives import format_context_line, parse_context_line, parse_directives
from .model import EngineId, ResumeToken
from .plugins import normalize_allowlist
from .router import AutoRouter
from .router import AutoRouter, EngineStatus
from .runner import Runner
from .worktrees import WorktreeError, resolve_run_cwd
@@ -108,11 +108,14 @@ class TransportRuntime:
def available_engine_ids(self) -> tuple[EngineId, ...]:
return tuple(entry.engine for entry in self._router.available_entries)
def missing_engine_ids(self) -> tuple[EngineId, ...]:
def engine_ids_with_status(self, status: EngineStatus) -> tuple[EngineId, ...]:
return tuple(
entry.engine for entry in self._router.entries if not entry.available
entry.engine for entry in self._router.entries if entry.status == status
)
def missing_engine_ids(self) -> tuple[EngineId, ...]:
return self.engine_ids_with_status("missing_cli")
def project_aliases(self) -> tuple[str, ...]:
return tuple(project.alias for project in self._projects.projects.values())
+2 -2
View File
@@ -41,13 +41,13 @@ class TransportBackend(Protocol):
def interactive_setup(self, *, force: bool) -> bool: ...
def lock_token(
self, *, transport_config: dict[str, object], config_path: Path
self, *, transport_config: object, config_path: Path
) -> str | None: ...
def build_and_run(
self,
*,
transport_config: dict[str, object],
transport_config: object,
config_path: Path,
runtime: TransportRuntime,
final_notify: bool,
+10 -7
View File
@@ -1,8 +1,9 @@
from __future__ import annotations
from takopi.backends import EngineBackend
from takopi.config import dump_toml
from takopi.telegram import onboarding
from takopi.backends import EngineBackend
from takopi.telegram.api_models import User
def test_mask_token_short() -> None:
@@ -91,7 +92,7 @@ def test_interactive_setup_writes_config(monkeypatch, tmp_path) -> None:
def _fake_run(func, *args, **kwargs):
if func is onboarding.get_bot_info:
return {"username": "my_bot"}
return User(id=1, username="my_bot")
if func is onboarding.wait_for_chat:
return onboarding.ChatInfo(
chat_id=123,
@@ -136,7 +137,7 @@ def test_interactive_setup_preserves_projects(monkeypatch, tmp_path) -> None:
def _fake_run(func, *args, **kwargs):
if func is onboarding.get_bot_info:
return {"username": "my_bot"}
return User(id=1, username="my_bot")
if func is onboarding.wait_for_chat:
return onboarding.ChatInfo(
chat_id=123,
@@ -173,7 +174,7 @@ def test_interactive_setup_no_agents_aborts(monkeypatch, tmp_path) -> None:
def _fake_run(func, *args, **kwargs):
if func is onboarding.get_bot_info:
return {"username": "my_bot"}
return User(id=1, username="my_bot")
if func is onboarding.wait_for_chat:
return onboarding.ChatInfo(
chat_id=123,
@@ -211,7 +212,7 @@ def test_interactive_setup_recovers_from_malformed_toml(monkeypatch, tmp_path) -
def _fake_run(func, *args, **kwargs):
if func is onboarding.get_bot_info:
return {"username": "my_bot"}
return User(id=1, username="my_bot")
if func is onboarding.wait_for_chat:
return onboarding.ChatInfo(
chat_id=123,
@@ -239,7 +240,7 @@ def test_interactive_setup_recovers_from_malformed_toml(monkeypatch, tmp_path) -
def test_capture_chat_id_with_token(monkeypatch) -> None:
def _fake_run(func, *args, **kwargs):
if func is onboarding.get_bot_info:
return {"username": "my_bot"}
return User(id=1, username="my_bot")
if func is onboarding.wait_for_chat:
return onboarding.ChatInfo(
chat_id=456,
@@ -261,7 +262,9 @@ def test_capture_chat_id_with_token(monkeypatch) -> None:
def test_capture_chat_id_prompts_for_token(monkeypatch) -> None:
monkeypatch.setattr(
onboarding, "_prompt_token", lambda _console: ("token", {"username": "bot"})
onboarding,
"_prompt_token",
lambda _console: ("token", User(id=1, username="bot")),
)
def _fake_run(func, *args, **kwargs):
+64 -14
View File
@@ -9,6 +9,11 @@ from takopi.config import ProjectsConfig
from takopi.model import EngineId
from takopi.router import AutoRouter, RunnerEntry
from takopi.runners.mock import Return, ScriptRunner
from takopi.settings import (
TelegramFilesSettings,
TelegramTopicsSettings,
TelegramTransportSettings,
)
from takopi.telegram import backend as telegram_backend
from takopi.transport_runtime import TransportRuntime
@@ -20,8 +25,13 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
missing = ScriptRunner([Return(answer="ok")], engine=pi)
router = AutoRouter(
entries=[
RunnerEntry(engine=codex, runner=runner, available=True),
RunnerEntry(engine=pi, runner=missing, available=False, issue="missing"),
RunnerEntry(engine=codex, runner=runner),
RunnerEntry(
engine=pi,
runner=missing,
status="missing_cli",
issue="missing",
),
],
default_engine=codex,
)
@@ -40,6 +50,44 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
assert "projects: `none`" in message
def test_build_startup_message_surfaces_unavailable_engine_reasons(
tmp_path: Path,
) -> None:
codex = EngineId("codex")
pi = EngineId("pi")
claude = EngineId("claude")
runner = ScriptRunner([Return(answer="ok")], engine=codex)
bad_cfg = ScriptRunner([Return(answer="ok")], engine=pi)
load_err = ScriptRunner([Return(answer="ok")], engine=claude)
router = AutoRouter(
entries=[
RunnerEntry(engine=codex, runner=runner),
RunnerEntry(engine=pi, runner=bad_cfg, status="bad_config", issue="bad"),
RunnerEntry(
engine=claude,
runner=load_err,
status="load_error",
issue="failed",
),
],
default_engine=codex,
)
runtime = TransportRuntime(
router=router,
projects=ProjectsConfig(projects={}, default_project=None),
watch_config=True,
)
message = telegram_backend._build_startup_message(
runtime, startup_pwd=str(tmp_path)
)
assert "agents: `codex" in message
assert "misconfigured: pi" in message
assert "failed to load: claude" in message
def test_telegram_backend_build_and_run_wires_config(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
@@ -55,7 +103,7 @@ def test_telegram_backend_build_and_run_wires_config(
codex = EngineId("codex")
runner = ScriptRunner([Return(answer="ok")], engine=codex)
router = AutoRouter(
entries=[RunnerEntry(engine=codex, runner=runner, available=True)],
entries=[RunnerEntry(engine=codex, runner=runner)],
default_engine=codex,
)
runtime = TransportRuntime(
@@ -80,13 +128,14 @@ def test_telegram_backend_build_and_run_wires_config(
monkeypatch.setattr(telegram_backend, "run_main_loop", fake_run_main_loop)
monkeypatch.setattr(telegram_backend, "TelegramClient", _FakeClient)
transport_config = {
"bot_token": "token",
"chat_id": 321,
"voice_transcription": True,
"files": {"enabled": True, "allowed_user_ids": [1, 2]},
"topics": {"enabled": True, "scope": "main"},
}
transport_config = TelegramTransportSettings(
bot_token="token",
chat_id=321,
voice_transcription=True,
voice_max_bytes=1234,
files=TelegramFilesSettings(enabled=True, allowed_user_ids=[1, 2]),
topics=TelegramTopicsSettings(enabled=True, scope="main"),
)
telegram_backend.TelegramBackend().build_and_run(
transport_config=transport_config,
@@ -100,18 +149,19 @@ def test_telegram_backend_build_and_run_wires_config(
kwargs = captured["kwargs"]
assert cfg.chat_id == 321
assert cfg.voice_transcription is True
assert cfg.voice_max_bytes == 1234
assert cfg.files.enabled is True
assert cfg.files.allowed_user_ids == frozenset({1, 2})
assert cfg.files.allowed_user_ids == [1, 2]
assert cfg.topics.enabled is True
assert cfg.bot.token == "token"
assert kwargs["watch_config"] is True
assert kwargs["transport_id"] == "telegram"
def test_build_files_config_defaults() -> None:
cfg = telegram_backend._build_files_config({})
def test_telegram_files_settings_defaults() -> None:
cfg = TelegramFilesSettings()
assert cfg.enabled is False
assert cfg.auto_put is True
assert cfg.uploads_dir == "incoming"
assert cfg.allowed_user_ids == frozenset()
assert cfg.allowed_user_ids == []
+44 -40
View File
@@ -6,22 +6,30 @@ import anyio
import pytest
from takopi import commands, plugins
import takopi.telegram.bridge as bridge
import takopi.telegram.loop as telegram_loop
import takopi.telegram.commands as telegram_commands
import takopi.telegram.topics as telegram_topics
from takopi.directives import parse_directives
from takopi.telegram.api_models import (
Chat,
ChatMember,
File,
ForumTopic,
Message,
Update,
User,
)
from takopi.settings import TelegramFilesSettings, TelegramTopicsSettings
from takopi.telegram.bridge import (
TelegramBridgeConfig,
TelegramFilesConfig,
TelegramPresenter,
TelegramTransport,
build_bot_commands,
handle_callback_cancel,
handle_cancel,
is_cancel_command,
send_with_resume,
run_main_loop,
send_with_resume,
)
from takopi.telegram.client import BotClient
from takopi.telegram.topic_state import TopicStateStore, resolve_state_path
@@ -122,13 +130,13 @@ class _FakeBot(BotClient):
offset: int | None,
timeout_s: int = 50,
allowed_updates: list[str] | None = None,
) -> list[dict[str, Any]] | None:
) -> list[Update] | None:
_ = offset
_ = timeout_s
_ = allowed_updates
return []
async def get_file(self, file_id: str) -> dict[str, Any] | None:
async def get_file(self, file_id: str) -> File | None:
_ = file_id
return None
@@ -148,7 +156,7 @@ class _FakeBot(BotClient):
reply_markup: dict | None = None,
*,
replace_message_id: int | None = None,
) -> dict[str, Any]:
) -> Message:
self.send_calls.append(
{
"chat_id": chat_id,
@@ -162,7 +170,7 @@ class _FakeBot(BotClient):
"replace_message_id": replace_message_id,
}
)
return {"message_id": 1}
return Message(message_id=1)
async def send_document(
self,
@@ -173,7 +181,7 @@ class _FakeBot(BotClient):
message_thread_id: int | None = None,
disable_notification: bool | None = False,
caption: str | None = None,
) -> dict[str, Any]:
) -> Message:
self.document_calls.append(
{
"chat_id": chat_id,
@@ -185,7 +193,7 @@ class _FakeBot(BotClient):
"caption": caption,
}
)
return {"message_id": 2}
return Message(message_id=2)
async def edit_message_text(
self,
@@ -197,7 +205,7 @@ class _FakeBot(BotClient):
reply_markup: dict | None = None,
*,
wait: bool = True,
) -> dict[str, Any]:
) -> Message:
self.edit_calls.append(
{
"chat_id": chat_id,
@@ -209,7 +217,7 @@ class _FakeBot(BotClient):
"wait": wait,
}
)
return {"message_id": message_id}
return Message(message_id=message_id)
async def delete_message(self, chat_id: int, message_id: int) -> bool:
self.delete_calls.append({"chat_id": chat_id, "message_id": message_id})
@@ -231,26 +239,22 @@ class _FakeBot(BotClient):
)
return True
async def get_me(self) -> dict[str, Any] | None:
return {"id": 1}
async def get_me(self) -> User | None:
return User(id=1, username="bot")
async def get_chat(self, chat_id: int) -> dict[str, Any] | None:
async def get_chat(self, chat_id: int) -> Chat | None:
_ = chat_id
return {"id": chat_id, "type": "supergroup", "is_forum": True}
return Chat(id=chat_id, type="supergroup", is_forum=True)
async def get_chat_member(
self, chat_id: int, user_id: int
) -> dict[str, Any] | None:
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
_ = chat_id
_ = user_id
return {"status": "administrator", "can_manage_topics": True}
return ChatMember(status="administrator", can_manage_topics=True)
async def create_forum_topic(
self, chat_id: int, name: str
) -> dict[str, Any] | None:
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
_ = chat_id
_ = name
return {"message_thread_id": 1}
return ForumTopic(message_thread_id=1)
async def edit_forum_topic(
self, chat_id: int, message_thread_id: int, name: str
@@ -539,10 +543,10 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
offset: int | None,
timeout_s: int = 50,
allowed_updates: list[str] | None = None,
) -> list[dict[str, Any]] | None:
) -> list[Update] | None:
return None
async def get_file(self, file_id: str) -> dict[str, Any] | None:
async def get_file(self, file_id: str) -> File | None:
_ = file_id
return None
@@ -562,7 +566,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
reply_markup: dict | None = None,
*,
replace_message_id: int | None = None,
) -> dict | None:
) -> Message | None:
_ = reply_markup
return None
@@ -575,7 +579,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
message_thread_id: int | None = None,
disable_notification: bool | None = False,
caption: str | None = None,
) -> dict | None:
) -> Message | None:
_ = (
chat_id,
filename,
@@ -597,7 +601,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
reply_markup: dict | None = None,
*,
wait: bool = True,
) -> dict | None:
) -> Message | None:
self.edit_calls.append(
{
"chat_id": chat_id,
@@ -611,7 +615,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
)
if not wait:
return None
return {"message_id": message_id}
return Message(message_id=message_id)
async def delete_message(
self,
@@ -629,7 +633,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
) -> bool:
return False
async def get_me(self) -> dict[str, Any] | None:
async def get_me(self) -> User | None:
return None
async def close(self) -> None:
@@ -778,9 +782,9 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
payload = b"hello"
class _FileBot(_FakeBot):
async def get_file(self, file_id: str) -> dict[str, Any] | None:
async def get_file(self, file_id: str) -> File | None:
_ = file_id
return {"file_path": "files/hello.txt"}
return File(file_path="files/hello.txt")
async def download_file(self, file_path: str) -> bytes | None:
_ = file_path
@@ -811,7 +815,7 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg,
files=TelegramFilesConfig(enabled=True),
files=TelegramFilesSettings(enabled=True),
)
msg = TelegramIncomingMessage(
transport="telegram",
@@ -876,9 +880,9 @@ async def test_handle_file_get_sends_document_for_allowed_user(
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg,
files=TelegramFilesConfig(
files=TelegramFilesSettings(
enabled=True,
allowed_user_ids=frozenset({42}),
allowed_user_ids=[42],
),
)
msg = TelegramIncomingMessage(
@@ -1006,7 +1010,7 @@ def test_topic_title_projects_scope_includes_project() -> None:
transport = _FakeTransport()
cfg = replace(
_make_cfg(transport),
topics=bridge.TelegramTopicsConfig(
topics=TelegramTopicsSettings(
enabled=True,
scope="projects",
),
@@ -1277,7 +1281,7 @@ async def test_run_main_loop_persists_topic_sessions_in_project_scope(
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg,
topics=bridge.TelegramTopicsConfig(
topics=TelegramTopicsSettings(
enabled=True,
scope="projects",
),
@@ -1363,11 +1367,11 @@ async def test_run_main_loop_batches_media_group_upload(
}
class _MediaBot(_FakeBot):
async def get_file(self, file_id: str) -> dict[str, Any] | None:
async def get_file(self, file_id: str) -> File | None:
file_path = file_map.get(file_id)
if file_path is None:
return None
return {"file_path": file_path}
return File(file_path=file_path)
async def download_file(self, file_path: str) -> bytes | None:
return payloads.get(file_path)
@@ -1397,7 +1401,7 @@ async def test_run_main_loop_batches_media_group_upload(
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg,
files=TelegramFilesConfig(enabled=True, auto_put=True),
files=TelegramFilesSettings(enabled=True, auto_put=True),
)
msg1 = TelegramIncomingMessage(
transport="telegram",
+172
View File
@@ -57,3 +57,175 @@ async def test_no_token_in_logs_on_http_error(
out = capsys.readouterr().out
assert token not in out
assert "bot[REDACTED]" in out
@pytest.mark.anyio
async def test_telegram_429_no_retry_post_form() -> None:
calls: list[int] = []
def handler(request: httpx.Request) -> httpx.Response:
calls.append(1)
return httpx.Response(
429,
json={
"ok": False,
"description": "retry",
"parameters": {"retry_after": 2},
},
request=request,
)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
with pytest.raises(TelegramRetryAfter) as exc:
await tg._post_form(
"sendDocument",
{"chat_id": 1},
files={"document": ("note.txt", b"hi")},
)
finally:
await client.aclose()
assert exc.value.retry_after == 2
assert len(calls) == 1
@pytest.mark.anyio
async def test_telegram_429_defaults_retry_after_on_bad_body() -> None:
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(429, text="nope", request=request)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
with pytest.raises(TelegramRetryAfter) as exc:
await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
finally:
await client.aclose()
assert exc.value.retry_after == 5.0
@pytest.mark.anyio
async def test_telegram_ok_false_returns_none() -> None:
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(
200,
json={"ok": False, "error_code": 400, "description": "bad"},
request=request,
)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
result = await tg._post("getUpdates", {"timeout": 1})
finally:
await client.aclose()
assert result is None
@pytest.mark.anyio
async def test_telegram_invalid_payload_returns_none() -> None:
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json=["not", "a", "dict"], request=request)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
result = await tg._post("getUpdates", {"timeout": 1})
finally:
await client.aclose()
assert result is None
@pytest.mark.anyio
async def test_telegram_decode_failure_returns_none() -> None:
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(
200,
json={"ok": True, "result": {"username": "bot-only"}},
request=request,
)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
result = await tg.get_me()
finally:
await client.aclose()
assert result is None
@pytest.mark.anyio
async def test_telegram_download_file_retries_on_429() -> None:
calls: list[int] = []
sleeps: list[float] = []
async def sleep(delay: float) -> None:
sleeps.append(delay)
def handler(request: httpx.Request) -> httpx.Response:
calls.append(1)
if len(calls) == 1:
return httpx.Response(
429,
json={"ok": False, "parameters": {"retry_after": 3}},
request=request,
)
return httpx.Response(200, content=b"ok", request=request)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client, sleep=sleep)
payload = await tg.download_file("path/to/file")
finally:
await client.aclose()
assert payload == b"ok"
assert sleeps == [3.0]
assert len(calls) == 2
@pytest.mark.anyio
async def test_telegram_download_file_429_defaults_retry_after_on_bad_body() -> None:
sleeps: list[float] = []
async def sleep(delay: float) -> None:
sleeps.append(delay)
calls: list[int] = []
def handler(request: httpx.Request) -> httpx.Response:
calls.append(1)
if len(calls) == 1:
return httpx.Response(429, text="nope", request=request)
return httpx.Response(200, content=b"ok", request=request)
transport = httpx.MockTransport(handler)
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client, sleep=sleep)
payload = await tg.download_file("path")
finally:
await client.aclose()
assert payload == b"ok"
assert sleeps == [5.0]
assert len(calls) == 2
+14 -12
View File
@@ -3,6 +3,7 @@ from typing import Any
import anyio
import pytest
from takopi.telegram.api_models import File, Message, Update, User
from takopi.telegram.client import BotClient, TelegramClient, TelegramRetryAfter
@@ -29,7 +30,7 @@ class _FakeBot(BotClient):
reply_markup: dict | None = None,
*,
replace_message_id: int | None = None,
) -> dict[str, Any]:
) -> Message | None:
_ = reply_to_message_id
_ = disable_notification
_ = message_thread_id
@@ -38,7 +39,7 @@ class _FakeBot(BotClient):
_ = reply_markup
_ = replace_message_id
self.calls.append("send_message")
return {"message_id": 1}
return Message(message_id=1)
async def send_document(
self,
@@ -49,7 +50,7 @@ class _FakeBot(BotClient):
message_thread_id: int | None = None,
disable_notification: bool | None = False,
caption: str | None = None,
) -> dict[str, Any]:
) -> Message | None:
_ = (
chat_id,
filename,
@@ -60,7 +61,7 @@ class _FakeBot(BotClient):
caption,
)
self.calls.append("send_document")
return {"message_id": 1}
return Message(message_id=1)
async def edit_message_text(
self,
@@ -72,7 +73,7 @@ class _FakeBot(BotClient):
reply_markup: dict | None = None,
*,
wait: bool = True,
) -> dict[str, Any]:
) -> Message | None:
_ = chat_id
_ = message_id
_ = entities
@@ -85,7 +86,7 @@ class _FakeBot(BotClient):
self._edit_attempts += 1
raise TelegramRetryAfter(self.retry_after)
self._edit_attempts += 1
return {"message_id": message_id}
return Message(message_id=message_id)
async def delete_message(
self,
@@ -113,7 +114,7 @@ class _FakeBot(BotClient):
offset: int | None,
timeout_s: int = 50,
allowed_updates: list[str] | None = None,
) -> list[dict[str, Any]] | None:
) -> list[Update] | None:
_ = offset
_ = timeout_s
_ = allowed_updates
@@ -123,7 +124,7 @@ class _FakeBot(BotClient):
self._updates_attempts += 1
return []
async def get_file(self, file_id: str) -> dict[str, Any] | None:
async def get_file(self, file_id: str) -> File | None:
_ = file_id
return None
@@ -134,8 +135,8 @@ class _FakeBot(BotClient):
async def close(self) -> None:
return None
async def get_me(self) -> dict[str, Any] | None:
return {"id": 1}
async def get_me(self) -> User | None:
return User(id=1)
async def answer_callback_query(
self,
@@ -187,7 +188,7 @@ async def test_edits_coalesce_latest() -> None:
reply_markup: dict | None = None,
*,
wait: bool = True,
) -> dict:
) -> Message | None:
if self._block_first:
self._block_first = False
self.edit_started.set()
@@ -305,7 +306,8 @@ async def test_retry_after_retries_once() -> None:
text="retry",
)
assert result == {"message_id": 1}
assert result is not None
assert result.message_id == 1
assert bot._edit_attempts == 2
+243
View File
@@ -0,0 +1,243 @@
from __future__ import annotations
import pytest
from takopi.telegram.api_models import (
Chat,
ChatMember,
File,
ForumTopic,
Message,
Update,
User,
)
from takopi.telegram.client import BotClient
from takopi.telegram.types import TelegramIncomingMessage, TelegramVoice
from takopi.telegram.voice import transcribe_voice
class _Bot(BotClient):
def __init__(self, *, file_info: File | None, audio: bytes | None) -> None:
self._file_info = file_info
self._audio = audio
async def close(self) -> None:
return None
async def get_updates(
self,
offset: int | None,
timeout_s: int = 50,
allowed_updates: list[str] | None = None,
) -> list[Update] | None:
_ = offset, timeout_s, allowed_updates
return []
async def get_file(self, file_id: str) -> File | None:
_ = file_id
return self._file_info
async def download_file(self, file_path: str) -> bytes | None:
_ = file_path
return self._audio
async def send_message(
self,
chat_id: int,
text: str,
reply_to_message_id: int | None = None,
disable_notification: bool | None = False,
message_thread_id: int | None = None,
entities: list[dict] | None = None,
parse_mode: str | None = None,
reply_markup: dict | None = None,
*,
replace_message_id: int | None = None,
) -> Message | None:
_ = (
chat_id,
text,
reply_to_message_id,
disable_notification,
message_thread_id,
entities,
parse_mode,
reply_markup,
replace_message_id,
)
raise AssertionError("send_message should not be called")
async def send_document(
self,
chat_id: int,
filename: str,
content: bytes,
reply_to_message_id: int | None = None,
message_thread_id: int | None = None,
disable_notification: bool | None = False,
caption: str | None = None,
) -> Message | None:
_ = (
chat_id,
filename,
content,
reply_to_message_id,
message_thread_id,
disable_notification,
caption,
)
raise AssertionError("send_document should not be called")
async def edit_message_text(
self,
chat_id: int,
message_id: int,
text: str,
entities: list[dict] | None = None,
parse_mode: str | None = None,
reply_markup: dict | None = None,
*,
wait: bool = True,
) -> Message | None:
_ = (
chat_id,
message_id,
text,
entities,
parse_mode,
reply_markup,
wait,
)
raise AssertionError("edit_message_text should not be called")
async def delete_message(self, chat_id: int, message_id: int) -> bool:
_ = chat_id, message_id
raise AssertionError("delete_message should not be called")
async def set_my_commands(
self,
commands: list[dict],
*,
scope: dict | None = None,
language_code: str | None = None,
) -> bool:
_ = commands, scope, language_code
raise AssertionError("set_my_commands should not be called")
async def get_me(self) -> User | None:
raise AssertionError("get_me should not be called")
async def answer_callback_query(
self,
callback_query_id: str,
text: str | None = None,
show_alert: bool | None = None,
) -> bool:
_ = callback_query_id, text, show_alert
raise AssertionError("answer_callback_query should not be called")
async def get_chat(self, chat_id: int) -> Chat | None:
_ = chat_id
raise AssertionError("get_chat should not be called")
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
_ = chat_id, user_id
raise AssertionError("get_chat_member should not be called")
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
_ = chat_id, name
raise AssertionError("create_forum_topic should not be called")
async def edit_forum_topic(
self, chat_id: int, message_thread_id: int, name: str
) -> bool:
_ = chat_id, message_thread_id, name
raise AssertionError("edit_forum_topic should not be called")
def _voice_message(*, file_size: int = 123) -> TelegramIncomingMessage:
voice = TelegramVoice(
file_id="voice-id",
mime_type="audio/ogg",
file_size=file_size,
duration=1,
raw={},
)
return TelegramIncomingMessage(
transport="telegram",
chat_id=1,
message_id=1,
text="",
reply_to_message_id=None,
reply_to_text=None,
sender_id=1,
voice=voice,
raw={},
)
@pytest.mark.anyio
async def test_transcribe_voice_handles_missing_file() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
bot = _Bot(file_info=None, audio=None)
result = await transcribe_voice(
bot=bot,
msg=_voice_message(),
enabled=True,
reply=reply,
)
assert result is None
assert replies[-1] == "failed to fetch voice file."
@pytest.mark.anyio
async def test_transcribe_voice_handles_missing_download() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
bot = _Bot(file_info=File(file_path="voice.ogg"), audio=None)
result = await transcribe_voice(
bot=bot,
msg=_voice_message(),
enabled=True,
reply=reply,
)
assert result is None
assert replies[-1] == "failed to download voice file."
@pytest.mark.anyio
async def test_transcribe_voice_rejects_large_voice_without_downloading() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
class _NoFetchBot(_Bot):
async def get_file(self, file_id: str) -> File | None: # type: ignore[override]
_ = file_id
raise AssertionError("get_file should not be called")
async def download_file(self, file_path: str) -> bytes | None: # type: ignore[override]
_ = file_path
raise AssertionError("download_file should not be called")
bot = _NoFetchBot(file_info=None, audio=None)
result = await transcribe_voice(
bot=bot,
msg=_voice_message(file_size=10_000),
enabled=True,
max_bytes=100,
reply=reply,
)
assert result is None
assert replies[-1] == "voice message is too large to transcribe."
+2 -2
View File
@@ -15,14 +15,14 @@ class DummyTransport:
def interactive_setup(self, *, force: bool) -> bool:
raise NotImplementedError
def lock_token(self, *, transport_config: dict[str, object], config_path):
def lock_token(self, *, transport_config: object, config_path):
_ = transport_config, config_path
raise NotImplementedError
def build_and_run(
self,
*,
transport_config: dict[str, object],
transport_config: object,
config_path,
runtime,
final_notify: bool,