refactor(telegram): boundary types (#90)
This commit is contained in:
+4
-5
@@ -105,11 +105,7 @@ def _default_engine_for_setup(
|
|||||||
if settings is None or config_path is None:
|
if settings is None or config_path is None:
|
||||||
return "codex"
|
return "codex"
|
||||||
value = settings.default_engine
|
value = settings.default_engine
|
||||||
if not isinstance(value, str) or not value.strip():
|
return value
|
||||||
raise ConfigError(
|
|
||||||
f"Invalid `default_engine` in {config_path}; expected a non-empty string."
|
|
||||||
)
|
|
||||||
return value.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _config_path_display(path: Path) -> str:
|
def _config_path_display(path: Path) -> str:
|
||||||
@@ -235,6 +231,9 @@ def _run_auto_router(
|
|||||||
default_engine_override=default_engine_override,
|
default_engine_override=default_engine_override,
|
||||||
reserved=("cancel",),
|
reserved=("cancel",),
|
||||||
)
|
)
|
||||||
|
if settings.transport == "telegram":
|
||||||
|
transport_config = settings.transports.telegram
|
||||||
|
else:
|
||||||
transport_config = settings.transport_config(
|
transport_config = settings.transport_config(
|
||||||
settings.transport, config_path=config_path
|
settings.transport, config_path=config_path
|
||||||
)
|
)
|
||||||
|
|||||||
+11
-2
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterable
|
from typing import Iterable, Literal, TypeAlias
|
||||||
|
|
||||||
from .model import EngineId, ResumeToken
|
from .model import EngineId, ResumeToken
|
||||||
from .runner import Runner
|
from .runner import Runner
|
||||||
@@ -17,13 +17,22 @@ class RunnerUnavailableError(RuntimeError):
|
|||||||
self.issue = issue
|
self.issue = issue
|
||||||
|
|
||||||
|
|
||||||
|
EngineStatus: TypeAlias = Literal["ok", "missing_cli", "bad_config", "load_error"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class RunnerEntry:
|
class RunnerEntry:
|
||||||
engine: EngineId
|
engine: EngineId
|
||||||
runner: Runner
|
runner: Runner
|
||||||
available: bool = True
|
status: EngineStatus = "ok"
|
||||||
issue: str | None = None
|
issue: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available(self) -> bool:
|
||||||
|
# "bad_config" means we ignored user config and built the runner with defaults.
|
||||||
|
# The engine is still runnable, but a warning should be surfaced to the user.
|
||||||
|
return self.status in {"ok", "bad_config"}
|
||||||
|
|
||||||
|
|
||||||
class AutoRouter:
|
class AutoRouter:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from .backends import EngineBackend
|
|||||||
from .config import ConfigError, ProjectsConfig
|
from .config import ConfigError, ProjectsConfig
|
||||||
from .engines import get_backend, list_backend_ids
|
from .engines import get_backend, list_backend_ids
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
from .router import AutoRouter, RunnerEntry
|
from .router import AutoRouter, EngineStatus, RunnerEntry
|
||||||
from .settings import TakopiSettings
|
from .settings import TakopiSettings
|
||||||
from .transport_runtime import TransportRuntime
|
from .transport_runtime import TransportRuntime
|
||||||
|
|
||||||
@@ -83,6 +83,7 @@ def build_router(
|
|||||||
for backend in backends:
|
for backend in backends:
|
||||||
engine_id = backend.id
|
engine_id = backend.id
|
||||||
issue: str | None = None
|
issue: str | None = None
|
||||||
|
status: EngineStatus = "ok"
|
||||||
engine_cfg: dict
|
engine_cfg: dict
|
||||||
try:
|
try:
|
||||||
engine_cfg = settings.engine_config(engine_id, config_path=config_path)
|
engine_cfg = settings.engine_config(engine_id, config_path=config_path)
|
||||||
@@ -90,6 +91,7 @@ def build_router(
|
|||||||
if engine_id == default_engine:
|
if engine_id == default_engine:
|
||||||
raise
|
raise
|
||||||
issue = str(exc)
|
issue = str(exc)
|
||||||
|
status = "bad_config"
|
||||||
engine_cfg = {}
|
engine_cfg = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -104,26 +106,31 @@ def build_router(
|
|||||||
except Exception as fallback_exc:
|
except Exception as fallback_exc:
|
||||||
warnings.append(f"{engine_id}: {issue or str(fallback_exc)}")
|
warnings.append(f"{engine_id}: {issue or str(fallback_exc)}")
|
||||||
continue
|
continue
|
||||||
|
status = "bad_config"
|
||||||
else:
|
else:
|
||||||
|
status = "load_error"
|
||||||
warnings.append(f"{engine_id}: {issue}")
|
warnings.append(f"{engine_id}: {issue}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cmd = backend.cli_cmd or backend.id
|
cmd = backend.cli_cmd or backend.id
|
||||||
if shutil.which(cmd) is None:
|
if shutil.which(cmd) is None:
|
||||||
issue = issue or f"{cmd} not found on PATH"
|
status = "missing_cli"
|
||||||
|
if issue:
|
||||||
|
issue = f"{issue}; {cmd} not found on PATH"
|
||||||
|
else:
|
||||||
|
issue = f"{cmd} not found on PATH"
|
||||||
|
|
||||||
if issue and engine_id == default_engine:
|
if status != "ok" and engine_id == default_engine:
|
||||||
raise ConfigError(f"Default engine {engine_id!r} unavailable: {issue}")
|
raise ConfigError(f"Default engine {engine_id!r} unavailable: {issue}")
|
||||||
|
|
||||||
available = issue is None
|
if status != "ok" and engine_id != default_engine:
|
||||||
if issue and engine_id != default_engine:
|
|
||||||
warnings.append(f"{engine_id}: {issue}")
|
warnings.append(f"{engine_id}: {issue}")
|
||||||
|
|
||||||
entries.append(
|
entries.append(
|
||||||
RunnerEntry(
|
RunnerEntry(
|
||||||
engine=engine_id,
|
engine=engine_id,
|
||||||
runner=runner,
|
runner=runner,
|
||||||
available=available,
|
status=status,
|
||||||
issue=issue,
|
issue=issue,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
+11
-6
@@ -83,6 +83,14 @@ class TelegramFilesSettings(BaseModel):
|
|||||||
raise ValueError("files.uploads_dir must be a relative path")
|
raise ValueError("files.uploads_dir must be a relative path")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_upload_bytes(self) -> int:
|
||||||
|
return 20 * 1024 * 1024
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_download_bytes(self) -> int:
|
||||||
|
return 50 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
class TelegramTransportSettings(BaseModel):
|
class TelegramTransportSettings(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
|
||||||
@@ -90,14 +98,13 @@ class TelegramTransportSettings(BaseModel):
|
|||||||
bot_token: NonEmptyStr
|
bot_token: NonEmptyStr
|
||||||
chat_id: StrictInt
|
chat_id: StrictInt
|
||||||
voice_transcription: bool = False
|
voice_transcription: bool = False
|
||||||
|
voice_max_bytes: StrictInt = 10 * 1024 * 1024
|
||||||
topics: TelegramTopicsSettings = Field(default_factory=TelegramTopicsSettings)
|
topics: TelegramTopicsSettings = Field(default_factory=TelegramTopicsSettings)
|
||||||
files: TelegramFilesSettings = Field(default_factory=TelegramFilesSettings)
|
files: TelegramFilesSettings = Field(default_factory=TelegramFilesSettings)
|
||||||
|
|
||||||
|
|
||||||
class TransportsSettings(BaseModel):
|
class TransportsSettings(BaseModel):
|
||||||
telegram: TelegramTransportSettings = Field(
|
telegram: TelegramTransportSettings
|
||||||
default_factory=TelegramTransportSettings
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
@@ -132,7 +139,7 @@ class TakopiSettings(BaseSettings):
|
|||||||
projects: dict[str, ProjectSettings] = Field(default_factory=dict)
|
projects: dict[str, ProjectSettings] = Field(default_factory=dict)
|
||||||
|
|
||||||
transport: NonEmptyStr = "telegram"
|
transport: NonEmptyStr = "telegram"
|
||||||
transports: TransportsSettings = Field(default_factory=TransportsSettings)
|
transports: TransportsSettings
|
||||||
|
|
||||||
plugins: PluginsSettings = Field(default_factory=PluginsSettings)
|
plugins: PluginsSettings = Field(default_factory=PluginsSettings)
|
||||||
|
|
||||||
@@ -310,8 +317,6 @@ def require_telegram(settings: TakopiSettings, config_path: Path) -> tuple[str,
|
|||||||
"(telegram only for now)."
|
"(telegram only for now)."
|
||||||
)
|
)
|
||||||
tg = settings.transports.telegram
|
tg = settings.transports.telegram
|
||||||
if not tg.bot_token:
|
|
||||||
raise ConfigError(f"Missing bot token in {config_path}.")
|
|
||||||
return tg.bot_token, tg.chat_id
|
return tg.bot_token, tg.chat_id
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Chat",
|
||||||
|
"ChatMember",
|
||||||
|
"File",
|
||||||
|
"ForumTopic",
|
||||||
|
"Message",
|
||||||
|
"Update",
|
||||||
|
"User",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class User(msgspec.Struct, forbid_unknown_fields=False):
|
||||||
|
id: int
|
||||||
|
username: str | None = None
|
||||||
|
first_name: str | None = None
|
||||||
|
last_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Chat(msgspec.Struct, forbid_unknown_fields=False):
|
||||||
|
id: int
|
||||||
|
type: str
|
||||||
|
is_forum: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMember(msgspec.Struct, forbid_unknown_fields=False):
|
||||||
|
status: str
|
||||||
|
can_manage_topics: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Message(msgspec.Struct, forbid_unknown_fields=False):
|
||||||
|
message_id: int
|
||||||
|
message_thread_id: int | None = None
|
||||||
|
text: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class File(msgspec.Struct, forbid_unknown_fields=False):
|
||||||
|
file_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class ForumTopic(msgspec.Struct, forbid_unknown_fields=False):
|
||||||
|
message_thread_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class Update(msgspec.Struct, forbid_unknown_fields=False):
|
||||||
|
update_id: int
|
||||||
|
message: dict[str, Any] | None = None
|
||||||
|
callback_query: dict[str, Any] | None = None
|
||||||
@@ -2,22 +2,19 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
from ..backends import EngineBackend
|
from ..backends import EngineBackend
|
||||||
from ..runner_bridge import ExecBridgeConfig
|
|
||||||
from ..logging import get_logger
|
from ..logging import get_logger
|
||||||
|
from ..runner_bridge import ExecBridgeConfig
|
||||||
from ..transports import SetupResult, TransportBackend
|
from ..settings import TelegramTransportSettings
|
||||||
from ..transport_runtime import TransportRuntime
|
from ..transport_runtime import TransportRuntime
|
||||||
|
from ..transports import SetupResult, TransportBackend
|
||||||
from .bridge import (
|
from .bridge import (
|
||||||
TelegramBridgeConfig,
|
TelegramBridgeConfig,
|
||||||
TelegramPresenter,
|
TelegramPresenter,
|
||||||
TelegramTransport,
|
TelegramTransport,
|
||||||
TelegramFilesConfig,
|
|
||||||
TelegramTopicsConfig,
|
|
||||||
run_main_loop,
|
run_main_loop,
|
||||||
)
|
)
|
||||||
from .client import TelegramClient
|
from .client import TelegramClient
|
||||||
@@ -26,6 +23,12 @@ from .onboarding import check_setup, interactive_setup
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _expect_transport_settings(transport_config: object) -> TelegramTransportSettings:
|
||||||
|
if isinstance(transport_config, TelegramTransportSettings):
|
||||||
|
return transport_config
|
||||||
|
raise TypeError("transport_config must be TelegramTransportSettings")
|
||||||
|
|
||||||
|
|
||||||
def _build_startup_message(
|
def _build_startup_message(
|
||||||
runtime: TransportRuntime,
|
runtime: TransportRuntime,
|
||||||
*,
|
*,
|
||||||
@@ -33,9 +36,20 @@ def _build_startup_message(
|
|||||||
) -> str:
|
) -> str:
|
||||||
available_engines = list(runtime.available_engine_ids())
|
available_engines = list(runtime.available_engine_ids())
|
||||||
missing_engines = list(runtime.missing_engine_ids())
|
missing_engines = list(runtime.missing_engine_ids())
|
||||||
|
misconfigured_engines = list(runtime.engine_ids_with_status("bad_config"))
|
||||||
|
failed_engines = list(runtime.engine_ids_with_status("load_error"))
|
||||||
|
|
||||||
engine_list = ", ".join(available_engines) if available_engines else "none"
|
engine_list = ", ".join(available_engines) if available_engines else "none"
|
||||||
|
|
||||||
|
notes: list[str] = []
|
||||||
if missing_engines:
|
if missing_engines:
|
||||||
engine_list = f"{engine_list} (not installed: {', '.join(missing_engines)})"
|
notes.append(f"not installed: {', '.join(missing_engines)}")
|
||||||
|
if misconfigured_engines:
|
||||||
|
notes.append(f"misconfigured: {', '.join(misconfigured_engines)}")
|
||||||
|
if failed_engines:
|
||||||
|
notes.append(f"failed to load: {', '.join(failed_engines)}")
|
||||||
|
if notes:
|
||||||
|
engine_list = f"{engine_list} ({'; '.join(notes)})"
|
||||||
project_aliases = sorted(
|
project_aliases = sorted(
|
||||||
{alias for alias in runtime.project_aliases()}, key=str.lower
|
{alias for alias in runtime.project_aliases()}, key=str.lower
|
||||||
)
|
)
|
||||||
@@ -49,34 +63,6 @@ def _build_startup_message(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_topics_config(transport_config: dict[str, object]) -> TelegramTopicsConfig:
|
|
||||||
raw = cast(dict[str, object], transport_config.get("topics", {}))
|
|
||||||
return TelegramTopicsConfig(
|
|
||||||
enabled=cast(bool, raw.get("enabled", False)),
|
|
||||||
scope=cast(str, raw.get("scope", "auto")),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_files_config(transport_config: dict[str, object]) -> TelegramFilesConfig:
|
|
||||||
defaults = TelegramFilesConfig()
|
|
||||||
raw = cast(dict[str, object], transport_config.get("files", {}))
|
|
||||||
return TelegramFilesConfig(
|
|
||||||
enabled=cast(bool, raw.get("enabled", defaults.enabled)),
|
|
||||||
auto_put=cast(bool, raw.get("auto_put", defaults.auto_put)),
|
|
||||||
uploads_dir=cast(str, raw.get("uploads_dir", defaults.uploads_dir)),
|
|
||||||
max_upload_bytes=defaults.max_upload_bytes,
|
|
||||||
max_download_bytes=defaults.max_download_bytes,
|
|
||||||
allowed_user_ids=frozenset(
|
|
||||||
cast(
|
|
||||||
list[int], raw.get("allowed_user_ids", list(defaults.allowed_user_ids))
|
|
||||||
)
|
|
||||||
),
|
|
||||||
deny_globs=tuple(
|
|
||||||
cast(list[str], raw.get("deny_globs", list(defaults.deny_globs)))
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramBackend(TransportBackend):
|
class TelegramBackend(TransportBackend):
|
||||||
id = "telegram"
|
id = "telegram"
|
||||||
description = "Telegram bot"
|
description = "Telegram bot"
|
||||||
@@ -92,23 +78,23 @@ class TelegramBackend(TransportBackend):
|
|||||||
def interactive_setup(self, *, force: bool) -> bool:
|
def interactive_setup(self, *, force: bool) -> bool:
|
||||||
return interactive_setup(force=force)
|
return interactive_setup(force=force)
|
||||||
|
|
||||||
def lock_token(
|
def lock_token(self, *, transport_config: object, config_path: Path) -> str | None:
|
||||||
self, *, transport_config: dict[str, object], config_path: Path
|
|
||||||
) -> str | None:
|
|
||||||
_ = config_path
|
_ = config_path
|
||||||
return cast(str, transport_config.get("bot_token"))
|
settings = _expect_transport_settings(transport_config)
|
||||||
|
return settings.bot_token
|
||||||
|
|
||||||
def build_and_run(
|
def build_and_run(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
transport_config: dict[str, object],
|
transport_config: object,
|
||||||
config_path: Path,
|
config_path: Path,
|
||||||
runtime: TransportRuntime,
|
runtime: TransportRuntime,
|
||||||
final_notify: bool,
|
final_notify: bool,
|
||||||
default_engine_override: str | None,
|
default_engine_override: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
token = cast(str, transport_config.get("bot_token"))
|
settings = _expect_transport_settings(transport_config)
|
||||||
chat_id = cast(int, transport_config.get("chat_id"))
|
token = settings.bot_token
|
||||||
|
chat_id = settings.chat_id
|
||||||
startup_msg = _build_startup_message(
|
startup_msg = _build_startup_message(
|
||||||
runtime,
|
runtime,
|
||||||
startup_pwd=os.getcwd(),
|
startup_pwd=os.getcwd(),
|
||||||
@@ -121,19 +107,16 @@ class TelegramBackend(TransportBackend):
|
|||||||
presenter=presenter,
|
presenter=presenter,
|
||||||
final_notify=final_notify,
|
final_notify=final_notify,
|
||||||
)
|
)
|
||||||
topics = _build_topics_config(transport_config)
|
|
||||||
files = _build_files_config(transport_config)
|
|
||||||
cfg = TelegramBridgeConfig(
|
cfg = TelegramBridgeConfig(
|
||||||
bot=bot,
|
bot=bot,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
startup_msg=startup_msg,
|
startup_msg=startup_msg,
|
||||||
exec_cfg=exec_cfg,
|
exec_cfg=exec_cfg,
|
||||||
voice_transcription=cast(
|
voice_transcription=settings.voice_transcription,
|
||||||
bool, transport_config.get("voice_transcription", False)
|
voice_max_bytes=int(settings.voice_max_bytes),
|
||||||
),
|
topics=settings.topics,
|
||||||
topics=topics,
|
files=settings.files,
|
||||||
files=files,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_loop() -> None:
|
async def run_loop() -> None:
|
||||||
@@ -142,7 +125,7 @@ class TelegramBackend(TransportBackend):
|
|||||||
watch_config=runtime.watch_config,
|
watch_config=runtime.watch_config,
|
||||||
default_engine_override=default_engine_override,
|
default_engine_override=default_engine_override,
|
||||||
transport_id=self.id,
|
transport_id=self.id,
|
||||||
transport_config=transport_config,
|
transport_config=settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
anyio.run(run_loop)
|
anyio.run(run_loop)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from ..logging import get_logger
|
from ..logging import get_logger
|
||||||
@@ -12,6 +12,11 @@ from ..transport import MessageRef, RenderedMessage, SendOptions, Transport
|
|||||||
from ..transport_runtime import TransportRuntime
|
from ..transport_runtime import TransportRuntime
|
||||||
from ..context import RunContext
|
from ..context import RunContext
|
||||||
from ..model import ResumeToken
|
from ..model import ResumeToken
|
||||||
|
from ..settings import (
|
||||||
|
TelegramFilesSettings,
|
||||||
|
TelegramTopicsSettings,
|
||||||
|
TelegramTransportSettings,
|
||||||
|
)
|
||||||
from .client import BotClient
|
from .client import BotClient
|
||||||
from .render import prepare_telegram
|
from .render import prepare_telegram
|
||||||
from .types import TelegramCallbackQuery, TelegramIncomingMessage
|
from .types import TelegramCallbackQuery, TelegramIncomingMessage
|
||||||
@@ -20,8 +25,6 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TelegramBridgeConfig",
|
"TelegramBridgeConfig",
|
||||||
"TelegramFilesConfig",
|
|
||||||
"TelegramTopicsConfig",
|
|
||||||
"TelegramPresenter",
|
"TelegramPresenter",
|
||||||
"TelegramTransport",
|
"TelegramTransport",
|
||||||
"build_bot_commands",
|
"build_bot_commands",
|
||||||
@@ -85,29 +88,6 @@ def _is_cancelled_label(label: str) -> bool:
|
|||||||
return stripped.lower() == "cancelled"
|
return stripped.lower() == "cancelled"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class TelegramFilesConfig:
|
|
||||||
enabled: bool = False
|
|
||||||
auto_put: bool = True
|
|
||||||
uploads_dir: str = "incoming"
|
|
||||||
max_upload_bytes: int = 20 * 1024 * 1024
|
|
||||||
max_download_bytes: int = 50 * 1024 * 1024
|
|
||||||
allowed_user_ids: frozenset[int] = frozenset()
|
|
||||||
deny_globs: tuple[str, ...] = (
|
|
||||||
".git/**",
|
|
||||||
".env",
|
|
||||||
".envrc",
|
|
||||||
"**/*.pem",
|
|
||||||
"**/.ssh/**",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class TelegramTopicsConfig:
|
|
||||||
enabled: bool = False
|
|
||||||
scope: str = "auto"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TelegramBridgeConfig:
|
class TelegramBridgeConfig:
|
||||||
bot: BotClient
|
bot: BotClient
|
||||||
@@ -116,9 +96,10 @@ class TelegramBridgeConfig:
|
|||||||
startup_msg: str
|
startup_msg: str
|
||||||
exec_cfg: ExecBridgeConfig
|
exec_cfg: ExecBridgeConfig
|
||||||
voice_transcription: bool = False
|
voice_transcription: bool = False
|
||||||
files: TelegramFilesConfig = TelegramFilesConfig()
|
voice_max_bytes: int = 10 * 1024 * 1024
|
||||||
|
files: TelegramFilesSettings = field(default_factory=TelegramFilesSettings)
|
||||||
chat_ids: tuple[int, ...] | None = None
|
chat_ids: tuple[int, ...] | None = None
|
||||||
topics: TelegramTopicsConfig = TelegramTopicsConfig()
|
topics: TelegramTopicsSettings = field(default_factory=TelegramTopicsSettings)
|
||||||
|
|
||||||
|
|
||||||
class TelegramTransport:
|
class TelegramTransport:
|
||||||
@@ -166,7 +147,7 @@ class TelegramTransport:
|
|||||||
)
|
)
|
||||||
if sent is None:
|
if sent is None:
|
||||||
return None
|
return None
|
||||||
message_id = cast(int, sent["message_id"])
|
message_id = sent.message_id
|
||||||
return MessageRef(
|
return MessageRef(
|
||||||
channel_id=chat_id,
|
channel_id=chat_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
@@ -192,7 +173,7 @@ class TelegramTransport:
|
|||||||
)
|
)
|
||||||
if edited is None:
|
if edited is None:
|
||||||
return ref if not wait else None
|
return ref if not wait else None
|
||||||
message_id = cast(int, edited.get("message_id", message_id))
|
message_id = edited.message_id
|
||||||
return MessageRef(
|
return MessageRef(
|
||||||
channel_id=chat_id,
|
channel_id=chat_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
@@ -287,7 +268,7 @@ async def run_main_loop(
|
|||||||
watch_config: bool | None = None,
|
watch_config: bool | None = None,
|
||||||
default_engine_override: str | None = None,
|
default_engine_override: str | None = None,
|
||||||
transport_id: str | None = None,
|
transport_id: str | None = None,
|
||||||
transport_config: dict[str, object] | None = None,
|
transport_config: TelegramTransportSettings | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
from .loop import run_main_loop as _run_main_loop
|
from .loop import run_main_loop as _run_main_loop
|
||||||
|
|
||||||
|
|||||||
+247
-205
@@ -12,13 +12,16 @@ from typing import (
|
|||||||
Iterable,
|
Iterable,
|
||||||
Protocol,
|
Protocol,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import msgspec
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
from ..logging import get_logger
|
from ..logging import get_logger
|
||||||
|
from .api_models import Chat, ChatMember, File, ForumTopic, Message, Update, User
|
||||||
from .types import (
|
from .types import (
|
||||||
TelegramCallbackQuery,
|
TelegramCallbackQuery,
|
||||||
TelegramDocument,
|
TelegramDocument,
|
||||||
@@ -29,6 +32,8 @@ from .types import (
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
SEND_PRIORITY = 0
|
SEND_PRIORITY = 0
|
||||||
DELETE_PRIORITY = 1
|
DELETE_PRIORITY = 1
|
||||||
@@ -51,15 +56,20 @@ def is_group_chat_id(chat_id: int) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def parse_incoming_update(
|
def parse_incoming_update(
|
||||||
update: dict[str, Any],
|
update: Update | dict[str, Any],
|
||||||
*,
|
*,
|
||||||
chat_id: int | None = None,
|
chat_id: int | None = None,
|
||||||
chat_ids: set[int] | None = None,
|
chat_ids: set[int] | None = None,
|
||||||
) -> TelegramIncomingUpdate | None:
|
) -> TelegramIncomingUpdate | None:
|
||||||
|
if isinstance(update, Update):
|
||||||
|
msg = update.message
|
||||||
|
callback_query = update.callback_query
|
||||||
|
else:
|
||||||
msg = update.get("message")
|
msg = update.get("message")
|
||||||
|
callback_query = update.get("callback_query")
|
||||||
|
|
||||||
if isinstance(msg, dict):
|
if isinstance(msg, dict):
|
||||||
return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids)
|
return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids)
|
||||||
callback_query = update.get("callback_query")
|
|
||||||
if isinstance(callback_query, dict):
|
if isinstance(callback_query, dict):
|
||||||
return _parse_callback_query(
|
return _parse_callback_query(
|
||||||
callback_query,
|
callback_query,
|
||||||
@@ -303,7 +313,7 @@ async def poll_incoming(
|
|||||||
if allowed is None and chat_id is not None:
|
if allowed is None and chat_id is not None:
|
||||||
allowed = {chat_id}
|
allowed = {chat_id}
|
||||||
for upd in updates:
|
for upd in updates:
|
||||||
offset = upd["update_id"] + 1
|
offset = upd.update_id + 1
|
||||||
msg = parse_incoming_update(upd, chat_ids=allowed)
|
msg = parse_incoming_update(upd, chat_ids=allowed)
|
||||||
if msg is not None:
|
if msg is not None:
|
||||||
yield msg
|
yield msg
|
||||||
@@ -317,9 +327,9 @@ class BotClient(Protocol):
|
|||||||
offset: int | None,
|
offset: int | None,
|
||||||
timeout_s: int = 50,
|
timeout_s: int = 50,
|
||||||
allowed_updates: list[str] | None = None,
|
allowed_updates: list[str] | None = None,
|
||||||
) -> list[dict] | None: ...
|
) -> list[Update] | None: ...
|
||||||
|
|
||||||
async def get_file(self, file_id: str) -> dict | None: ...
|
async def get_file(self, file_id: str) -> File | None: ...
|
||||||
|
|
||||||
async def download_file(self, file_path: str) -> bytes | None: ...
|
async def download_file(self, file_path: str) -> bytes | None: ...
|
||||||
|
|
||||||
@@ -335,7 +345,7 @@ class BotClient(Protocol):
|
|||||||
reply_markup: dict[str, Any] | None = None,
|
reply_markup: dict[str, Any] | None = None,
|
||||||
*,
|
*,
|
||||||
replace_message_id: int | None = None,
|
replace_message_id: int | None = None,
|
||||||
) -> dict | None: ...
|
) -> Message | None: ...
|
||||||
|
|
||||||
async def send_document(
|
async def send_document(
|
||||||
self,
|
self,
|
||||||
@@ -346,7 +356,7 @@ class BotClient(Protocol):
|
|||||||
message_thread_id: int | None = None,
|
message_thread_id: int | None = None,
|
||||||
disable_notification: bool | None = False,
|
disable_notification: bool | None = False,
|
||||||
caption: str | None = None,
|
caption: str | None = None,
|
||||||
) -> dict | None: ...
|
) -> Message | None: ...
|
||||||
|
|
||||||
async def edit_message_text(
|
async def edit_message_text(
|
||||||
self,
|
self,
|
||||||
@@ -358,7 +368,7 @@ class BotClient(Protocol):
|
|||||||
reply_markup: dict[str, Any] | None = None,
|
reply_markup: dict[str, Any] | None = None,
|
||||||
*,
|
*,
|
||||||
wait: bool = True,
|
wait: bool = True,
|
||||||
) -> dict | None: ...
|
) -> Message | None: ...
|
||||||
|
|
||||||
async def delete_message(
|
async def delete_message(
|
||||||
self,
|
self,
|
||||||
@@ -374,7 +384,7 @@ class BotClient(Protocol):
|
|||||||
language_code: str | None = None,
|
language_code: str | None = None,
|
||||||
) -> bool: ...
|
) -> bool: ...
|
||||||
|
|
||||||
async def get_me(self) -> dict | None: ...
|
async def get_me(self) -> User | None: ...
|
||||||
|
|
||||||
async def answer_callback_query(
|
async def answer_callback_query(
|
||||||
self,
|
self,
|
||||||
@@ -383,15 +393,17 @@ class BotClient(Protocol):
|
|||||||
show_alert: bool | None = None,
|
show_alert: bool | None = None,
|
||||||
) -> bool: ...
|
) -> bool: ...
|
||||||
|
|
||||||
async def get_chat(self, chat_id: int) -> dict | None: ...
|
async def get_chat(self, chat_id: int) -> Chat | None: ...
|
||||||
|
|
||||||
async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None: ...
|
async def get_chat_member(
|
||||||
|
self, chat_id: int, user_id: int
|
||||||
|
) -> ChatMember | None: ...
|
||||||
|
|
||||||
async def create_forum_topic(
|
async def create_forum_topic(
|
||||||
self,
|
self,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
name: str,
|
name: str,
|
||||||
) -> dict | None: ...
|
) -> ForumTopic | None: ...
|
||||||
|
|
||||||
async def edit_forum_topic(
|
async def edit_forum_topic(
|
||||||
self,
|
self,
|
||||||
@@ -672,71 +684,13 @@ class TelegramClient:
|
|||||||
if self._owns_http_client and self._http_client is not None:
|
if self._owns_http_client and self._http_client is not None:
|
||||||
await self._http_client.aclose()
|
await self._http_client.aclose()
|
||||||
|
|
||||||
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
|
def _parse_telegram_envelope(
|
||||||
if self._http_client is None or self._base is None:
|
self,
|
||||||
raise RuntimeError("TelegramClient is configured without an HTTP client.")
|
*,
|
||||||
logger.debug("telegram.request", method=method, payload=json_data)
|
method: str,
|
||||||
try:
|
resp: httpx.Response,
|
||||||
resp = await self._http_client.post(
|
payload: Any,
|
||||||
f"{self._base}/{method}", json=json_data
|
) -> Any | None:
|
||||||
)
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
url = getattr(e.request, "url", None)
|
|
||||||
logger.error(
|
|
||||||
"telegram.network_error",
|
|
||||||
method=method,
|
|
||||||
url=str(url) if url is not None else None,
|
|
||||||
error=str(e),
|
|
||||||
error_type=e.__class__.__name__,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
resp.raise_for_status()
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
if resp.status_code == 429:
|
|
||||||
retry_after: float | None = None
|
|
||||||
try:
|
|
||||||
payload = resp.json()
|
|
||||||
except Exception:
|
|
||||||
payload = None
|
|
||||||
if isinstance(payload, dict):
|
|
||||||
retry_after = retry_after_from_payload(payload)
|
|
||||||
retry_after = 5.0 if retry_after is None else retry_after
|
|
||||||
logger.warning(
|
|
||||||
"telegram.rate_limited",
|
|
||||||
method=method,
|
|
||||||
status=resp.status_code,
|
|
||||||
url=str(resp.request.url),
|
|
||||||
retry_after=retry_after,
|
|
||||||
)
|
|
||||||
raise TelegramRetryAfter(retry_after) from e
|
|
||||||
body = resp.text
|
|
||||||
logger.error(
|
|
||||||
"telegram.http_error",
|
|
||||||
method=method,
|
|
||||||
status=resp.status_code,
|
|
||||||
url=str(resp.request.url),
|
|
||||||
error=str(e),
|
|
||||||
body=body,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
payload = resp.json()
|
|
||||||
except Exception as e:
|
|
||||||
body = resp.text
|
|
||||||
logger.error(
|
|
||||||
"telegram.bad_response",
|
|
||||||
method=method,
|
|
||||||
status=resp.status_code,
|
|
||||||
url=str(resp.request.url),
|
|
||||||
error=str(e),
|
|
||||||
error_type=e.__class__.__name__,
|
|
||||||
body=body,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not isinstance(payload, dict):
|
if not isinstance(payload, dict):
|
||||||
logger.error(
|
logger.error(
|
||||||
"telegram.invalid_payload",
|
"telegram.invalid_payload",
|
||||||
@@ -768,145 +722,193 @@ class TelegramClient:
|
|||||||
logger.debug("telegram.response", method=method, payload=payload)
|
logger.debug("telegram.response", method=method, payload=payload)
|
||||||
return payload.get("result")
|
return payload.get("result")
|
||||||
|
|
||||||
|
async def _request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
*,
|
||||||
|
json: dict[str, Any] | None = None,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
files: dict[str, Any] | None = None,
|
||||||
|
) -> Any | None:
|
||||||
|
if self._http_client is None or self._base is None:
|
||||||
|
raise RuntimeError("TelegramClient is configured without an HTTP client.")
|
||||||
|
request_payload = json if json is not None else data
|
||||||
|
logger.debug("telegram.request", method=method, payload=request_payload)
|
||||||
|
try:
|
||||||
|
if json is not None:
|
||||||
|
resp = await self._http_client.post(f"{self._base}/{method}", json=json)
|
||||||
|
else:
|
||||||
|
resp = await self._http_client.post(
|
||||||
|
f"{self._base}/{method}", data=data, files=files
|
||||||
|
)
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
url = getattr(exc.request, "url", None)
|
||||||
|
logger.error(
|
||||||
|
"telegram.network_error",
|
||||||
|
method=method,
|
||||||
|
url=str(url) if url is not None else None,
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if resp.status_code == 429:
|
||||||
|
retry_after: float | None = None
|
||||||
|
try:
|
||||||
|
response_payload = resp.json()
|
||||||
|
except Exception:
|
||||||
|
response_payload = None
|
||||||
|
if isinstance(response_payload, dict):
|
||||||
|
retry_after = retry_after_from_payload(response_payload)
|
||||||
|
retry_after = 5.0 if retry_after is None else retry_after
|
||||||
|
logger.warning(
|
||||||
|
"telegram.rate_limited",
|
||||||
|
method=method,
|
||||||
|
status=resp.status_code,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
retry_after=retry_after,
|
||||||
|
)
|
||||||
|
raise TelegramRetryAfter(retry_after) from exc
|
||||||
|
body = resp.text
|
||||||
|
logger.error(
|
||||||
|
"telegram.http_error",
|
||||||
|
method=method,
|
||||||
|
status=resp.status_code,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
error=str(exc),
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_payload = resp.json()
|
||||||
|
except Exception as exc:
|
||||||
|
body = resp.text
|
||||||
|
logger.error(
|
||||||
|
"telegram.bad_response",
|
||||||
|
method=method,
|
||||||
|
status=resp.status_code,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._parse_telegram_envelope(
|
||||||
|
method=method,
|
||||||
|
resp=resp,
|
||||||
|
payload=response_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _decode_result(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
method: str,
|
||||||
|
payload: Any,
|
||||||
|
model: type[T],
|
||||||
|
) -> T | None:
|
||||||
|
if payload is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return msgspec.convert(payload, type=model)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"telegram.decode_error",
|
||||||
|
method=method,
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _call_with_retry_after(
|
||||||
|
self,
|
||||||
|
fn: Callable[[], Awaitable[T]],
|
||||||
|
) -> T:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return await fn()
|
||||||
|
except TelegramRetryAfter as exc:
|
||||||
|
await self._sleep(exc.retry_after)
|
||||||
|
|
||||||
|
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
|
||||||
|
return await self._request(method, json=json_data)
|
||||||
|
|
||||||
async def _post_form(
|
async def _post_form(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
data: dict[str, Any],
|
data: dict[str, Any],
|
||||||
files: dict[str, Any],
|
files: dict[str, Any],
|
||||||
) -> Any | None:
|
) -> Any | None:
|
||||||
if self._http_client is None or self._base is None:
|
return await self._request(method, data=data, files=files)
|
||||||
raise RuntimeError("TelegramClient is configured without an HTTP client.")
|
|
||||||
logger.debug("telegram.request", method=method, payload=data)
|
|
||||||
try:
|
|
||||||
resp = await self._http_client.post(
|
|
||||||
f"{self._base}/{method}", data=data, files=files
|
|
||||||
)
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
url = getattr(e.request, "url", None)
|
|
||||||
logger.error(
|
|
||||||
"telegram.network_error",
|
|
||||||
method=method,
|
|
||||||
url=str(url) if url is not None else None,
|
|
||||||
error=str(e),
|
|
||||||
error_type=e.__class__.__name__,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
resp.raise_for_status()
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
if resp.status_code == 429:
|
|
||||||
retry_after: float | None = None
|
|
||||||
try:
|
|
||||||
payload = resp.json()
|
|
||||||
except Exception:
|
|
||||||
payload = None
|
|
||||||
if isinstance(payload, dict):
|
|
||||||
retry_after = retry_after_from_payload(payload)
|
|
||||||
retry_after = 5.0 if retry_after is None else retry_after
|
|
||||||
logger.warning(
|
|
||||||
"telegram.rate_limited",
|
|
||||||
method=method,
|
|
||||||
status=resp.status_code,
|
|
||||||
url=str(resp.request.url),
|
|
||||||
retry_after=retry_after,
|
|
||||||
)
|
|
||||||
raise TelegramRetryAfter(retry_after) from e
|
|
||||||
body = resp.text
|
|
||||||
logger.error(
|
|
||||||
"telegram.http_error",
|
|
||||||
method=method,
|
|
||||||
status=resp.status_code,
|
|
||||||
url=str(resp.request.url),
|
|
||||||
error=str(e),
|
|
||||||
body=body,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
payload = resp.json()
|
|
||||||
except Exception as e:
|
|
||||||
body = resp.text
|
|
||||||
logger.error(
|
|
||||||
"telegram.bad_response",
|
|
||||||
method=method,
|
|
||||||
status=resp.status_code,
|
|
||||||
error=str(e),
|
|
||||||
error_type=e.__class__.__name__,
|
|
||||||
body=body,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not isinstance(payload, dict):
|
|
||||||
logger.error(
|
|
||||||
"telegram.invalid_payload",
|
|
||||||
method=method,
|
|
||||||
url=str(resp.request.url),
|
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not payload.get("ok"):
|
|
||||||
if payload.get("error_code") == 429:
|
|
||||||
retry_after = retry_after_from_payload(payload)
|
|
||||||
retry_after = 5.0 if retry_after is None else retry_after
|
|
||||||
logger.warning(
|
|
||||||
"telegram.rate_limited",
|
|
||||||
method=method,
|
|
||||||
url=str(resp.request.url),
|
|
||||||
retry_after=retry_after,
|
|
||||||
)
|
|
||||||
raise TelegramRetryAfter(retry_after)
|
|
||||||
logger.error(
|
|
||||||
"telegram.api_error",
|
|
||||||
method=method,
|
|
||||||
url=str(resp.request.url),
|
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.debug("telegram.response", method=method, payload=payload)
|
|
||||||
return payload.get("result")
|
|
||||||
|
|
||||||
async def get_updates(
|
async def get_updates(
|
||||||
self,
|
self,
|
||||||
offset: int | None,
|
offset: int | None,
|
||||||
timeout_s: int = 50,
|
timeout_s: int = 50,
|
||||||
allowed_updates: list[str] | None = None,
|
allowed_updates: list[str] | None = None,
|
||||||
) -> list[dict] | None:
|
) -> list[Update] | None:
|
||||||
while True:
|
async def execute() -> list[Update] | None:
|
||||||
try:
|
|
||||||
if self._client_override is not None:
|
if self._client_override is not None:
|
||||||
return await self._client_override.get_updates(
|
raw = await self._client_override.get_updates(
|
||||||
offset=offset,
|
offset=offset,
|
||||||
timeout_s=timeout_s,
|
timeout_s=timeout_s,
|
||||||
allowed_updates=allowed_updates,
|
allowed_updates=allowed_updates,
|
||||||
)
|
)
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return msgspec.convert(raw, type=list[Update])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"telegram.decode_error",
|
||||||
|
method="getUpdates",
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
params: dict[str, Any] = {"timeout": timeout_s}
|
params: dict[str, Any] = {"timeout": timeout_s}
|
||||||
if offset is not None:
|
if offset is not None:
|
||||||
params["offset"] = offset
|
params["offset"] = offset
|
||||||
if allowed_updates is not None:
|
if allowed_updates is not None:
|
||||||
params["allowed_updates"] = allowed_updates
|
params["allowed_updates"] = allowed_updates
|
||||||
result = await self._post("getUpdates", params)
|
result = await self._post("getUpdates", params)
|
||||||
return result if isinstance(result, list) else None
|
if result is None or not isinstance(result, list):
|
||||||
except TelegramRetryAfter as exc:
|
return None
|
||||||
await self._sleep(exc.retry_after)
|
|
||||||
|
|
||||||
async def get_file(self, file_id: str) -> dict | None:
|
|
||||||
while True:
|
|
||||||
try:
|
try:
|
||||||
|
return msgspec.convert(result, type=list[Update])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"telegram.decode_error",
|
||||||
|
method="getUpdates",
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await self._call_with_retry_after(execute)
|
||||||
|
|
||||||
|
async def get_file(self, file_id: str) -> File | None:
|
||||||
|
async def execute() -> File | None:
|
||||||
if self._client_override is not None:
|
if self._client_override is not None:
|
||||||
return await self._client_override.get_file(file_id)
|
return await self._client_override.get_file(file_id)
|
||||||
result = await self._post("getFile", {"file_id": file_id})
|
result = await self._post("getFile", {"file_id": file_id})
|
||||||
return result if isinstance(result, dict) else None
|
return self._decode_result(method="getFile", payload=result, model=File)
|
||||||
except TelegramRetryAfter as exc:
|
|
||||||
await self._sleep(exc.retry_after)
|
return await self._call_with_retry_after(execute)
|
||||||
|
|
||||||
async def download_file(self, file_path: str) -> bytes | None:
|
async def download_file(self, file_path: str) -> bytes | None:
|
||||||
|
async def execute() -> bytes | None:
|
||||||
if self._client_override is not None:
|
if self._client_override is not None:
|
||||||
return await self._client_override.download_file(file_path)
|
return await self._client_override.download_file(file_path)
|
||||||
if self._http_client is None or self._file_base is None:
|
if self._http_client is None or self._file_base is None:
|
||||||
raise RuntimeError("TelegramClient is configured without an HTTP client.")
|
raise RuntimeError(
|
||||||
|
"TelegramClient is configured without an HTTP client."
|
||||||
|
)
|
||||||
url = f"{self._file_base}/{file_path}"
|
url = f"{self._file_base}/{file_path}"
|
||||||
try:
|
try:
|
||||||
resp = await self._http_client.get(url)
|
resp = await self._http_client.get(url)
|
||||||
@@ -922,6 +924,24 @@ class TelegramClient:
|
|||||||
try:
|
try:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
except httpx.HTTPStatusError as exc:
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if resp.status_code == 429:
|
||||||
|
retry_after: float | None = None
|
||||||
|
try:
|
||||||
|
response_payload = resp.json()
|
||||||
|
except Exception:
|
||||||
|
response_payload = None
|
||||||
|
if isinstance(response_payload, dict):
|
||||||
|
retry_after = retry_after_from_payload(response_payload)
|
||||||
|
retry_after = 5.0 if retry_after is None else retry_after
|
||||||
|
logger.warning(
|
||||||
|
"telegram.rate_limited",
|
||||||
|
method="download_file",
|
||||||
|
status=resp.status_code,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
retry_after=retry_after,
|
||||||
|
)
|
||||||
|
raise TelegramRetryAfter(retry_after) from exc
|
||||||
|
|
||||||
logger.error(
|
logger.error(
|
||||||
"telegram.file_http_error",
|
"telegram.file_http_error",
|
||||||
status=resp.status_code,
|
status=resp.status_code,
|
||||||
@@ -932,6 +952,8 @@ class TelegramClient:
|
|||||||
return None
|
return None
|
||||||
return resp.content
|
return resp.content
|
||||||
|
|
||||||
|
return await self._call_with_retry_after(execute)
|
||||||
|
|
||||||
async def send_message(
|
async def send_message(
|
||||||
self,
|
self,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
@@ -944,8 +966,8 @@ class TelegramClient:
|
|||||||
reply_markup: dict[str, Any] | None = None,
|
reply_markup: dict[str, Any] | None = None,
|
||||||
*,
|
*,
|
||||||
replace_message_id: int | None = None,
|
replace_message_id: int | None = None,
|
||||||
) -> dict | None:
|
) -> Message | None:
|
||||||
async def execute() -> dict | None:
|
async def execute() -> Message | None:
|
||||||
if self._client_override is not None:
|
if self._client_override is not None:
|
||||||
return await self._client_override.send_message(
|
return await self._client_override.send_message(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
@@ -972,7 +994,11 @@ class TelegramClient:
|
|||||||
if reply_markup is not None:
|
if reply_markup is not None:
|
||||||
params["reply_markup"] = reply_markup
|
params["reply_markup"] = reply_markup
|
||||||
result = await self._post("sendMessage", params)
|
result = await self._post("sendMessage", params)
|
||||||
return result if isinstance(result, dict) else None
|
return self._decode_result(
|
||||||
|
method="sendMessage",
|
||||||
|
payload=result,
|
||||||
|
model=Message,
|
||||||
|
)
|
||||||
|
|
||||||
if replace_message_id is not None:
|
if replace_message_id is not None:
|
||||||
await self._outbox.drop_pending(key=("edit", chat_id, replace_message_id))
|
await self._outbox.drop_pending(key=("edit", chat_id, replace_message_id))
|
||||||
@@ -1000,8 +1026,8 @@ class TelegramClient:
|
|||||||
message_thread_id: int | None = None,
|
message_thread_id: int | None = None,
|
||||||
disable_notification: bool | None = False,
|
disable_notification: bool | None = False,
|
||||||
caption: str | None = None,
|
caption: str | None = None,
|
||||||
) -> dict | None:
|
) -> Message | None:
|
||||||
async def execute() -> dict | None:
|
async def execute() -> Message | None:
|
||||||
if self._client_override is not None:
|
if self._client_override is not None:
|
||||||
return await self._client_override.send_document(
|
return await self._client_override.send_document(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
@@ -1026,7 +1052,11 @@ class TelegramClient:
|
|||||||
params,
|
params,
|
||||||
files={"document": (filename, content)},
|
files={"document": (filename, content)},
|
||||||
)
|
)
|
||||||
return result if isinstance(result, dict) else None
|
return self._decode_result(
|
||||||
|
method="sendDocument",
|
||||||
|
payload=result,
|
||||||
|
model=Message,
|
||||||
|
)
|
||||||
|
|
||||||
return await self.enqueue_op(
|
return await self.enqueue_op(
|
||||||
key=self.unique_key("send_document"),
|
key=self.unique_key("send_document"),
|
||||||
@@ -1046,8 +1076,8 @@ class TelegramClient:
|
|||||||
reply_markup: dict[str, Any] | None = None,
|
reply_markup: dict[str, Any] | None = None,
|
||||||
*,
|
*,
|
||||||
wait: bool = True,
|
wait: bool = True,
|
||||||
) -> dict | None:
|
) -> Message | None:
|
||||||
async def execute() -> dict | None:
|
async def execute() -> Message | None:
|
||||||
if self._client_override is not None:
|
if self._client_override is not None:
|
||||||
return await self._client_override.edit_message_text(
|
return await self._client_override.edit_message_text(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
@@ -1070,7 +1100,11 @@ class TelegramClient:
|
|||||||
if reply_markup is not None:
|
if reply_markup is not None:
|
||||||
params["reply_markup"] = reply_markup
|
params["reply_markup"] = reply_markup
|
||||||
result = await self._post("editMessageText", params)
|
result = await self._post("editMessageText", params)
|
||||||
return result if isinstance(result, dict) else None
|
return self._decode_result(
|
||||||
|
method="editMessageText",
|
||||||
|
payload=result,
|
||||||
|
model=Message,
|
||||||
|
)
|
||||||
|
|
||||||
return await self.enqueue_op(
|
return await self.enqueue_op(
|
||||||
key=("edit", chat_id, message_id),
|
key=("edit", chat_id, message_id),
|
||||||
@@ -1142,12 +1176,12 @@ class TelegramClient:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_me(self) -> dict | None:
|
async def get_me(self) -> User | None:
|
||||||
async def execute() -> dict | None:
|
async def execute() -> User | None:
|
||||||
if self._client_override is not None:
|
if self._client_override is not None:
|
||||||
return await self._client_override.get_me()
|
return await self._client_override.get_me()
|
||||||
result = await self._post("getMe", {})
|
result = await self._post("getMe", {})
|
||||||
return result if isinstance(result, dict) else None
|
return self._decode_result(method="getMe", payload=result, model=User)
|
||||||
|
|
||||||
return await self.enqueue_op(
|
return await self.enqueue_op(
|
||||||
key=self.unique_key("get_me"),
|
key=self.unique_key("get_me"),
|
||||||
@@ -1188,12 +1222,12 @@ class TelegramClient:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_chat(self, chat_id: int) -> dict | None:
|
async def get_chat(self, chat_id: int) -> Chat | None:
|
||||||
async def execute() -> dict | None:
|
async def execute() -> Chat | None:
|
||||||
if self._client_override is not None:
|
if self._client_override is not None:
|
||||||
return await self._client_override.get_chat(chat_id)
|
return await self._client_override.get_chat(chat_id)
|
||||||
result = await self._post("getChat", {"chat_id": chat_id})
|
result = await self._post("getChat", {"chat_id": chat_id})
|
||||||
return result if isinstance(result, dict) else None
|
return self._decode_result(method="getChat", payload=result, model=Chat)
|
||||||
|
|
||||||
return await self.enqueue_op(
|
return await self.enqueue_op(
|
||||||
key=self.unique_key("get_chat"),
|
key=self.unique_key("get_chat"),
|
||||||
@@ -1203,14 +1237,18 @@ class TelegramClient:
|
|||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None:
|
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
|
||||||
async def execute() -> dict | None:
|
async def execute() -> ChatMember | None:
|
||||||
if self._client_override is not None:
|
if self._client_override is not None:
|
||||||
return await self._client_override.get_chat_member(chat_id, user_id)
|
return await self._client_override.get_chat_member(chat_id, user_id)
|
||||||
result = await self._post(
|
result = await self._post(
|
||||||
"getChatMember", {"chat_id": chat_id, "user_id": user_id}
|
"getChatMember", {"chat_id": chat_id, "user_id": user_id}
|
||||||
)
|
)
|
||||||
return result if isinstance(result, dict) else None
|
return self._decode_result(
|
||||||
|
method="getChatMember",
|
||||||
|
payload=result,
|
||||||
|
model=ChatMember,
|
||||||
|
)
|
||||||
|
|
||||||
return await self.enqueue_op(
|
return await self.enqueue_op(
|
||||||
key=self.unique_key("get_chat_member"),
|
key=self.unique_key("get_chat_member"),
|
||||||
@@ -1220,14 +1258,18 @@ class TelegramClient:
|
|||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_forum_topic(self, chat_id: int, name: str) -> dict | None:
|
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
|
||||||
async def execute() -> dict | None:
|
async def execute() -> ForumTopic | None:
|
||||||
if self._client_override is not None:
|
if self._client_override is not None:
|
||||||
return await self._client_override.create_forum_topic(chat_id, name)
|
return await self._client_override.create_forum_topic(chat_id, name)
|
||||||
result = await self._post(
|
result = await self._post(
|
||||||
"createForumTopic", {"chat_id": chat_id, "name": name}
|
"createForumTopic", {"chat_id": chat_id, "name": name}
|
||||||
)
|
)
|
||||||
return result if isinstance(result, dict) else None
|
return self._decode_result(
|
||||||
|
method="createForumTopic",
|
||||||
|
payload=result,
|
||||||
|
model=ForumTopic,
|
||||||
|
)
|
||||||
|
|
||||||
return await self.enqueue_op(
|
return await self.enqueue_op(
|
||||||
key=self.unique_key("create_forum_topic"),
|
key=self.unique_key("create_forum_topic"),
|
||||||
|
|||||||
@@ -289,11 +289,10 @@ async def _check_file_permissions(cfg, msg: TelegramIncomingMessage) -> bool:
|
|||||||
if is_private:
|
if is_private:
|
||||||
return True
|
return True
|
||||||
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
|
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
|
||||||
if not isinstance(member, dict):
|
if member is None:
|
||||||
await reply(text="failed to verify file transfer permissions.")
|
await reply(text="failed to verify file transfer permissions.")
|
||||||
return False
|
return False
|
||||||
status = member.get("status")
|
if member.status in {"creator", "administrator"}:
|
||||||
if status in {"creator", "administrator"}:
|
|
||||||
return True
|
return True
|
||||||
await reply(text="file transfer is restricted to group admins.")
|
await reply(text="file transfer is restricted to group admins.")
|
||||||
return False
|
return False
|
||||||
@@ -377,21 +376,14 @@ async def _save_document_payload(
|
|||||||
error="file is too large to upload.",
|
error="file is too large to upload.",
|
||||||
)
|
)
|
||||||
file_info = await cfg.bot.get_file(document.file_id)
|
file_info = await cfg.bot.get_file(document.file_id)
|
||||||
if not isinstance(file_info, dict):
|
if file_info is None:
|
||||||
return _FilePutResult(
|
|
||||||
name=name,
|
|
||||||
rel_path=None,
|
|
||||||
size=None,
|
|
||||||
error="failed to fetch file metadata.",
|
|
||||||
)
|
|
||||||
file_path = file_info.get("file_path")
|
|
||||||
if not isinstance(file_path, str) or not file_path:
|
|
||||||
return _FilePutResult(
|
return _FilePutResult(
|
||||||
name=name,
|
name=name,
|
||||||
rel_path=None,
|
rel_path=None,
|
||||||
size=None,
|
size=None,
|
||||||
error="failed to fetch file metadata.",
|
error="failed to fetch file metadata.",
|
||||||
)
|
)
|
||||||
|
file_path = file_info.file_path
|
||||||
name = default_upload_name(document.file_name, file_path)
|
name = default_upload_name(document.file_name, file_path)
|
||||||
resolved_path = rel_path
|
resolved_path = rel_path
|
||||||
if resolved_path is None:
|
if resolved_path is None:
|
||||||
@@ -972,10 +964,10 @@ async def _handle_topic_command(
|
|||||||
return
|
return
|
||||||
title = _topic_title(runtime=cfg.runtime, context=context)
|
title = _topic_title(runtime=cfg.runtime, context=context)
|
||||||
created = await cfg.bot.create_forum_topic(msg.chat_id, title)
|
created = await cfg.bot.create_forum_topic(msg.chat_id, title)
|
||||||
thread_id = created.get("message_thread_id") if isinstance(created, dict) else None
|
if created is None:
|
||||||
if isinstance(thread_id, bool) or not isinstance(thread_id, int):
|
|
||||||
await reply(text="failed to create topic.")
|
await reply(text="failed to create topic.")
|
||||||
return
|
return
|
||||||
|
thread_id = created.message_thread_id
|
||||||
await store.set_context(
|
await store.set_context(
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
thread_id,
|
thread_id,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from ..directives import DirectiveError
|
|||||||
from ..logging import get_logger
|
from ..logging import get_logger
|
||||||
from ..model import EngineId, ResumeToken
|
from ..model import EngineId, ResumeToken
|
||||||
from ..scheduler import ThreadJob, ThreadScheduler
|
from ..scheduler import ThreadJob, ThreadScheduler
|
||||||
|
from ..settings import TelegramTransportSettings
|
||||||
from ..transport import MessageRef
|
from ..transport import MessageRef
|
||||||
from ..context import RunContext
|
from ..context import RunContext
|
||||||
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
|
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
|
||||||
@@ -169,7 +170,7 @@ async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int |
|
|||||||
if drained:
|
if drained:
|
||||||
logger.info("startup.backlog.drained", count=drained)
|
logger.info("startup.backlog.drained", count=drained)
|
||||||
return offset
|
return offset
|
||||||
offset = updates[-1]["update_id"] + 1
|
offset = updates[-1].update_id + 1
|
||||||
drained += len(updates)
|
drained += len(updates)
|
||||||
|
|
||||||
|
|
||||||
@@ -266,7 +267,7 @@ async def run_main_loop(
|
|||||||
watch_config: bool | None = None,
|
watch_config: bool | None = None,
|
||||||
default_engine_override: str | None = None,
|
default_engine_override: str | None = None,
|
||||||
transport_id: str | None = None,
|
transport_id: str | None = None,
|
||||||
transport_config: dict[str, object] | None = None,
|
transport_config: TelegramTransportSettings | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
from ..runner_bridge import RunningTasks
|
from ..runner_bridge import RunningTasks
|
||||||
|
|
||||||
@@ -277,7 +278,7 @@ async def run_main_loop(
|
|||||||
}
|
}
|
||||||
reserved_commands = _reserved_commands(cfg.runtime)
|
reserved_commands = _reserved_commands(cfg.runtime)
|
||||||
transport_snapshot = (
|
transport_snapshot = (
|
||||||
dict(transport_config) if transport_config is not None else None
|
transport_config.model_dump() if transport_config is not None else None
|
||||||
)
|
)
|
||||||
topic_store: TopicStateStore | None = None
|
topic_store: TopicStateStore | None = None
|
||||||
media_groups: dict[tuple[int, str], _MediaGroupState] = {}
|
media_groups: dict[tuple[int, str], _MediaGroupState] = {}
|
||||||
@@ -476,6 +477,7 @@ async def run_main_loop(
|
|||||||
bot=cfg.bot,
|
bot=cfg.bot,
|
||||||
msg=msg,
|
msg=msg,
|
||||||
enabled=cfg.voice_transcription,
|
enabled=cfg.voice_transcription,
|
||||||
|
max_bytes=cfg.voice_max_bytes,
|
||||||
reply=reply,
|
reply=reply,
|
||||||
)
|
)
|
||||||
if text is None:
|
if text is None:
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from ..engines import list_backends
|
|||||||
from ..logging import suppress_logs
|
from ..logging import suppress_logs
|
||||||
from ..settings import HOME_CONFIG_PATH, load_settings, require_telegram
|
from ..settings import HOME_CONFIG_PATH, load_settings, require_telegram
|
||||||
from ..transports import SetupResult
|
from ..transports import SetupResult
|
||||||
|
from .api_models import User
|
||||||
from .client import TelegramClient, TelegramRetryAfter
|
from .client import TelegramClient, TelegramRetryAfter
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -131,7 +132,7 @@ def mask_token(token: str) -> str:
|
|||||||
return f"{token[:9]}...{token[-5:]}"
|
return f"{token[:9]}...{token[-5:]}"
|
||||||
|
|
||||||
|
|
||||||
async def get_bot_info(token: str) -> dict[str, Any] | None:
|
async def get_bot_info(token: str) -> User | None:
|
||||||
bot = TelegramClient(token)
|
bot = TelegramClient(token)
|
||||||
try:
|
try:
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
@@ -153,23 +154,19 @@ async def wait_for_chat(token: str) -> ChatInfo:
|
|||||||
offset=None, timeout_s=0, allowed_updates=allowed_updates
|
offset=None, timeout_s=0, allowed_updates=allowed_updates
|
||||||
)
|
)
|
||||||
if drained:
|
if drained:
|
||||||
offset = drained[-1]["update_id"] + 1
|
offset = drained[-1].update_id + 1
|
||||||
while True:
|
while True:
|
||||||
try:
|
|
||||||
updates = await bot.get_updates(
|
updates = await bot.get_updates(
|
||||||
offset=offset, timeout_s=50, allowed_updates=allowed_updates
|
offset=offset, timeout_s=50, allowed_updates=allowed_updates
|
||||||
)
|
)
|
||||||
except TelegramRetryAfter as exc:
|
|
||||||
await anyio.sleep(exc.retry_after)
|
|
||||||
continue
|
|
||||||
if updates is None:
|
if updates is None:
|
||||||
await anyio.sleep(1)
|
await anyio.sleep(1)
|
||||||
continue
|
continue
|
||||||
if not updates:
|
if not updates:
|
||||||
continue
|
continue
|
||||||
offset = updates[-1]["update_id"] + 1
|
offset = updates[-1].update_id + 1
|
||||||
update = updates[-1]
|
update = updates[-1]
|
||||||
msg = update.get("message")
|
msg = update.message
|
||||||
if not isinstance(msg, dict):
|
if not isinstance(msg, dict):
|
||||||
continue
|
continue
|
||||||
sender = msg.get("from")
|
sender = msg.get("from")
|
||||||
@@ -298,7 +295,7 @@ def _confirm(message: str, *, default: bool = True) -> bool | None:
|
|||||||
return question.ask()
|
return question.ask()
|
||||||
|
|
||||||
|
|
||||||
def _prompt_token(console: Console) -> tuple[str, dict[str, Any]] | None:
|
def _prompt_token(console: Console) -> tuple[str, User] | None:
|
||||||
while True:
|
while True:
|
||||||
token = questionary.password("paste your bot token:").ask()
|
token = questionary.password("paste your bot token:").ask()
|
||||||
if token is None:
|
if token is None:
|
||||||
@@ -310,11 +307,10 @@ def _prompt_token(console: Console) -> tuple[str, dict[str, Any]] | None:
|
|||||||
console.print(" validating...")
|
console.print(" validating...")
|
||||||
info = anyio.run(get_bot_info, token)
|
info = anyio.run(get_bot_info, token)
|
||||||
if info:
|
if info:
|
||||||
username = info.get("username")
|
if info.username:
|
||||||
if isinstance(username, str) and username:
|
console.print(f" connected to @{info.username}")
|
||||||
console.print(f" connected to @{username}")
|
|
||||||
else:
|
else:
|
||||||
name = info.get("first_name") or "your bot"
|
name = info.first_name or "your bot"
|
||||||
console.print(f" connected to {name}")
|
console.print(f" connected to {name}")
|
||||||
return token, info
|
return token, info
|
||||||
console.print(" failed to connect, check the token and try again")
|
console.print(" failed to connect, check the token and try again")
|
||||||
@@ -342,7 +338,7 @@ def capture_chat_id(*, token: str | None = None) -> ChatInfo | None:
|
|||||||
return None
|
return None
|
||||||
token, info = token_info
|
token, info = token_info
|
||||||
|
|
||||||
bot_ref = f"@{info['username']}"
|
bot_ref = f"@{info.username}"
|
||||||
console.print("")
|
console.print("")
|
||||||
console.print(f" send /start to {bot_ref} (works in groups too)")
|
console.print(f" send /start to {bot_ref} (works in groups too)")
|
||||||
console.print(" waiting...")
|
console.print(" waiting...")
|
||||||
@@ -401,7 +397,7 @@ def interactive_setup(*, force: bool) -> bool:
|
|||||||
if token_info is None:
|
if token_info is None:
|
||||||
return False
|
return False
|
||||||
token, info = token_info
|
token, info = token_info
|
||||||
bot_ref = f"@{info['username']}"
|
bot_ref = f"@{info.username}"
|
||||||
|
|
||||||
console.print("")
|
console.print("")
|
||||||
console.print(f" send /start to {bot_ref} (works in groups too)")
|
console.print(f" send /start to {bot_ref} (works in groups too)")
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ import json
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
|
import msgspec
|
||||||
|
|
||||||
from ..context import RunContext
|
from ..context import RunContext
|
||||||
from ..logging import get_logger
|
from ..logging import get_logger
|
||||||
@@ -27,6 +27,26 @@ class TopicThreadSnapshot:
|
|||||||
topic_title: str | None
|
topic_title: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class _ContextState(msgspec.Struct, forbid_unknown_fields=False):
|
||||||
|
project: str | None = None
|
||||||
|
branch: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _SessionState(msgspec.Struct, forbid_unknown_fields=False):
|
||||||
|
resume: str
|
||||||
|
|
||||||
|
|
||||||
|
class _ThreadState(msgspec.Struct, forbid_unknown_fields=False):
|
||||||
|
context: _ContextState | None = None
|
||||||
|
sessions: dict[str, _SessionState] = msgspec.field(default_factory=dict)
|
||||||
|
topic_title: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _TopicState(msgspec.Struct, forbid_unknown_fields=False):
|
||||||
|
version: int
|
||||||
|
threads: dict[str, _ThreadState] = msgspec.field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
def resolve_state_path(config_path: Path) -> Path:
|
def resolve_state_path(config_path: Path) -> Path:
|
||||||
return config_path.with_name(STATE_FILENAME)
|
return config_path.with_name(STATE_FILENAME)
|
||||||
|
|
||||||
@@ -35,34 +55,31 @@ def _thread_key(chat_id: int, thread_id: int) -> str:
|
|||||||
return f"{chat_id}:{thread_id}"
|
return f"{chat_id}:{thread_id}"
|
||||||
|
|
||||||
|
|
||||||
def _parse_context(raw: object) -> RunContext | None:
|
def _normalize_text(value: str | None) -> str | None:
|
||||||
if not isinstance(raw, dict):
|
if value is None:
|
||||||
return None
|
return None
|
||||||
payload = cast(dict[str, object], raw)
|
value = value.strip()
|
||||||
project = payload.get("project")
|
return value or None
|
||||||
branch = payload.get("branch")
|
|
||||||
if project is not None and not isinstance(project, str):
|
|
||||||
project = None
|
def _context_from_state(state: _ContextState | None) -> RunContext | None:
|
||||||
if isinstance(project, str):
|
if state is None:
|
||||||
project = project.strip() or None
|
return None
|
||||||
if branch is not None and not isinstance(branch, str):
|
project = _normalize_text(state.project)
|
||||||
branch = None
|
branch = _normalize_text(state.branch)
|
||||||
if isinstance(branch, str):
|
|
||||||
branch = branch.strip() or None
|
|
||||||
if project is None and branch is None:
|
if project is None and branch is None:
|
||||||
return None
|
return None
|
||||||
return RunContext(project=project, branch=branch)
|
return RunContext(project=project, branch=branch)
|
||||||
|
|
||||||
|
|
||||||
def _dump_context(context: RunContext | None) -> dict[str, str] | None:
|
def _context_to_state(context: RunContext | None) -> _ContextState | None:
|
||||||
if context is None or (context.project is None and context.branch is None):
|
if context is None:
|
||||||
return None
|
return None
|
||||||
payload: dict[str, str] = {}
|
project = _normalize_text(context.project)
|
||||||
if context.project is not None:
|
branch = _normalize_text(context.branch)
|
||||||
payload["project"] = context.project
|
if project is None and branch is None:
|
||||||
if context.branch is not None:
|
return None
|
||||||
payload["branch"] = context.branch
|
return _ContextState(project=project, branch=branch)
|
||||||
return payload or None
|
|
||||||
|
|
||||||
|
|
||||||
class TopicStateStore:
|
class TopicStateStore:
|
||||||
@@ -71,10 +88,7 @@ class TopicStateStore:
|
|||||||
self._lock = anyio.Lock()
|
self._lock = anyio.Lock()
|
||||||
self._loaded = False
|
self._loaded = False
|
||||||
self._mtime_ns: int | None = None
|
self._mtime_ns: int | None = None
|
||||||
self._data: dict[str, Any] = {
|
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||||
"version": STATE_VERSION,
|
|
||||||
"threads": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
async def get_thread(
|
async def get_thread(
|
||||||
self, chat_id: int, thread_id: int
|
self, chat_id: int, thread_id: int
|
||||||
@@ -92,7 +106,7 @@ class TopicStateStore:
|
|||||||
thread = self._get_thread_locked(chat_id, thread_id)
|
thread = self._get_thread_locked(chat_id, thread_id)
|
||||||
if thread is None:
|
if thread is None:
|
||||||
return None
|
return None
|
||||||
return _parse_context(thread.get("context"))
|
return _context_from_state(thread.context)
|
||||||
|
|
||||||
async def set_context(
|
async def set_context(
|
||||||
self,
|
self,
|
||||||
@@ -105,9 +119,9 @@ class TopicStateStore:
|
|||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._reload_locked_if_needed()
|
self._reload_locked_if_needed()
|
||||||
thread = self._ensure_thread_locked(chat_id, thread_id)
|
thread = self._ensure_thread_locked(chat_id, thread_id)
|
||||||
thread["context"] = _dump_context(context)
|
thread.context = _context_to_state(context)
|
||||||
if topic_title is not None:
|
if topic_title is not None:
|
||||||
thread["topic_title"] = topic_title
|
thread.topic_title = topic_title
|
||||||
self._save_locked()
|
self._save_locked()
|
||||||
|
|
||||||
async def clear_context(self, chat_id: int, thread_id: int) -> None:
|
async def clear_context(self, chat_id: int, thread_id: int) -> None:
|
||||||
@@ -116,7 +130,7 @@ class TopicStateStore:
|
|||||||
thread = self._get_thread_locked(chat_id, thread_id)
|
thread = self._get_thread_locked(chat_id, thread_id)
|
||||||
if thread is None:
|
if thread is None:
|
||||||
return
|
return
|
||||||
thread.pop("context", None)
|
thread.context = None
|
||||||
self._save_locked()
|
self._save_locked()
|
||||||
|
|
||||||
async def get_session_resume(
|
async def get_session_resume(
|
||||||
@@ -127,16 +141,10 @@ class TopicStateStore:
|
|||||||
thread = self._get_thread_locked(chat_id, thread_id)
|
thread = self._get_thread_locked(chat_id, thread_id)
|
||||||
if thread is None:
|
if thread is None:
|
||||||
return None
|
return None
|
||||||
sessions = thread.get("sessions")
|
entry = thread.sessions.get(engine)
|
||||||
if not isinstance(sessions, dict):
|
if entry is None or not entry.resume:
|
||||||
return None
|
return None
|
||||||
entry = sessions.get(engine)
|
return ResumeToken(engine=engine, value=entry.resume)
|
||||||
if not isinstance(entry, dict):
|
|
||||||
return None
|
|
||||||
value = entry.get("resume")
|
|
||||||
if not isinstance(value, str) or not value:
|
|
||||||
return None
|
|
||||||
return ResumeToken(engine=engine, value=value)
|
|
||||||
|
|
||||||
async def set_session_resume(
|
async def set_session_resume(
|
||||||
self, chat_id: int, thread_id: int, token: ResumeToken
|
self, chat_id: int, thread_id: int, token: ResumeToken
|
||||||
@@ -144,13 +152,7 @@ class TopicStateStore:
|
|||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._reload_locked_if_needed()
|
self._reload_locked_if_needed()
|
||||||
thread = self._ensure_thread_locked(chat_id, thread_id)
|
thread = self._ensure_thread_locked(chat_id, thread_id)
|
||||||
sessions = thread.get("sessions")
|
thread.sessions[token.engine] = _SessionState(resume=token.value)
|
||||||
if not isinstance(sessions, dict):
|
|
||||||
sessions = {}
|
|
||||||
thread["sessions"] = sessions
|
|
||||||
sessions[token.engine] = {
|
|
||||||
"resume": token.value,
|
|
||||||
}
|
|
||||||
self._save_locked()
|
self._save_locked()
|
||||||
|
|
||||||
async def clear_sessions(self, chat_id: int, thread_id: int) -> None:
|
async def clear_sessions(self, chat_id: int, thread_id: int) -> None:
|
||||||
@@ -159,7 +161,7 @@ class TopicStateStore:
|
|||||||
thread = self._get_thread_locked(chat_id, thread_id)
|
thread = self._get_thread_locked(chat_id, thread_id)
|
||||||
if thread is None:
|
if thread is None:
|
||||||
return
|
return
|
||||||
thread.pop("sessions", None)
|
thread.sessions = {}
|
||||||
self._save_locked()
|
self._save_locked()
|
||||||
|
|
||||||
async def find_thread_for_context(
|
async def find_thread_for_context(
|
||||||
@@ -167,47 +169,37 @@ class TopicStateStore:
|
|||||||
) -> int | None:
|
) -> int | None:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._reload_locked_if_needed()
|
self._reload_locked_if_needed()
|
||||||
threads = self._data.get("threads")
|
target_project = _normalize_text(context.project)
|
||||||
if not isinstance(threads, dict):
|
target_branch = _normalize_text(context.branch)
|
||||||
return None
|
for raw_key, thread in self._state.threads.items():
|
||||||
for raw_key, payload in threads.items():
|
if not raw_key.startswith(f"{chat_id}:"):
|
||||||
if not isinstance(raw_key, str) or not isinstance(payload, dict):
|
|
||||||
continue
|
continue
|
||||||
parsed = _parse_context(payload.get("context"))
|
parsed = _context_from_state(thread.context)
|
||||||
if parsed is None:
|
if parsed is None:
|
||||||
continue
|
continue
|
||||||
if parsed.project != context.project or parsed.branch != context.branch:
|
if parsed.project != target_project or parsed.branch != target_branch:
|
||||||
continue
|
|
||||||
if not raw_key.startswith(f"{chat_id}:"):
|
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
_, thread_str = raw_key.split(":", 1)
|
_, thread_str = raw_key.split(":", 1)
|
||||||
return int(thread_str)
|
return int(thread_str)
|
||||||
except (ValueError, TypeError):
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _snapshot_locked(
|
def _snapshot_locked(
|
||||||
self, thread: dict[str, Any], chat_id: int, thread_id: int
|
self, thread: _ThreadState, chat_id: int, thread_id: int
|
||||||
) -> TopicThreadSnapshot:
|
) -> TopicThreadSnapshot:
|
||||||
sessions: dict[str, str] = {}
|
sessions = {
|
||||||
raw_sessions = thread.get("sessions")
|
engine: entry.resume
|
||||||
if isinstance(raw_sessions, dict):
|
for engine, entry in thread.sessions.items()
|
||||||
for engine, entry in raw_sessions.items():
|
if entry.resume
|
||||||
if not isinstance(engine, str) or not isinstance(entry, dict):
|
}
|
||||||
continue
|
|
||||||
value = entry.get("resume")
|
|
||||||
if isinstance(value, str) and value:
|
|
||||||
sessions[engine] = value
|
|
||||||
topic_title = thread.get("topic_title")
|
|
||||||
if not isinstance(topic_title, str):
|
|
||||||
topic_title = None
|
|
||||||
return TopicThreadSnapshot(
|
return TopicThreadSnapshot(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
context=_parse_context(thread.get("context")),
|
context=_context_from_state(thread.context),
|
||||||
sessions=sessions,
|
sessions=sessions,
|
||||||
topic_title=topic_title,
|
topic_title=thread.topic_title,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _stat_mtime_ns(self) -> int | None:
|
def _stat_mtime_ns(self) -> int | None:
|
||||||
@@ -226,10 +218,10 @@ class TopicStateStore:
|
|||||||
self._loaded = True
|
self._loaded = True
|
||||||
self._mtime_ns = self._stat_mtime_ns()
|
self._mtime_ns = self._stat_mtime_ns()
|
||||||
if self._mtime_ns is None:
|
if self._mtime_ns is None:
|
||||||
self._data = {"version": STATE_VERSION, "threads": {}}
|
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
payload = json.loads(self._path.read_text(encoding="utf-8"))
|
payload = msgspec.json.decode(self._path.read_bytes(), type=_TopicState)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"telegram.topic_state.load_failed",
|
"telegram.topic_state.load_failed",
|
||||||
@@ -237,29 +229,22 @@ class TopicStateStore:
|
|||||||
error=str(exc),
|
error=str(exc),
|
||||||
error_type=exc.__class__.__name__,
|
error_type=exc.__class__.__name__,
|
||||||
)
|
)
|
||||||
self._data = {"version": STATE_VERSION, "threads": {}}
|
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||||
return
|
return
|
||||||
if not isinstance(payload, dict):
|
if payload.version != STATE_VERSION:
|
||||||
self._data = {"version": STATE_VERSION, "threads": {}}
|
|
||||||
return
|
|
||||||
version = payload.get("version")
|
|
||||||
if version != STATE_VERSION:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"telegram.topic_state.version_mismatch",
|
"telegram.topic_state.version_mismatch",
|
||||||
path=str(self._path),
|
path=str(self._path),
|
||||||
version=version,
|
version=payload.version,
|
||||||
expected=STATE_VERSION,
|
expected=STATE_VERSION,
|
||||||
)
|
)
|
||||||
self._data = {"version": STATE_VERSION, "threads": {}}
|
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||||
return
|
return
|
||||||
threads = payload.get("threads")
|
self._state = payload
|
||||||
if not isinstance(threads, dict):
|
|
||||||
threads = {}
|
|
||||||
self._data = {"version": STATE_VERSION, "threads": threads}
|
|
||||||
|
|
||||||
def _save_locked(self) -> None:
|
def _save_locked(self) -> None:
|
||||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
payload = {"version": STATE_VERSION, "threads": self._data.get("threads", {})}
|
payload = msgspec.to_builtins(self._state)
|
||||||
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
|
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
|
||||||
with open(tmp_path, "w", encoding="utf-8") as handle:
|
with open(tmp_path, "w", encoding="utf-8") as handle:
|
||||||
json.dump(payload, handle, indent=2, sort_keys=True)
|
json.dump(payload, handle, indent=2, sort_keys=True)
|
||||||
@@ -267,22 +252,14 @@ class TopicStateStore:
|
|||||||
os.replace(tmp_path, self._path)
|
os.replace(tmp_path, self._path)
|
||||||
self._mtime_ns = self._stat_mtime_ns()
|
self._mtime_ns = self._stat_mtime_ns()
|
||||||
|
|
||||||
def _get_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any] | None:
|
def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None:
|
||||||
threads = self._data.get("threads")
|
return self._state.threads.get(_thread_key(chat_id, thread_id))
|
||||||
if not isinstance(threads, dict):
|
|
||||||
return None
|
|
||||||
entry = threads.get(_thread_key(chat_id, thread_id))
|
|
||||||
return entry if isinstance(entry, dict) else None
|
|
||||||
|
|
||||||
def _ensure_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any]:
|
def _ensure_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState:
|
||||||
threads = self._data.get("threads")
|
|
||||||
if not isinstance(threads, dict):
|
|
||||||
threads = {}
|
|
||||||
self._data["threads"] = threads
|
|
||||||
key = _thread_key(chat_id, thread_id)
|
key = _thread_key(chat_id, thread_id)
|
||||||
entry = threads.get(key)
|
entry = self._state.threads.get(key)
|
||||||
if isinstance(entry, dict):
|
if entry is not None:
|
||||||
return entry
|
return entry
|
||||||
entry = {}
|
entry = _ThreadState()
|
||||||
threads[key] = entry
|
self._state.threads[key] = entry
|
||||||
return entry
|
return entry
|
||||||
|
|||||||
@@ -185,9 +185,9 @@ async def _validate_topics_setup(cfg: TelegramBridgeConfig) -> None:
|
|||||||
if not cfg.topics.enabled:
|
if not cfg.topics.enabled:
|
||||||
return
|
return
|
||||||
me = await cfg.bot.get_me()
|
me = await cfg.bot.get_me()
|
||||||
bot_id = me.get("id") if isinstance(me, dict) else None
|
if me is None:
|
||||||
if not isinstance(bot_id, int):
|
|
||||||
raise ConfigError("failed to fetch bot id for topics validation.")
|
raise ConfigError("failed to fetch bot id for topics validation.")
|
||||||
|
bot_id = me.id
|
||||||
scope, chat_ids = _resolve_topics_scope(cfg)
|
scope, chat_ids = _resolve_topics_scope(cfg)
|
||||||
if scope == "projects" and not chat_ids:
|
if scope == "projects" and not chat_ids:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
@@ -197,37 +197,34 @@ async def _validate_topics_setup(cfg: TelegramBridgeConfig) -> None:
|
|||||||
|
|
||||||
for chat_id in chat_ids:
|
for chat_id in chat_ids:
|
||||||
chat = await cfg.bot.get_chat(chat_id)
|
chat = await cfg.bot.get_chat(chat_id)
|
||||||
if not isinstance(chat, dict):
|
if chat is None:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
f"failed to fetch chat info for topics validation ({chat_id})."
|
f"failed to fetch chat info for topics validation ({chat_id})."
|
||||||
)
|
)
|
||||||
chat_type = chat.get("type")
|
if chat.type != "supergroup":
|
||||||
is_forum = chat.get("is_forum")
|
|
||||||
if chat_type != "supergroup":
|
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"topics enabled but chat is not a supergroup "
|
"topics enabled but chat is not a supergroup "
|
||||||
f"(chat_id={chat_id}); convert the group and enable topics."
|
f"(chat_id={chat_id}); convert the group and enable topics."
|
||||||
)
|
)
|
||||||
if is_forum is not True:
|
if chat.is_forum is not True:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"topics enabled but chat does not have topics enabled "
|
"topics enabled but chat does not have topics enabled "
|
||||||
f"(chat_id={chat_id}); turn on topics in group settings."
|
f"(chat_id={chat_id}); turn on topics in group settings."
|
||||||
)
|
)
|
||||||
member = await cfg.bot.get_chat_member(chat_id, bot_id)
|
member = await cfg.bot.get_chat_member(chat_id, bot_id)
|
||||||
if not isinstance(member, dict):
|
if member is None:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"failed to fetch bot permissions "
|
"failed to fetch bot permissions "
|
||||||
f"(chat_id={chat_id}); promote the bot to admin with manage topics."
|
f"(chat_id={chat_id}); promote the bot to admin with manage topics."
|
||||||
)
|
)
|
||||||
status = member.get("status")
|
if member.status == "creator":
|
||||||
if status == "creator":
|
|
||||||
continue
|
continue
|
||||||
if status != "administrator":
|
if member.status != "administrator":
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"topics enabled but bot is not an admin "
|
"topics enabled but bot is not an admin "
|
||||||
f"(chat_id={chat_id}); promote it and grant manage topics."
|
f"(chat_id={chat_id}); promote it and grant manage topics."
|
||||||
)
|
)
|
||||||
if member.get("can_manage_topics") is not True:
|
if member.can_manage_topics is not True:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"topics enabled but bot lacks manage topics permission "
|
"topics enabled but bot lacks manage topics permission "
|
||||||
f"(chat_id={chat_id}); grant can_manage_topics."
|
f"(chat_id={chat_id}); grant can_manage_topics."
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import io
|
import io
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from ..logging import get_logger
|
from ..logging import get_logger
|
||||||
from openai import AsyncOpenAI, OpenAIError
|
from openai import AsyncOpenAI, OpenAIError
|
||||||
@@ -29,6 +28,7 @@ async def transcribe_voice(
|
|||||||
bot: BotClient,
|
bot: BotClient,
|
||||||
msg: TelegramIncomingMessage,
|
msg: TelegramIncomingMessage,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
|
max_bytes: int | None = None,
|
||||||
reply: Callable[..., Awaitable[None]],
|
reply: Callable[..., Awaitable[None]],
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
voice = msg.voice
|
voice = msg.voice
|
||||||
@@ -37,9 +37,24 @@ async def transcribe_voice(
|
|||||||
if not enabled:
|
if not enabled:
|
||||||
await reply(text=VOICE_TRANSCRIPTION_DISABLED_HINT)
|
await reply(text=VOICE_TRANSCRIPTION_DISABLED_HINT)
|
||||||
return None
|
return None
|
||||||
file_info = cast(dict[str, object], await bot.get_file(voice.file_id))
|
if (
|
||||||
file_path = cast(str, file_info["file_path"])
|
max_bytes is not None
|
||||||
audio_bytes = cast(bytes, await bot.download_file(file_path))
|
and voice.file_size is not None
|
||||||
|
and voice.file_size > max_bytes
|
||||||
|
):
|
||||||
|
await reply(text="voice message is too large to transcribe.")
|
||||||
|
return None
|
||||||
|
file_info = await bot.get_file(voice.file_id)
|
||||||
|
if file_info is None:
|
||||||
|
await reply(text="failed to fetch voice file.")
|
||||||
|
return None
|
||||||
|
audio_bytes = await bot.download_file(file_info.file_path)
|
||||||
|
if audio_bytes is None:
|
||||||
|
await reply(text="failed to download voice file.")
|
||||||
|
return None
|
||||||
|
if max_bytes is not None and len(audio_bytes) > max_bytes:
|
||||||
|
await reply(text="voice message is too large to transcribe.")
|
||||||
|
return None
|
||||||
audio_file = io.BytesIO(audio_bytes)
|
audio_file = io.BytesIO(audio_bytes)
|
||||||
audio_file.name = "voice.ogg"
|
audio_file.name = "voice.ogg"
|
||||||
async with AsyncOpenAI(timeout=120) as client:
|
async with AsyncOpenAI(timeout=120) as client:
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from .context import RunContext
|
|||||||
from .directives import format_context_line, parse_context_line, parse_directives
|
from .directives import format_context_line, parse_context_line, parse_directives
|
||||||
from .model import EngineId, ResumeToken
|
from .model import EngineId, ResumeToken
|
||||||
from .plugins import normalize_allowlist
|
from .plugins import normalize_allowlist
|
||||||
from .router import AutoRouter
|
from .router import AutoRouter, EngineStatus
|
||||||
from .runner import Runner
|
from .runner import Runner
|
||||||
from .worktrees import WorktreeError, resolve_run_cwd
|
from .worktrees import WorktreeError, resolve_run_cwd
|
||||||
|
|
||||||
@@ -108,11 +108,14 @@ class TransportRuntime:
|
|||||||
def available_engine_ids(self) -> tuple[EngineId, ...]:
|
def available_engine_ids(self) -> tuple[EngineId, ...]:
|
||||||
return tuple(entry.engine for entry in self._router.available_entries)
|
return tuple(entry.engine for entry in self._router.available_entries)
|
||||||
|
|
||||||
def missing_engine_ids(self) -> tuple[EngineId, ...]:
|
def engine_ids_with_status(self, status: EngineStatus) -> tuple[EngineId, ...]:
|
||||||
return tuple(
|
return tuple(
|
||||||
entry.engine for entry in self._router.entries if not entry.available
|
entry.engine for entry in self._router.entries if entry.status == status
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def missing_engine_ids(self) -> tuple[EngineId, ...]:
|
||||||
|
return self.engine_ids_with_status("missing_cli")
|
||||||
|
|
||||||
def project_aliases(self) -> tuple[str, ...]:
|
def project_aliases(self) -> tuple[str, ...]:
|
||||||
return tuple(project.alias for project in self._projects.projects.values())
|
return tuple(project.alias for project in self._projects.projects.values())
|
||||||
|
|
||||||
|
|||||||
@@ -41,13 +41,13 @@ class TransportBackend(Protocol):
|
|||||||
def interactive_setup(self, *, force: bool) -> bool: ...
|
def interactive_setup(self, *, force: bool) -> bool: ...
|
||||||
|
|
||||||
def lock_token(
|
def lock_token(
|
||||||
self, *, transport_config: dict[str, object], config_path: Path
|
self, *, transport_config: object, config_path: Path
|
||||||
) -> str | None: ...
|
) -> str | None: ...
|
||||||
|
|
||||||
def build_and_run(
|
def build_and_run(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
transport_config: dict[str, object],
|
transport_config: object,
|
||||||
config_path: Path,
|
config_path: Path,
|
||||||
runtime: TransportRuntime,
|
runtime: TransportRuntime,
|
||||||
final_notify: bool,
|
final_notify: bool,
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from takopi.backends import EngineBackend
|
||||||
from takopi.config import dump_toml
|
from takopi.config import dump_toml
|
||||||
from takopi.telegram import onboarding
|
from takopi.telegram import onboarding
|
||||||
from takopi.backends import EngineBackend
|
from takopi.telegram.api_models import User
|
||||||
|
|
||||||
|
|
||||||
def test_mask_token_short() -> None:
|
def test_mask_token_short() -> None:
|
||||||
@@ -91,7 +92,7 @@ def test_interactive_setup_writes_config(monkeypatch, tmp_path) -> None:
|
|||||||
|
|
||||||
def _fake_run(func, *args, **kwargs):
|
def _fake_run(func, *args, **kwargs):
|
||||||
if func is onboarding.get_bot_info:
|
if func is onboarding.get_bot_info:
|
||||||
return {"username": "my_bot"}
|
return User(id=1, username="my_bot")
|
||||||
if func is onboarding.wait_for_chat:
|
if func is onboarding.wait_for_chat:
|
||||||
return onboarding.ChatInfo(
|
return onboarding.ChatInfo(
|
||||||
chat_id=123,
|
chat_id=123,
|
||||||
@@ -136,7 +137,7 @@ def test_interactive_setup_preserves_projects(monkeypatch, tmp_path) -> None:
|
|||||||
|
|
||||||
def _fake_run(func, *args, **kwargs):
|
def _fake_run(func, *args, **kwargs):
|
||||||
if func is onboarding.get_bot_info:
|
if func is onboarding.get_bot_info:
|
||||||
return {"username": "my_bot"}
|
return User(id=1, username="my_bot")
|
||||||
if func is onboarding.wait_for_chat:
|
if func is onboarding.wait_for_chat:
|
||||||
return onboarding.ChatInfo(
|
return onboarding.ChatInfo(
|
||||||
chat_id=123,
|
chat_id=123,
|
||||||
@@ -173,7 +174,7 @@ def test_interactive_setup_no_agents_aborts(monkeypatch, tmp_path) -> None:
|
|||||||
|
|
||||||
def _fake_run(func, *args, **kwargs):
|
def _fake_run(func, *args, **kwargs):
|
||||||
if func is onboarding.get_bot_info:
|
if func is onboarding.get_bot_info:
|
||||||
return {"username": "my_bot"}
|
return User(id=1, username="my_bot")
|
||||||
if func is onboarding.wait_for_chat:
|
if func is onboarding.wait_for_chat:
|
||||||
return onboarding.ChatInfo(
|
return onboarding.ChatInfo(
|
||||||
chat_id=123,
|
chat_id=123,
|
||||||
@@ -211,7 +212,7 @@ def test_interactive_setup_recovers_from_malformed_toml(monkeypatch, tmp_path) -
|
|||||||
|
|
||||||
def _fake_run(func, *args, **kwargs):
|
def _fake_run(func, *args, **kwargs):
|
||||||
if func is onboarding.get_bot_info:
|
if func is onboarding.get_bot_info:
|
||||||
return {"username": "my_bot"}
|
return User(id=1, username="my_bot")
|
||||||
if func is onboarding.wait_for_chat:
|
if func is onboarding.wait_for_chat:
|
||||||
return onboarding.ChatInfo(
|
return onboarding.ChatInfo(
|
||||||
chat_id=123,
|
chat_id=123,
|
||||||
@@ -239,7 +240,7 @@ def test_interactive_setup_recovers_from_malformed_toml(monkeypatch, tmp_path) -
|
|||||||
def test_capture_chat_id_with_token(monkeypatch) -> None:
|
def test_capture_chat_id_with_token(monkeypatch) -> None:
|
||||||
def _fake_run(func, *args, **kwargs):
|
def _fake_run(func, *args, **kwargs):
|
||||||
if func is onboarding.get_bot_info:
|
if func is onboarding.get_bot_info:
|
||||||
return {"username": "my_bot"}
|
return User(id=1, username="my_bot")
|
||||||
if func is onboarding.wait_for_chat:
|
if func is onboarding.wait_for_chat:
|
||||||
return onboarding.ChatInfo(
|
return onboarding.ChatInfo(
|
||||||
chat_id=456,
|
chat_id=456,
|
||||||
@@ -261,7 +262,9 @@ def test_capture_chat_id_with_token(monkeypatch) -> None:
|
|||||||
|
|
||||||
def test_capture_chat_id_prompts_for_token(monkeypatch) -> None:
|
def test_capture_chat_id_prompts_for_token(monkeypatch) -> None:
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
onboarding, "_prompt_token", lambda _console: ("token", {"username": "bot"})
|
onboarding,
|
||||||
|
"_prompt_token",
|
||||||
|
lambda _console: ("token", User(id=1, username="bot")),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _fake_run(func, *args, **kwargs):
|
def _fake_run(func, *args, **kwargs):
|
||||||
|
|||||||
@@ -9,6 +9,11 @@ from takopi.config import ProjectsConfig
|
|||||||
from takopi.model import EngineId
|
from takopi.model import EngineId
|
||||||
from takopi.router import AutoRouter, RunnerEntry
|
from takopi.router import AutoRouter, RunnerEntry
|
||||||
from takopi.runners.mock import Return, ScriptRunner
|
from takopi.runners.mock import Return, ScriptRunner
|
||||||
|
from takopi.settings import (
|
||||||
|
TelegramFilesSettings,
|
||||||
|
TelegramTopicsSettings,
|
||||||
|
TelegramTransportSettings,
|
||||||
|
)
|
||||||
from takopi.telegram import backend as telegram_backend
|
from takopi.telegram import backend as telegram_backend
|
||||||
from takopi.transport_runtime import TransportRuntime
|
from takopi.transport_runtime import TransportRuntime
|
||||||
|
|
||||||
@@ -20,8 +25,13 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
|
|||||||
missing = ScriptRunner([Return(answer="ok")], engine=pi)
|
missing = ScriptRunner([Return(answer="ok")], engine=pi)
|
||||||
router = AutoRouter(
|
router = AutoRouter(
|
||||||
entries=[
|
entries=[
|
||||||
RunnerEntry(engine=codex, runner=runner, available=True),
|
RunnerEntry(engine=codex, runner=runner),
|
||||||
RunnerEntry(engine=pi, runner=missing, available=False, issue="missing"),
|
RunnerEntry(
|
||||||
|
engine=pi,
|
||||||
|
runner=missing,
|
||||||
|
status="missing_cli",
|
||||||
|
issue="missing",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
default_engine=codex,
|
default_engine=codex,
|
||||||
)
|
)
|
||||||
@@ -40,6 +50,44 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
|
|||||||
assert "projects: `none`" in message
|
assert "projects: `none`" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_startup_message_surfaces_unavailable_engine_reasons(
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
codex = EngineId("codex")
|
||||||
|
pi = EngineId("pi")
|
||||||
|
claude = EngineId("claude")
|
||||||
|
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
||||||
|
bad_cfg = ScriptRunner([Return(answer="ok")], engine=pi)
|
||||||
|
load_err = ScriptRunner([Return(answer="ok")], engine=claude)
|
||||||
|
|
||||||
|
router = AutoRouter(
|
||||||
|
entries=[
|
||||||
|
RunnerEntry(engine=codex, runner=runner),
|
||||||
|
RunnerEntry(engine=pi, runner=bad_cfg, status="bad_config", issue="bad"),
|
||||||
|
RunnerEntry(
|
||||||
|
engine=claude,
|
||||||
|
runner=load_err,
|
||||||
|
status="load_error",
|
||||||
|
issue="failed",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
default_engine=codex,
|
||||||
|
)
|
||||||
|
runtime = TransportRuntime(
|
||||||
|
router=router,
|
||||||
|
projects=ProjectsConfig(projects={}, default_project=None),
|
||||||
|
watch_config=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = telegram_backend._build_startup_message(
|
||||||
|
runtime, startup_pwd=str(tmp_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "agents: `codex" in message
|
||||||
|
assert "misconfigured: pi" in message
|
||||||
|
assert "failed to load: claude" in message
|
||||||
|
|
||||||
|
|
||||||
def test_telegram_backend_build_and_run_wires_config(
|
def test_telegram_backend_build_and_run_wires_config(
|
||||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -55,7 +103,7 @@ def test_telegram_backend_build_and_run_wires_config(
|
|||||||
codex = EngineId("codex")
|
codex = EngineId("codex")
|
||||||
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
||||||
router = AutoRouter(
|
router = AutoRouter(
|
||||||
entries=[RunnerEntry(engine=codex, runner=runner, available=True)],
|
entries=[RunnerEntry(engine=codex, runner=runner)],
|
||||||
default_engine=codex,
|
default_engine=codex,
|
||||||
)
|
)
|
||||||
runtime = TransportRuntime(
|
runtime = TransportRuntime(
|
||||||
@@ -80,13 +128,14 @@ def test_telegram_backend_build_and_run_wires_config(
|
|||||||
monkeypatch.setattr(telegram_backend, "run_main_loop", fake_run_main_loop)
|
monkeypatch.setattr(telegram_backend, "run_main_loop", fake_run_main_loop)
|
||||||
monkeypatch.setattr(telegram_backend, "TelegramClient", _FakeClient)
|
monkeypatch.setattr(telegram_backend, "TelegramClient", _FakeClient)
|
||||||
|
|
||||||
transport_config = {
|
transport_config = TelegramTransportSettings(
|
||||||
"bot_token": "token",
|
bot_token="token",
|
||||||
"chat_id": 321,
|
chat_id=321,
|
||||||
"voice_transcription": True,
|
voice_transcription=True,
|
||||||
"files": {"enabled": True, "allowed_user_ids": [1, 2]},
|
voice_max_bytes=1234,
|
||||||
"topics": {"enabled": True, "scope": "main"},
|
files=TelegramFilesSettings(enabled=True, allowed_user_ids=[1, 2]),
|
||||||
}
|
topics=TelegramTopicsSettings(enabled=True, scope="main"),
|
||||||
|
)
|
||||||
|
|
||||||
telegram_backend.TelegramBackend().build_and_run(
|
telegram_backend.TelegramBackend().build_and_run(
|
||||||
transport_config=transport_config,
|
transport_config=transport_config,
|
||||||
@@ -100,18 +149,19 @@ def test_telegram_backend_build_and_run_wires_config(
|
|||||||
kwargs = captured["kwargs"]
|
kwargs = captured["kwargs"]
|
||||||
assert cfg.chat_id == 321
|
assert cfg.chat_id == 321
|
||||||
assert cfg.voice_transcription is True
|
assert cfg.voice_transcription is True
|
||||||
|
assert cfg.voice_max_bytes == 1234
|
||||||
assert cfg.files.enabled is True
|
assert cfg.files.enabled is True
|
||||||
assert cfg.files.allowed_user_ids == frozenset({1, 2})
|
assert cfg.files.allowed_user_ids == [1, 2]
|
||||||
assert cfg.topics.enabled is True
|
assert cfg.topics.enabled is True
|
||||||
assert cfg.bot.token == "token"
|
assert cfg.bot.token == "token"
|
||||||
assert kwargs["watch_config"] is True
|
assert kwargs["watch_config"] is True
|
||||||
assert kwargs["transport_id"] == "telegram"
|
assert kwargs["transport_id"] == "telegram"
|
||||||
|
|
||||||
|
|
||||||
def test_build_files_config_defaults() -> None:
|
def test_telegram_files_settings_defaults() -> None:
|
||||||
cfg = telegram_backend._build_files_config({})
|
cfg = TelegramFilesSettings()
|
||||||
|
|
||||||
assert cfg.enabled is False
|
assert cfg.enabled is False
|
||||||
assert cfg.auto_put is True
|
assert cfg.auto_put is True
|
||||||
assert cfg.uploads_dir == "incoming"
|
assert cfg.uploads_dir == "incoming"
|
||||||
assert cfg.allowed_user_ids == frozenset()
|
assert cfg.allowed_user_ids == []
|
||||||
|
|||||||
@@ -6,22 +6,30 @@ import anyio
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from takopi import commands, plugins
|
from takopi import commands, plugins
|
||||||
import takopi.telegram.bridge as bridge
|
|
||||||
import takopi.telegram.loop as telegram_loop
|
import takopi.telegram.loop as telegram_loop
|
||||||
import takopi.telegram.commands as telegram_commands
|
import takopi.telegram.commands as telegram_commands
|
||||||
import takopi.telegram.topics as telegram_topics
|
import takopi.telegram.topics as telegram_topics
|
||||||
from takopi.directives import parse_directives
|
from takopi.directives import parse_directives
|
||||||
|
from takopi.telegram.api_models import (
|
||||||
|
Chat,
|
||||||
|
ChatMember,
|
||||||
|
File,
|
||||||
|
ForumTopic,
|
||||||
|
Message,
|
||||||
|
Update,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
from takopi.settings import TelegramFilesSettings, TelegramTopicsSettings
|
||||||
from takopi.telegram.bridge import (
|
from takopi.telegram.bridge import (
|
||||||
TelegramBridgeConfig,
|
TelegramBridgeConfig,
|
||||||
TelegramFilesConfig,
|
|
||||||
TelegramPresenter,
|
TelegramPresenter,
|
||||||
TelegramTransport,
|
TelegramTransport,
|
||||||
build_bot_commands,
|
build_bot_commands,
|
||||||
handle_callback_cancel,
|
handle_callback_cancel,
|
||||||
handle_cancel,
|
handle_cancel,
|
||||||
is_cancel_command,
|
is_cancel_command,
|
||||||
send_with_resume,
|
|
||||||
run_main_loop,
|
run_main_loop,
|
||||||
|
send_with_resume,
|
||||||
)
|
)
|
||||||
from takopi.telegram.client import BotClient
|
from takopi.telegram.client import BotClient
|
||||||
from takopi.telegram.topic_state import TopicStateStore, resolve_state_path
|
from takopi.telegram.topic_state import TopicStateStore, resolve_state_path
|
||||||
@@ -122,13 +130,13 @@ class _FakeBot(BotClient):
|
|||||||
offset: int | None,
|
offset: int | None,
|
||||||
timeout_s: int = 50,
|
timeout_s: int = 50,
|
||||||
allowed_updates: list[str] | None = None,
|
allowed_updates: list[str] | None = None,
|
||||||
) -> list[dict[str, Any]] | None:
|
) -> list[Update] | None:
|
||||||
_ = offset
|
_ = offset
|
||||||
_ = timeout_s
|
_ = timeout_s
|
||||||
_ = allowed_updates
|
_ = allowed_updates
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_file(self, file_id: str) -> dict[str, Any] | None:
|
async def get_file(self, file_id: str) -> File | None:
|
||||||
_ = file_id
|
_ = file_id
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -148,7 +156,7 @@ class _FakeBot(BotClient):
|
|||||||
reply_markup: dict | None = None,
|
reply_markup: dict | None = None,
|
||||||
*,
|
*,
|
||||||
replace_message_id: int | None = None,
|
replace_message_id: int | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> Message:
|
||||||
self.send_calls.append(
|
self.send_calls.append(
|
||||||
{
|
{
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
@@ -162,7 +170,7 @@ class _FakeBot(BotClient):
|
|||||||
"replace_message_id": replace_message_id,
|
"replace_message_id": replace_message_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return {"message_id": 1}
|
return Message(message_id=1)
|
||||||
|
|
||||||
async def send_document(
|
async def send_document(
|
||||||
self,
|
self,
|
||||||
@@ -173,7 +181,7 @@ class _FakeBot(BotClient):
|
|||||||
message_thread_id: int | None = None,
|
message_thread_id: int | None = None,
|
||||||
disable_notification: bool | None = False,
|
disable_notification: bool | None = False,
|
||||||
caption: str | None = None,
|
caption: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> Message:
|
||||||
self.document_calls.append(
|
self.document_calls.append(
|
||||||
{
|
{
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
@@ -185,7 +193,7 @@ class _FakeBot(BotClient):
|
|||||||
"caption": caption,
|
"caption": caption,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return {"message_id": 2}
|
return Message(message_id=2)
|
||||||
|
|
||||||
async def edit_message_text(
|
async def edit_message_text(
|
||||||
self,
|
self,
|
||||||
@@ -197,7 +205,7 @@ class _FakeBot(BotClient):
|
|||||||
reply_markup: dict | None = None,
|
reply_markup: dict | None = None,
|
||||||
*,
|
*,
|
||||||
wait: bool = True,
|
wait: bool = True,
|
||||||
) -> dict[str, Any]:
|
) -> Message:
|
||||||
self.edit_calls.append(
|
self.edit_calls.append(
|
||||||
{
|
{
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
@@ -209,7 +217,7 @@ class _FakeBot(BotClient):
|
|||||||
"wait": wait,
|
"wait": wait,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return {"message_id": message_id}
|
return Message(message_id=message_id)
|
||||||
|
|
||||||
async def delete_message(self, chat_id: int, message_id: int) -> bool:
|
async def delete_message(self, chat_id: int, message_id: int) -> bool:
|
||||||
self.delete_calls.append({"chat_id": chat_id, "message_id": message_id})
|
self.delete_calls.append({"chat_id": chat_id, "message_id": message_id})
|
||||||
@@ -231,26 +239,22 @@ class _FakeBot(BotClient):
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def get_me(self) -> dict[str, Any] | None:
|
async def get_me(self) -> User | None:
|
||||||
return {"id": 1}
|
return User(id=1, username="bot")
|
||||||
|
|
||||||
async def get_chat(self, chat_id: int) -> dict[str, Any] | None:
|
async def get_chat(self, chat_id: int) -> Chat | None:
|
||||||
_ = chat_id
|
_ = chat_id
|
||||||
return {"id": chat_id, "type": "supergroup", "is_forum": True}
|
return Chat(id=chat_id, type="supergroup", is_forum=True)
|
||||||
|
|
||||||
async def get_chat_member(
|
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
|
||||||
self, chat_id: int, user_id: int
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
_ = chat_id
|
_ = chat_id
|
||||||
_ = user_id
|
_ = user_id
|
||||||
return {"status": "administrator", "can_manage_topics": True}
|
return ChatMember(status="administrator", can_manage_topics=True)
|
||||||
|
|
||||||
async def create_forum_topic(
|
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
|
||||||
self, chat_id: int, name: str
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
_ = chat_id
|
_ = chat_id
|
||||||
_ = name
|
_ = name
|
||||||
return {"message_thread_id": 1}
|
return ForumTopic(message_thread_id=1)
|
||||||
|
|
||||||
async def edit_forum_topic(
|
async def edit_forum_topic(
|
||||||
self, chat_id: int, message_thread_id: int, name: str
|
self, chat_id: int, message_thread_id: int, name: str
|
||||||
@@ -539,10 +543,10 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
|||||||
offset: int | None,
|
offset: int | None,
|
||||||
timeout_s: int = 50,
|
timeout_s: int = 50,
|
||||||
allowed_updates: list[str] | None = None,
|
allowed_updates: list[str] | None = None,
|
||||||
) -> list[dict[str, Any]] | None:
|
) -> list[Update] | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_file(self, file_id: str) -> dict[str, Any] | None:
|
async def get_file(self, file_id: str) -> File | None:
|
||||||
_ = file_id
|
_ = file_id
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -562,7 +566,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
|||||||
reply_markup: dict | None = None,
|
reply_markup: dict | None = None,
|
||||||
*,
|
*,
|
||||||
replace_message_id: int | None = None,
|
replace_message_id: int | None = None,
|
||||||
) -> dict | None:
|
) -> Message | None:
|
||||||
_ = reply_markup
|
_ = reply_markup
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -575,7 +579,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
|||||||
message_thread_id: int | None = None,
|
message_thread_id: int | None = None,
|
||||||
disable_notification: bool | None = False,
|
disable_notification: bool | None = False,
|
||||||
caption: str | None = None,
|
caption: str | None = None,
|
||||||
) -> dict | None:
|
) -> Message | None:
|
||||||
_ = (
|
_ = (
|
||||||
chat_id,
|
chat_id,
|
||||||
filename,
|
filename,
|
||||||
@@ -597,7 +601,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
|||||||
reply_markup: dict | None = None,
|
reply_markup: dict | None = None,
|
||||||
*,
|
*,
|
||||||
wait: bool = True,
|
wait: bool = True,
|
||||||
) -> dict | None:
|
) -> Message | None:
|
||||||
self.edit_calls.append(
|
self.edit_calls.append(
|
||||||
{
|
{
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
@@ -611,7 +615,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
|||||||
)
|
)
|
||||||
if not wait:
|
if not wait:
|
||||||
return None
|
return None
|
||||||
return {"message_id": message_id}
|
return Message(message_id=message_id)
|
||||||
|
|
||||||
async def delete_message(
|
async def delete_message(
|
||||||
self,
|
self,
|
||||||
@@ -629,7 +633,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_me(self) -> dict[str, Any] | None:
|
async def get_me(self) -> User | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
@@ -778,9 +782,9 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
|
|||||||
payload = b"hello"
|
payload = b"hello"
|
||||||
|
|
||||||
class _FileBot(_FakeBot):
|
class _FileBot(_FakeBot):
|
||||||
async def get_file(self, file_id: str) -> dict[str, Any] | None:
|
async def get_file(self, file_id: str) -> File | None:
|
||||||
_ = file_id
|
_ = file_id
|
||||||
return {"file_path": "files/hello.txt"}
|
return File(file_path="files/hello.txt")
|
||||||
|
|
||||||
async def download_file(self, file_path: str) -> bytes | None:
|
async def download_file(self, file_path: str) -> bytes | None:
|
||||||
_ = file_path
|
_ = file_path
|
||||||
@@ -811,7 +815,7 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
|
|||||||
chat_id=123,
|
chat_id=123,
|
||||||
startup_msg="",
|
startup_msg="",
|
||||||
exec_cfg=exec_cfg,
|
exec_cfg=exec_cfg,
|
||||||
files=TelegramFilesConfig(enabled=True),
|
files=TelegramFilesSettings(enabled=True),
|
||||||
)
|
)
|
||||||
msg = TelegramIncomingMessage(
|
msg = TelegramIncomingMessage(
|
||||||
transport="telegram",
|
transport="telegram",
|
||||||
@@ -876,9 +880,9 @@ async def test_handle_file_get_sends_document_for_allowed_user(
|
|||||||
chat_id=123,
|
chat_id=123,
|
||||||
startup_msg="",
|
startup_msg="",
|
||||||
exec_cfg=exec_cfg,
|
exec_cfg=exec_cfg,
|
||||||
files=TelegramFilesConfig(
|
files=TelegramFilesSettings(
|
||||||
enabled=True,
|
enabled=True,
|
||||||
allowed_user_ids=frozenset({42}),
|
allowed_user_ids=[42],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
msg = TelegramIncomingMessage(
|
msg = TelegramIncomingMessage(
|
||||||
@@ -1006,7 +1010,7 @@ def test_topic_title_projects_scope_includes_project() -> None:
|
|||||||
transport = _FakeTransport()
|
transport = _FakeTransport()
|
||||||
cfg = replace(
|
cfg = replace(
|
||||||
_make_cfg(transport),
|
_make_cfg(transport),
|
||||||
topics=bridge.TelegramTopicsConfig(
|
topics=TelegramTopicsSettings(
|
||||||
enabled=True,
|
enabled=True,
|
||||||
scope="projects",
|
scope="projects",
|
||||||
),
|
),
|
||||||
@@ -1277,7 +1281,7 @@ async def test_run_main_loop_persists_topic_sessions_in_project_scope(
|
|||||||
chat_id=123,
|
chat_id=123,
|
||||||
startup_msg="",
|
startup_msg="",
|
||||||
exec_cfg=exec_cfg,
|
exec_cfg=exec_cfg,
|
||||||
topics=bridge.TelegramTopicsConfig(
|
topics=TelegramTopicsSettings(
|
||||||
enabled=True,
|
enabled=True,
|
||||||
scope="projects",
|
scope="projects",
|
||||||
),
|
),
|
||||||
@@ -1363,11 +1367,11 @@ async def test_run_main_loop_batches_media_group_upload(
|
|||||||
}
|
}
|
||||||
|
|
||||||
class _MediaBot(_FakeBot):
|
class _MediaBot(_FakeBot):
|
||||||
async def get_file(self, file_id: str) -> dict[str, Any] | None:
|
async def get_file(self, file_id: str) -> File | None:
|
||||||
file_path = file_map.get(file_id)
|
file_path = file_map.get(file_id)
|
||||||
if file_path is None:
|
if file_path is None:
|
||||||
return None
|
return None
|
||||||
return {"file_path": file_path}
|
return File(file_path=file_path)
|
||||||
|
|
||||||
async def download_file(self, file_path: str) -> bytes | None:
|
async def download_file(self, file_path: str) -> bytes | None:
|
||||||
return payloads.get(file_path)
|
return payloads.get(file_path)
|
||||||
@@ -1397,7 +1401,7 @@ async def test_run_main_loop_batches_media_group_upload(
|
|||||||
chat_id=123,
|
chat_id=123,
|
||||||
startup_msg="",
|
startup_msg="",
|
||||||
exec_cfg=exec_cfg,
|
exec_cfg=exec_cfg,
|
||||||
files=TelegramFilesConfig(enabled=True, auto_put=True),
|
files=TelegramFilesSettings(enabled=True, auto_put=True),
|
||||||
)
|
)
|
||||||
msg1 = TelegramIncomingMessage(
|
msg1 = TelegramIncomingMessage(
|
||||||
transport="telegram",
|
transport="telegram",
|
||||||
|
|||||||
@@ -57,3 +57,175 @@ async def test_no_token_in_logs_on_http_error(
|
|||||||
out = capsys.readouterr().out
|
out = capsys.readouterr().out
|
||||||
assert token not in out
|
assert token not in out
|
||||||
assert "bot[REDACTED]" in out
|
assert "bot[REDACTED]" in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_telegram_429_no_retry_post_form() -> None:
|
||||||
|
calls: list[int] = []
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
calls.append(1)
|
||||||
|
return httpx.Response(
|
||||||
|
429,
|
||||||
|
json={
|
||||||
|
"ok": False,
|
||||||
|
"description": "retry",
|
||||||
|
"parameters": {"retry_after": 2},
|
||||||
|
},
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
|
||||||
|
client = httpx.AsyncClient(transport=transport)
|
||||||
|
try:
|
||||||
|
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||||
|
with pytest.raises(TelegramRetryAfter) as exc:
|
||||||
|
await tg._post_form(
|
||||||
|
"sendDocument",
|
||||||
|
{"chat_id": 1},
|
||||||
|
files={"document": ("note.txt", b"hi")},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
assert exc.value.retry_after == 2
|
||||||
|
assert len(calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_telegram_429_defaults_retry_after_on_bad_body() -> None:
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(429, text="nope", request=request)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
|
||||||
|
client = httpx.AsyncClient(transport=transport)
|
||||||
|
try:
|
||||||
|
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||||
|
with pytest.raises(TelegramRetryAfter) as exc:
|
||||||
|
await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
||||||
|
finally:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
assert exc.value.retry_after == 5.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_telegram_ok_false_returns_none() -> None:
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={"ok": False, "error_code": 400, "description": "bad"},
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
|
||||||
|
client = httpx.AsyncClient(transport=transport)
|
||||||
|
try:
|
||||||
|
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||||
|
result = await tg._post("getUpdates", {"timeout": 1})
|
||||||
|
finally:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_telegram_invalid_payload_returns_none() -> None:
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(200, json=["not", "a", "dict"], request=request)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
|
||||||
|
client = httpx.AsyncClient(transport=transport)
|
||||||
|
try:
|
||||||
|
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||||
|
result = await tg._post("getUpdates", {"timeout": 1})
|
||||||
|
finally:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_telegram_decode_failure_returns_none() -> None:
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={"ok": True, "result": {"username": "bot-only"}},
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
|
||||||
|
client = httpx.AsyncClient(transport=transport)
|
||||||
|
try:
|
||||||
|
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||||
|
result = await tg.get_me()
|
||||||
|
finally:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_telegram_download_file_retries_on_429() -> None:
|
||||||
|
calls: list[int] = []
|
||||||
|
sleeps: list[float] = []
|
||||||
|
|
||||||
|
async def sleep(delay: float) -> None:
|
||||||
|
sleeps.append(delay)
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
calls.append(1)
|
||||||
|
if len(calls) == 1:
|
||||||
|
return httpx.Response(
|
||||||
|
429,
|
||||||
|
json={"ok": False, "parameters": {"retry_after": 3}},
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
return httpx.Response(200, content=b"ok", request=request)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
|
||||||
|
client = httpx.AsyncClient(transport=transport)
|
||||||
|
try:
|
||||||
|
tg = TelegramClient("123:abcDEF_ghij", http_client=client, sleep=sleep)
|
||||||
|
payload = await tg.download_file("path/to/file")
|
||||||
|
finally:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
assert payload == b"ok"
|
||||||
|
assert sleeps == [3.0]
|
||||||
|
assert len(calls) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_telegram_download_file_429_defaults_retry_after_on_bad_body() -> None:
|
||||||
|
sleeps: list[float] = []
|
||||||
|
|
||||||
|
async def sleep(delay: float) -> None:
|
||||||
|
sleeps.append(delay)
|
||||||
|
|
||||||
|
calls: list[int] = []
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
calls.append(1)
|
||||||
|
if len(calls) == 1:
|
||||||
|
return httpx.Response(429, text="nope", request=request)
|
||||||
|
return httpx.Response(200, content=b"ok", request=request)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
|
||||||
|
client = httpx.AsyncClient(transport=transport)
|
||||||
|
try:
|
||||||
|
tg = TelegramClient("123:abcDEF_ghij", http_client=client, sleep=sleep)
|
||||||
|
payload = await tg.download_file("path")
|
||||||
|
finally:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
assert payload == b"ok"
|
||||||
|
assert sleeps == [5.0]
|
||||||
|
assert len(calls) == 2
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Any
|
|||||||
import anyio
|
import anyio
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from takopi.telegram.api_models import File, Message, Update, User
|
||||||
from takopi.telegram.client import BotClient, TelegramClient, TelegramRetryAfter
|
from takopi.telegram.client import BotClient, TelegramClient, TelegramRetryAfter
|
||||||
|
|
||||||
|
|
||||||
@@ -29,7 +30,7 @@ class _FakeBot(BotClient):
|
|||||||
reply_markup: dict | None = None,
|
reply_markup: dict | None = None,
|
||||||
*,
|
*,
|
||||||
replace_message_id: int | None = None,
|
replace_message_id: int | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> Message | None:
|
||||||
_ = reply_to_message_id
|
_ = reply_to_message_id
|
||||||
_ = disable_notification
|
_ = disable_notification
|
||||||
_ = message_thread_id
|
_ = message_thread_id
|
||||||
@@ -38,7 +39,7 @@ class _FakeBot(BotClient):
|
|||||||
_ = reply_markup
|
_ = reply_markup
|
||||||
_ = replace_message_id
|
_ = replace_message_id
|
||||||
self.calls.append("send_message")
|
self.calls.append("send_message")
|
||||||
return {"message_id": 1}
|
return Message(message_id=1)
|
||||||
|
|
||||||
async def send_document(
|
async def send_document(
|
||||||
self,
|
self,
|
||||||
@@ -49,7 +50,7 @@ class _FakeBot(BotClient):
|
|||||||
message_thread_id: int | None = None,
|
message_thread_id: int | None = None,
|
||||||
disable_notification: bool | None = False,
|
disable_notification: bool | None = False,
|
||||||
caption: str | None = None,
|
caption: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> Message | None:
|
||||||
_ = (
|
_ = (
|
||||||
chat_id,
|
chat_id,
|
||||||
filename,
|
filename,
|
||||||
@@ -60,7 +61,7 @@ class _FakeBot(BotClient):
|
|||||||
caption,
|
caption,
|
||||||
)
|
)
|
||||||
self.calls.append("send_document")
|
self.calls.append("send_document")
|
||||||
return {"message_id": 1}
|
return Message(message_id=1)
|
||||||
|
|
||||||
async def edit_message_text(
|
async def edit_message_text(
|
||||||
self,
|
self,
|
||||||
@@ -72,7 +73,7 @@ class _FakeBot(BotClient):
|
|||||||
reply_markup: dict | None = None,
|
reply_markup: dict | None = None,
|
||||||
*,
|
*,
|
||||||
wait: bool = True,
|
wait: bool = True,
|
||||||
) -> dict[str, Any]:
|
) -> Message | None:
|
||||||
_ = chat_id
|
_ = chat_id
|
||||||
_ = message_id
|
_ = message_id
|
||||||
_ = entities
|
_ = entities
|
||||||
@@ -85,7 +86,7 @@ class _FakeBot(BotClient):
|
|||||||
self._edit_attempts += 1
|
self._edit_attempts += 1
|
||||||
raise TelegramRetryAfter(self.retry_after)
|
raise TelegramRetryAfter(self.retry_after)
|
||||||
self._edit_attempts += 1
|
self._edit_attempts += 1
|
||||||
return {"message_id": message_id}
|
return Message(message_id=message_id)
|
||||||
|
|
||||||
async def delete_message(
|
async def delete_message(
|
||||||
self,
|
self,
|
||||||
@@ -113,7 +114,7 @@ class _FakeBot(BotClient):
|
|||||||
offset: int | None,
|
offset: int | None,
|
||||||
timeout_s: int = 50,
|
timeout_s: int = 50,
|
||||||
allowed_updates: list[str] | None = None,
|
allowed_updates: list[str] | None = None,
|
||||||
) -> list[dict[str, Any]] | None:
|
) -> list[Update] | None:
|
||||||
_ = offset
|
_ = offset
|
||||||
_ = timeout_s
|
_ = timeout_s
|
||||||
_ = allowed_updates
|
_ = allowed_updates
|
||||||
@@ -123,7 +124,7 @@ class _FakeBot(BotClient):
|
|||||||
self._updates_attempts += 1
|
self._updates_attempts += 1
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_file(self, file_id: str) -> dict[str, Any] | None:
|
async def get_file(self, file_id: str) -> File | None:
|
||||||
_ = file_id
|
_ = file_id
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -134,8 +135,8 @@ class _FakeBot(BotClient):
|
|||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_me(self) -> dict[str, Any] | None:
|
async def get_me(self) -> User | None:
|
||||||
return {"id": 1}
|
return User(id=1)
|
||||||
|
|
||||||
async def answer_callback_query(
|
async def answer_callback_query(
|
||||||
self,
|
self,
|
||||||
@@ -187,7 +188,7 @@ async def test_edits_coalesce_latest() -> None:
|
|||||||
reply_markup: dict | None = None,
|
reply_markup: dict | None = None,
|
||||||
*,
|
*,
|
||||||
wait: bool = True,
|
wait: bool = True,
|
||||||
) -> dict:
|
) -> Message | None:
|
||||||
if self._block_first:
|
if self._block_first:
|
||||||
self._block_first = False
|
self._block_first = False
|
||||||
self.edit_started.set()
|
self.edit_started.set()
|
||||||
@@ -305,7 +306,8 @@ async def test_retry_after_retries_once() -> None:
|
|||||||
text="retry",
|
text="retry",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == {"message_id": 1}
|
assert result is not None
|
||||||
|
assert result.message_id == 1
|
||||||
assert bot._edit_attempts == 2
|
assert bot._edit_attempts == 2
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,243 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from takopi.telegram.api_models import (
|
||||||
|
Chat,
|
||||||
|
ChatMember,
|
||||||
|
File,
|
||||||
|
ForumTopic,
|
||||||
|
Message,
|
||||||
|
Update,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
from takopi.telegram.client import BotClient
|
||||||
|
from takopi.telegram.types import TelegramIncomingMessage, TelegramVoice
|
||||||
|
from takopi.telegram.voice import transcribe_voice
|
||||||
|
|
||||||
|
|
||||||
|
class _Bot(BotClient):
|
||||||
|
def __init__(self, *, file_info: File | None, audio: bytes | None) -> None:
|
||||||
|
self._file_info = file_info
|
||||||
|
self._audio = audio
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_updates(
|
||||||
|
self,
|
||||||
|
offset: int | None,
|
||||||
|
timeout_s: int = 50,
|
||||||
|
allowed_updates: list[str] | None = None,
|
||||||
|
) -> list[Update] | None:
|
||||||
|
_ = offset, timeout_s, allowed_updates
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_file(self, file_id: str) -> File | None:
|
||||||
|
_ = file_id
|
||||||
|
return self._file_info
|
||||||
|
|
||||||
|
async def download_file(self, file_path: str) -> bytes | None:
|
||||||
|
_ = file_path
|
||||||
|
return self._audio
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
text: str,
|
||||||
|
reply_to_message_id: int | None = None,
|
||||||
|
disable_notification: bool | None = False,
|
||||||
|
message_thread_id: int | None = None,
|
||||||
|
entities: list[dict] | None = None,
|
||||||
|
parse_mode: str | None = None,
|
||||||
|
reply_markup: dict | None = None,
|
||||||
|
*,
|
||||||
|
replace_message_id: int | None = None,
|
||||||
|
) -> Message | None:
|
||||||
|
_ = (
|
||||||
|
chat_id,
|
||||||
|
text,
|
||||||
|
reply_to_message_id,
|
||||||
|
disable_notification,
|
||||||
|
message_thread_id,
|
||||||
|
entities,
|
||||||
|
parse_mode,
|
||||||
|
reply_markup,
|
||||||
|
replace_message_id,
|
||||||
|
)
|
||||||
|
raise AssertionError("send_message should not be called")
|
||||||
|
|
||||||
|
async def send_document(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
reply_to_message_id: int | None = None,
|
||||||
|
message_thread_id: int | None = None,
|
||||||
|
disable_notification: bool | None = False,
|
||||||
|
caption: str | None = None,
|
||||||
|
) -> Message | None:
|
||||||
|
_ = (
|
||||||
|
chat_id,
|
||||||
|
filename,
|
||||||
|
content,
|
||||||
|
reply_to_message_id,
|
||||||
|
message_thread_id,
|
||||||
|
disable_notification,
|
||||||
|
caption,
|
||||||
|
)
|
||||||
|
raise AssertionError("send_document should not be called")
|
||||||
|
|
||||||
|
async def edit_message_text(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
message_id: int,
|
||||||
|
text: str,
|
||||||
|
entities: list[dict] | None = None,
|
||||||
|
parse_mode: str | None = None,
|
||||||
|
reply_markup: dict | None = None,
|
||||||
|
*,
|
||||||
|
wait: bool = True,
|
||||||
|
) -> Message | None:
|
||||||
|
_ = (
|
||||||
|
chat_id,
|
||||||
|
message_id,
|
||||||
|
text,
|
||||||
|
entities,
|
||||||
|
parse_mode,
|
||||||
|
reply_markup,
|
||||||
|
wait,
|
||||||
|
)
|
||||||
|
raise AssertionError("edit_message_text should not be called")
|
||||||
|
|
||||||
|
async def delete_message(self, chat_id: int, message_id: int) -> bool:
|
||||||
|
_ = chat_id, message_id
|
||||||
|
raise AssertionError("delete_message should not be called")
|
||||||
|
|
||||||
|
async def set_my_commands(
|
||||||
|
self,
|
||||||
|
commands: list[dict],
|
||||||
|
*,
|
||||||
|
scope: dict | None = None,
|
||||||
|
language_code: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
_ = commands, scope, language_code
|
||||||
|
raise AssertionError("set_my_commands should not be called")
|
||||||
|
|
||||||
|
async def get_me(self) -> User | None:
|
||||||
|
raise AssertionError("get_me should not be called")
|
||||||
|
|
||||||
|
async def answer_callback_query(
|
||||||
|
self,
|
||||||
|
callback_query_id: str,
|
||||||
|
text: str | None = None,
|
||||||
|
show_alert: bool | None = None,
|
||||||
|
) -> bool:
|
||||||
|
_ = callback_query_id, text, show_alert
|
||||||
|
raise AssertionError("answer_callback_query should not be called")
|
||||||
|
|
||||||
|
async def get_chat(self, chat_id: int) -> Chat | None:
|
||||||
|
_ = chat_id
|
||||||
|
raise AssertionError("get_chat should not be called")
|
||||||
|
|
||||||
|
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
|
||||||
|
_ = chat_id, user_id
|
||||||
|
raise AssertionError("get_chat_member should not be called")
|
||||||
|
|
||||||
|
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
|
||||||
|
_ = chat_id, name
|
||||||
|
raise AssertionError("create_forum_topic should not be called")
|
||||||
|
|
||||||
|
async def edit_forum_topic(
|
||||||
|
self, chat_id: int, message_thread_id: int, name: str
|
||||||
|
) -> bool:
|
||||||
|
_ = chat_id, message_thread_id, name
|
||||||
|
raise AssertionError("edit_forum_topic should not be called")
|
||||||
|
|
||||||
|
|
||||||
|
def _voice_message(*, file_size: int = 123) -> TelegramIncomingMessage:
|
||||||
|
voice = TelegramVoice(
|
||||||
|
file_id="voice-id",
|
||||||
|
mime_type="audio/ogg",
|
||||||
|
file_size=file_size,
|
||||||
|
duration=1,
|
||||||
|
raw={},
|
||||||
|
)
|
||||||
|
return TelegramIncomingMessage(
|
||||||
|
transport="telegram",
|
||||||
|
chat_id=1,
|
||||||
|
message_id=1,
|
||||||
|
text="",
|
||||||
|
reply_to_message_id=None,
|
||||||
|
reply_to_text=None,
|
||||||
|
sender_id=1,
|
||||||
|
voice=voice,
|
||||||
|
raw={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_transcribe_voice_handles_missing_file() -> None:
|
||||||
|
replies: list[str] = []
|
||||||
|
|
||||||
|
async def reply(**kwargs) -> None:
|
||||||
|
replies.append(kwargs["text"])
|
||||||
|
|
||||||
|
bot = _Bot(file_info=None, audio=None)
|
||||||
|
result = await transcribe_voice(
|
||||||
|
bot=bot,
|
||||||
|
msg=_voice_message(),
|
||||||
|
enabled=True,
|
||||||
|
reply=reply,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
assert replies[-1] == "failed to fetch voice file."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_transcribe_voice_handles_missing_download() -> None:
|
||||||
|
replies: list[str] = []
|
||||||
|
|
||||||
|
async def reply(**kwargs) -> None:
|
||||||
|
replies.append(kwargs["text"])
|
||||||
|
|
||||||
|
bot = _Bot(file_info=File(file_path="voice.ogg"), audio=None)
|
||||||
|
result = await transcribe_voice(
|
||||||
|
bot=bot,
|
||||||
|
msg=_voice_message(),
|
||||||
|
enabled=True,
|
||||||
|
reply=reply,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
assert replies[-1] == "failed to download voice file."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_transcribe_voice_rejects_large_voice_without_downloading() -> None:
|
||||||
|
replies: list[str] = []
|
||||||
|
|
||||||
|
async def reply(**kwargs) -> None:
|
||||||
|
replies.append(kwargs["text"])
|
||||||
|
|
||||||
|
class _NoFetchBot(_Bot):
|
||||||
|
async def get_file(self, file_id: str) -> File | None: # type: ignore[override]
|
||||||
|
_ = file_id
|
||||||
|
raise AssertionError("get_file should not be called")
|
||||||
|
|
||||||
|
async def download_file(self, file_path: str) -> bytes | None: # type: ignore[override]
|
||||||
|
_ = file_path
|
||||||
|
raise AssertionError("download_file should not be called")
|
||||||
|
|
||||||
|
bot = _NoFetchBot(file_info=None, audio=None)
|
||||||
|
result = await transcribe_voice(
|
||||||
|
bot=bot,
|
||||||
|
msg=_voice_message(file_size=10_000),
|
||||||
|
enabled=True,
|
||||||
|
max_bytes=100,
|
||||||
|
reply=reply,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
assert replies[-1] == "voice message is too large to transcribe."
|
||||||
@@ -15,14 +15,14 @@ class DummyTransport:
|
|||||||
def interactive_setup(self, *, force: bool) -> bool:
|
def interactive_setup(self, *, force: bool) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def lock_token(self, *, transport_config: dict[str, object], config_path):
|
def lock_token(self, *, transport_config: object, config_path):
|
||||||
_ = transport_config, config_path
|
_ = transport_config, config_path
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def build_and_run(
|
def build_and_run(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
transport_config: dict[str, object],
|
transport_config: object,
|
||||||
config_path,
|
config_path,
|
||||||
runtime,
|
runtime,
|
||||||
final_notify: bool,
|
final_notify: bool,
|
||||||
|
|||||||
Reference in New Issue
Block a user