diff --git a/src/takopi/cli.py b/src/takopi/cli.py deleted file mode 100644 index a680d47..0000000 --- a/src/takopi/cli.py +++ /dev/null @@ -1,1090 +0,0 @@ -from __future__ import annotations - -import os -import re -import sys -import tomllib -from dataclasses import dataclass -from collections.abc import Callable -from importlib.metadata import EntryPoint -from pathlib import Path -from typing import Any, Literal - -import anyio -from functools import partial -from pydantic import BaseModel -import typer - -from . import __version__ -from .config import ( - ConfigError, - HOME_CONFIG_PATH, - dump_toml, - load_or_init_config, - read_config, - write_config, -) -from .config_migrations import migrate_config -from .commands import get_command -from .backends import EngineBackend -from .engines import get_backend, list_backend_ids -from .ids import RESERVED_CHAT_COMMANDS, RESERVED_COMMAND_IDS, RESERVED_ENGINE_IDS -from .lockfile import LockError, LockHandle, acquire_lock, token_fingerprint -from .logging import get_logger, setup_logging -from .runtime_loader import build_runtime_spec, resolve_plugins_allowlist -from .settings import ( - TakopiSettings, - TelegramTopicsSettings, - load_settings, - load_settings_if_exists, - validate_settings_data, -) -from .plugins import ( - COMMAND_GROUP, - ENGINE_GROUP, - TRANSPORT_GROUP, - entrypoint_distribution_name, - get_load_errors, - is_entrypoint_allowed, - list_entrypoints, - normalize_allowlist, -) -from .transports import SetupResult, get_transport -from .utils.git import resolve_default_base, resolve_main_worktree_root -from .telegram import onboarding -from .telegram.client import TelegramClient -from .telegram.topics import _validate_topics_setup_for - -logger = get_logger(__name__) - -_KEY_SEGMENT_RE = re.compile(r"^[A-Za-z0-9_-]+$") -_MISSING = object() -_CONFIG_PATH_OPTION = typer.Option( - None, - "--config-path", - help="Override the default config path.", -) - - -def _load_settings_optional() -> tuple[TakopiSettings | None, Path | None]: - try: - loaded = load_settings_if_exists() - except ConfigError: - return None, None - if loaded is None: - return None, None - return loaded - - -DoctorStatus = Literal["ok", "warning", "error"] - - -@dataclass(frozen=True, slots=True) -class DoctorCheck: - label: str - status: DoctorStatus - detail: str | None = None - - def render(self) -> str: - if self.detail: - return f"- {self.label}: {self.status} ({self.detail})" - return f"- {self.label}: {self.status}" - - -def _print_version_and_exit() -> None: - typer.echo(__version__) - raise typer.Exit() - - -def _version_callback(value: bool) -> None: - if value: - _print_version_and_exit() - - -def _resolve_transport_id(override: str | None) -> str: - if override is not None: - value = override.strip() - if not value: - raise ConfigError("Invalid `--transport`; expected a non-empty string.") - return value - try: - config, _ = load_or_init_config() - except ConfigError: - return "telegram" - raw = config.get("transport") - if not isinstance(raw, str) or not raw.strip(): - return "telegram" - return raw.strip() - - -def acquire_config_lock(config_path: Path, token: str | None) -> LockHandle: - fingerprint = token_fingerprint(token) if token else None - try: - return acquire_lock( - config_path=config_path, - token_fingerprint=fingerprint, - ) - except LockError as exc: - lines = str(exc).splitlines() - if lines: - typer.echo(lines[0], err=True) - if len(lines) > 1: - typer.echo("\n".join(lines[1:]), err=True) - else: - typer.echo("error: unknown error", err=True) - raise typer.Exit(code=1) from exc - - -def _default_engine_for_setup( - override: str | None, - *, - settings: TakopiSettings | None, - config_path: Path | None, -) -> str: - if override: - return override - if settings is None or config_path is None: - return "codex" - value = settings.default_engine - return value - - -def _resolve_setup_engine( - default_engine_override: str | None, -) -> tuple[ - TakopiSettings | None, - Path | None, - list[str] | None, - str, - EngineBackend, -]: - settings_hint, config_hint = _load_settings_optional() - allowlist = resolve_plugins_allowlist(settings_hint) - default_engine = _default_engine_for_setup( - default_engine_override, - settings=settings_hint, - config_path=config_hint, - ) - engine_backend = get_backend(default_engine, allowlist=allowlist) - return settings_hint, config_hint, allowlist, default_engine, engine_backend - - -def _config_path_display(path: Path) -> str: - home = Path.home() - try: - return f"~/{path.relative_to(home)}" - except ValueError: - return str(path) - - -def _should_run_interactive() -> bool: - if os.environ.get("TAKOPI_NO_INTERACTIVE"): - return False - return sys.stdin.isatty() and sys.stdout.isatty() - - -def _setup_needs_config(setup: SetupResult) -> bool: - config_titles = {"create a config", "configure telegram"} - return any(issue.title in config_titles for issue in setup.issues) - - -def _fail_missing_config(path: Path) -> None: - display = _config_path_display(path) - if path.exists(): - typer.echo(f"error: invalid takopi config at {display}", err=True) - else: - typer.echo(f"error: missing takopi config at {display}", err=True) - - -def _doctor_file_checks(settings: TakopiSettings) -> list[DoctorCheck]: - files = settings.transports.telegram.files - if not files.enabled: - return [DoctorCheck("file transfer", "ok", "disabled")] - if files.allowed_user_ids: - count = len(files.allowed_user_ids) - detail = f"restricted to {count} user id(s)" - return [DoctorCheck("file transfer", "ok", detail)] - return [DoctorCheck("file transfer", "warning", "enabled for all users")] - - -def _doctor_voice_checks(settings: TakopiSettings) -> list[DoctorCheck]: - if not settings.transports.telegram.voice_transcription: - return [DoctorCheck("voice transcription", "ok", "disabled")] - if os.environ.get("OPENAI_API_KEY"): - return [DoctorCheck("voice transcription", "ok", "OPENAI_API_KEY set")] - return [DoctorCheck("voice transcription", "error", "OPENAI_API_KEY not set")] - - -async def _doctor_telegram_checks( - token: str, - chat_id: int, - topics: TelegramTopicsSettings, - project_chat_ids: tuple[int, ...], -) -> list[DoctorCheck]: - checks: list[DoctorCheck] = [] - bot = TelegramClient(token) - try: - me = await bot.get_me() - if me is None: - checks.append( - DoctorCheck("telegram token", "error", "failed to fetch bot info") - ) - checks.append(DoctorCheck("chat_id", "error", "skipped (token invalid)")) - if topics.enabled: - checks.append(DoctorCheck("topics", "error", "skipped (token invalid)")) - else: - checks.append(DoctorCheck("topics", "ok", "disabled")) - return checks - bot_label = f"@{me.username}" if me.username else f"id={me.id}" - checks.append(DoctorCheck("telegram token", "ok", bot_label)) - chat = await bot.get_chat(chat_id) - if chat is None: - checks.append(DoctorCheck("chat_id", "error", f"unreachable ({chat_id})")) - else: - checks.append(DoctorCheck("chat_id", "ok", f"{chat.type} ({chat_id})")) - if topics.enabled: - try: - await _validate_topics_setup_for( - bot=bot, - topics=topics, - chat_id=chat_id, - project_chat_ids=project_chat_ids, - ) - checks.append(DoctorCheck("topics", "ok", f"scope={topics.scope}")) - except ConfigError as exc: - checks.append(DoctorCheck("topics", "error", str(exc))) - else: - checks.append(DoctorCheck("topics", "ok", "disabled")) - except Exception as exc: # noqa: BLE001 - checks.append(DoctorCheck("telegram", "error", str(exc))) - finally: - await bot.close() - return checks - - -def _run_auto_router( - *, - default_engine_override: str | None, - transport_override: str | None, - final_notify: bool, - debug: bool, - onboard: bool, -) -> None: - if debug: - os.environ.setdefault("TAKOPI_LOG_FILE", "debug.log") - setup_logging(debug=debug) - lock_handle: LockHandle | None = None - try: - ( - settings_hint, - config_hint, - allowlist, - default_engine, - engine_backend, - ) = _resolve_setup_engine(default_engine_override) - transport_id = _resolve_transport_id(transport_override) - transport_backend = get_transport(transport_id, allowlist=allowlist) - except ConfigError as e: - typer.echo(f"error: {e}", err=True) - raise typer.Exit(code=1) from e - if onboard: - if not _should_run_interactive(): - typer.echo("error: --onboard requires a TTY", err=True) - raise typer.Exit(code=1) - if not anyio.run(partial(transport_backend.interactive_setup, force=True)): - raise typer.Exit(code=1) - ( - settings_hint, - config_hint, - allowlist, - default_engine, - engine_backend, - ) = _resolve_setup_engine(default_engine_override) - setup = transport_backend.check_setup( - engine_backend, - transport_override=transport_override, - ) - if not setup.ok: - if _setup_needs_config(setup) and _should_run_interactive(): - if setup.config_path.exists(): - display = _config_path_display(setup.config_path) - run_onboard = typer.confirm( - f"config at {display} is missing/invalid for " - f"{transport_backend.id}, run onboarding now?", - default=False, - ) - if run_onboard and anyio.run( - partial(transport_backend.interactive_setup, force=True) - ): - ( - settings_hint, - config_hint, - allowlist, - default_engine, - engine_backend, - ) = _resolve_setup_engine(default_engine_override) - setup = transport_backend.check_setup( - engine_backend, - transport_override=transport_override, - ) - elif anyio.run(partial(transport_backend.interactive_setup, force=False)): - ( - settings_hint, - config_hint, - allowlist, - default_engine, - engine_backend, - ) = _resolve_setup_engine(default_engine_override) - setup = transport_backend.check_setup( - engine_backend, - transport_override=transport_override, - ) - if not setup.ok: - if _setup_needs_config(setup): - _fail_missing_config(setup.config_path) - else: - first = setup.issues[0] - typer.echo(f"error: {first.title}", err=True) - raise typer.Exit(code=1) - try: - settings, config_path = load_settings() - if transport_override and transport_override != settings.transport: - settings = settings.model_copy(update={"transport": transport_override}) - spec = build_runtime_spec( - settings=settings, - config_path=config_path, - default_engine_override=default_engine_override, - reserved=RESERVED_CHAT_COMMANDS, - ) - if settings.transport == "telegram": - transport_config = settings.transports.telegram - else: - transport_config = settings.transport_config( - settings.transport, config_path=config_path - ) - lock_token = transport_backend.lock_token( - transport_config=transport_config, - _config_path=config_path, - ) - lock_handle = acquire_config_lock(config_path, lock_token) - runtime = spec.to_runtime(config_path=config_path) - transport_backend.build_and_run( - final_notify=final_notify, - default_engine_override=default_engine_override, - config_path=config_path, - transport_config=transport_config, - runtime=runtime, - ) - except ConfigError as e: - typer.echo(f"error: {e}", err=True) - raise typer.Exit(code=1) from e - except KeyboardInterrupt: - logger.info("shutdown.interrupted") - raise typer.Exit(code=130) from None - finally: - if lock_handle is not None: - lock_handle.release() - - -def _prompt_alias(value: str | None, *, default_alias: str | None = None) -> str: - if value is not None: - alias = value - elif default_alias: - alias = typer.prompt("project alias", default=default_alias) - else: - alias = typer.prompt("project alias") - alias = alias.strip() - if not alias: - typer.echo("error: project alias cannot be empty", err=True) - raise typer.Exit(code=1) - return alias - - -def _default_alias_from_path(path: Path) -> str | None: - name = path.name - if not name: - return None - name = name.removesuffix(".git") - return name or None - - -def _ensure_projects_table(config: dict, config_path: Path) -> dict: - projects = config.setdefault("projects", {}) - if not isinstance(projects, dict): - raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.") - return projects - - -def init( - alias: str | None = typer.Argument( - None, help="Project alias (used as /alias in messages)." - ), - default: bool = typer.Option( - False, - "--default", - help="Set this project as the default_project.", - ), -) -> None: - """Register the current repo as a Takopi project.""" - config, config_path = load_or_init_config() - if config_path.exists(): - applied = migrate_config(config, config_path=config_path) - if applied: - write_config(config, config_path) - - cwd = Path.cwd() - project_path = resolve_main_worktree_root(cwd) or cwd - default_alias = _default_alias_from_path(project_path) - alias = _prompt_alias(alias, default_alias=default_alias) - - settings = validate_settings_data(config, config_path=config_path) - allowlist = resolve_plugins_allowlist(settings) - engine_ids = list_backend_ids(allowlist=allowlist) - projects_cfg = settings.to_projects_config( - config_path=config_path, - engine_ids=engine_ids, - reserved=RESERVED_CHAT_COMMANDS, - ) - - alias_key = alias.lower() - if alias_key in {engine.lower() for engine in engine_ids}: - raise ConfigError( - f"Invalid project alias {alias!r}; aliases must not match engine ids." - ) - if alias_key in RESERVED_CHAT_COMMANDS: - raise ConfigError( - f"Invalid project alias {alias!r}; aliases must not match reserved commands." - ) - - existing = projects_cfg.projects.get(alias_key) - if existing is not None: - overwrite = typer.confirm( - f"project {existing.alias!r} already exists, overwrite?", - default=False, - ) - if not overwrite: - raise typer.Exit(code=1) - - projects = _ensure_projects_table(config, config_path) - if existing is not None and existing.alias in projects: - projects.pop(existing.alias, None) - - default_engine = settings.default_engine - worktree_base = resolve_default_base(project_path) - - entry: dict[str, object] = { - "path": str(project_path), - "worktrees_dir": ".worktrees", - "default_engine": default_engine, - } - if worktree_base: - entry["worktree_base"] = worktree_base - - projects[alias] = entry - if default: - config["default_project"] = alias - - write_config(config, config_path) - typer.echo(f"saved project {alias!r} to {_config_path_display(config_path)}") - - -def chat_id( - token: str | None = typer.Option( - None, - "--token", - help="Telegram bot token (defaults to config if available).", - ), - project: str | None = typer.Option( - None, - "--project", - help="Project alias to print a chat_id snippet for.", - ), -) -> None: - """Capture a Telegram chat id and exit.""" - setup_logging(debug=False, cache_logger_on_first_use=False) - if token is None: - settings, _ = _load_settings_optional() - if settings is not None: - tg = settings.transports.telegram - token = tg.bot_token or None - chat = anyio.run(partial(onboarding.capture_chat_id, token=token)) - if chat is None: - raise typer.Exit(code=1) - if project: - project = project.strip() - if not project: - raise ConfigError("Invalid `--project`; expected a non-empty string.") - - config, config_path = load_or_init_config() - if config_path.exists(): - applied = migrate_config(config, config_path=config_path) - if applied: - write_config(config, config_path) - - projects = _ensure_projects_table(config, config_path) - entry = projects.get(project) - if entry is None: - lowered = project.lower() - for key, value in projects.items(): - if isinstance(key, str) and key.lower() == lowered: - entry = value - project = key - break - if entry is None: - raise ConfigError( - f"Unknown project {project!r}; run `takopi init {project}` first." - ) - if not isinstance(entry, dict): - raise ConfigError( - f"Invalid `projects.{project}` in {config_path}; expected a table." - ) - entry["chat_id"] = chat.chat_id - write_config(config, config_path) - typer.echo(f"updated projects.{project}.chat_id = {chat.chat_id}") - return - - typer.echo(f"chat_id = {chat.chat_id}") - - -def onboarding_paths() -> None: - """Print all possible onboarding paths.""" - setup_logging(debug=False, cache_logger_on_first_use=False) - onboarding.debug_onboarding_paths() - - -def doctor() -> None: - """Run configuration checks for the active transport.""" - setup_logging(debug=False, cache_logger_on_first_use=False) - try: - settings, config_path = load_settings() - except ConfigError as exc: - typer.echo(f"error: {exc}", err=True) - raise typer.Exit(code=1) from exc - - if settings.transport != "telegram": - typer.echo( - "error: takopi doctor currently supports the telegram transport only.", - err=True, - ) - raise typer.Exit(code=1) - - allowlist = resolve_plugins_allowlist(settings) - engine_ids = list_backend_ids(allowlist=allowlist) - try: - projects_cfg = settings.to_projects_config( - config_path=config_path, - engine_ids=engine_ids, - reserved=RESERVED_CHAT_COMMANDS, - ) - except ConfigError as exc: - typer.echo(f"error: {exc}", err=True) - raise typer.Exit(code=1) from exc - - tg = settings.transports.telegram - project_chat_ids = projects_cfg.project_chat_ids() - telegram_checks = anyio.run( - _doctor_telegram_checks, - tg.bot_token, - tg.chat_id, - tg.topics, - project_chat_ids, - ) - if telegram_checks is None: - telegram_checks = [] - checks = [ - *telegram_checks, - *_doctor_file_checks(settings), - *_doctor_voice_checks(settings), - ] - typer.echo("takopi doctor") - for check in checks: - typer.echo(check.render()) - if any(check.status == "error" for check in checks): - raise typer.Exit(code=1) - - -def _print_entrypoints( - label: str, entrypoints: list[EntryPoint], *, allowlist: set[str] | None -) -> None: - typer.echo(f"{label}:") - if not entrypoints: - typer.echo(" (none)") - return - for ep in entrypoints: - dist = entrypoint_distribution_name(ep) or "unknown" - status = "" - if allowlist is not None: - allowed = is_entrypoint_allowed(ep, allowlist) - status = " enabled" if allowed else " disabled" - typer.echo(f" {ep.name} ({dist}){status}") - - -def plugins_cmd( - load: bool = typer.Option( - False, - "--load/--no-load", - help="Load plugins to validate and surface import errors.", - ), -) -> None: - """List discovered plugins and optionally validate them.""" - settings_hint, _ = _load_settings_optional() - allowlist = resolve_plugins_allowlist(settings_hint) - - allowlist_set = normalize_allowlist(allowlist) - engine_eps = list_entrypoints( - ENGINE_GROUP, - reserved_ids=RESERVED_ENGINE_IDS, - ) - transport_eps = list_entrypoints(TRANSPORT_GROUP) - command_eps = list_entrypoints( - COMMAND_GROUP, - reserved_ids=RESERVED_COMMAND_IDS, - ) - - _print_entrypoints("engine backends", engine_eps, allowlist=allowlist_set) - _print_entrypoints("transport backends", transport_eps, allowlist=allowlist_set) - _print_entrypoints("command backends", command_eps, allowlist=allowlist_set) - - if load: - for ep in engine_eps: - if allowlist_set is not None and not is_entrypoint_allowed( - ep, allowlist_set - ): - continue - try: - get_backend(ep.name, allowlist=allowlist) - except ConfigError: - continue - for ep in transport_eps: - if allowlist_set is not None and not is_entrypoint_allowed( - ep, allowlist_set - ): - continue - try: - get_transport(ep.name, allowlist=allowlist) - except ConfigError: - continue - for ep in command_eps: - if allowlist_set is not None and not is_entrypoint_allowed( - ep, allowlist_set - ): - continue - try: - get_command(ep.name, allowlist=allowlist) - except ConfigError: - continue - - errors = get_load_errors() - if errors: - typer.echo("errors:") - for err in errors: - group = err.group - if group == ENGINE_GROUP: - group = "engine" - elif group == TRANSPORT_GROUP: - group = "transport" - elif group == COMMAND_GROUP: - group = "command" - dist = err.distribution or "unknown" - typer.echo(f" {group} {err.name} ({dist}): {err.error}") - - -def _resolve_config_path_override(value: Path | None) -> Path: - if value is None: - return HOME_CONFIG_PATH - return value.expanduser() - - -def _exit_config_error(exc: ConfigError, *, code: int = 2) -> None: - typer.echo(f"error: {exc}", err=True) - raise typer.Exit(code=code) from exc - - -def _parse_key_path(raw: str) -> list[str]: - value = raw.strip() - if not value: - raise ConfigError("Invalid key path; expected a non-empty value.") - segments = value.split(".") - for segment in segments: - if not segment: - raise ConfigError(f"Invalid key path {raw!r}; empty segment.") - if not _KEY_SEGMENT_RE.fullmatch(segment): - raise ConfigError( - f"Invalid key segment {segment!r} in {raw!r}; " - "use only letters, numbers, '_' or '-'." - ) - return segments - - -def _parse_value(raw: str) -> Any: - value = raw.strip() - if not value: - return "" - try: - return tomllib.loads(f"__v__ = {value}")["__v__"] - except tomllib.TOMLDecodeError: - return value - - -def _toml_literal(value: Any) -> str: - dumped = dump_toml({"__v__": value}) - prefix = "__v__ = " - if dumped.startswith(prefix): - return dumped[len(prefix) :].rstrip("\n") - raise ConfigError("Unsupported config value; unable to render TOML literal.") - - -def _normalized_value_from_settings( - settings: TakopiSettings, segments: list[str] -) -> Any: - node: Any = settings - for segment in segments: - if isinstance(node, BaseModel): - if segment in node.__class__.model_fields: - node = getattr(node, segment) - else: - extra = node.model_extra or {} - node = extra.get(segment, _MISSING) - elif isinstance(node, dict): - node = node.get(segment, _MISSING) - else: - return _MISSING - if node is _MISSING: - return _MISSING - if isinstance(node, BaseModel): - return node.model_dump(exclude_unset=True) - return node - - -def _flatten_config(config: dict[str, Any]) -> list[tuple[str, Any]]: - items: list[tuple[str, Any]] = [] - - def _walk(node: Any, prefix: str) -> None: - if isinstance(node, dict): - for key in sorted(node): - value = node[key] - path = f"{prefix}.{key}" if prefix else key - if isinstance(value, dict): - _walk(value, path) - else: - items.append((path, value)) - elif prefix: - items.append((prefix, node)) - - _walk(config, "") - return items - - -def _load_config_or_exit(path: Path, *, missing_code: int) -> dict[str, Any]: - if not path.exists(): - _fail_missing_config(path) - raise typer.Exit(code=missing_code) - try: - return read_config(path) - except ConfigError as exc: - _exit_config_error(exc) - return {} - - -def config_path_cmd( - config_path: Path | None = _CONFIG_PATH_OPTION, -) -> None: - """Print the resolved config path.""" - path = _resolve_config_path_override(config_path) - typer.echo(_config_path_display(path)) - - -def config_list( - config_path: Path | None = _CONFIG_PATH_OPTION, -) -> None: - """List config keys as flattened dot-paths.""" - path = _resolve_config_path_override(config_path) - config = _load_config_or_exit(path, missing_code=1) - try: - for key, value in _flatten_config(config): - literal = _toml_literal(value) - typer.echo(f"{key} = {literal}") - except ConfigError as exc: - _exit_config_error(exc) - - -def config_get( - key: str = typer.Argument(..., help="Dot-path key to fetch."), - config_path: Path | None = _CONFIG_PATH_OPTION, -) -> None: - """Fetch a single config key.""" - path = _resolve_config_path_override(config_path) - config = _load_config_or_exit(path, missing_code=2) - try: - segments = _parse_key_path(key) - except ConfigError as exc: - _exit_config_error(exc) - - node: Any = config - for index, segment in enumerate(segments): - if not isinstance(node, dict): - prefix = ".".join(segments[:index]) - _exit_config_error( - ConfigError(f"Invalid `{prefix}` in {path}; expected a table.") - ) - if segment not in node: - raise typer.Exit(code=1) - node = node[segment] - - if isinstance(node, dict): - typer.echo( - f"error: {'.'.join(segments)!r} is a table; pick a leaf node.", - err=True, - ) - raise typer.Exit(code=2) - - try: - typer.echo(_toml_literal(node)) - except ConfigError as exc: - _exit_config_error(exc) - - -def config_set( - key: str = typer.Argument(..., help="Dot-path key to set."), - value: str = typer.Argument(..., help="Value to assign (auto-parsed)."), - config_path: Path | None = _CONFIG_PATH_OPTION, -) -> None: - """Set a config value.""" - path = _resolve_config_path_override(config_path) - config = _load_config_or_exit(path, missing_code=2) - try: - segments = _parse_key_path(key) - except ConfigError as exc: - _exit_config_error(exc) - - try: - migrate_config(config, config_path=path) - except ConfigError as exc: - _exit_config_error(exc) - - parsed = _parse_value(value) - node: Any = config - for index, segment in enumerate(segments[:-1]): - next_node = node.get(segment) - if next_node is None: - created: dict[str, Any] = {} - node[segment] = created - node = created - continue - if not isinstance(next_node, dict): - prefix = ".".join(segments[: index + 1]) - _exit_config_error( - ConfigError(f"Invalid `{prefix}` in {path}; expected a table.") - ) - node = next_node - node[segments[-1]] = parsed - - try: - settings = validate_settings_data(config, config_path=path) - except ConfigError as exc: - _exit_config_error(exc) - - normalized = _normalized_value_from_settings(settings, segments) - if normalized is not _MISSING: - node[segments[-1]] = normalized - parsed = normalized - - try: - write_config(config, path) - except ConfigError as exc: - _exit_config_error(exc) - - try: - rendered = _toml_literal(parsed) - except ConfigError as exc: - _exit_config_error(exc) - typer.echo(f"updated {'.'.join(segments)} = {rendered}") - - -def config_unset( - key: str = typer.Argument(..., help="Dot-path key to remove."), - config_path: Path | None = _CONFIG_PATH_OPTION, -) -> None: - """Remove a config key.""" - path = _resolve_config_path_override(config_path) - config = _load_config_or_exit(path, missing_code=2) - try: - segments = _parse_key_path(key) - except ConfigError as exc: - _exit_config_error(exc) - - try: - migrate_config(config, config_path=path) - except ConfigError as exc: - _exit_config_error(exc) - - node: Any = config - stack: list[tuple[dict[str, Any], str]] = [] - for index, segment in enumerate(segments[:-1]): - if not isinstance(node, dict): - prefix = ".".join(segments[:index]) - _exit_config_error( - ConfigError(f"Invalid `{prefix}` in {path}; expected a table.") - ) - next_node = node.get(segment) - if next_node is None: - raise typer.Exit(code=1) - if not isinstance(next_node, dict): - prefix = ".".join(segments[: index + 1]) - _exit_config_error( - ConfigError(f"Invalid `{prefix}` in {path}; expected a table.") - ) - stack.append((node, segment)) - node = next_node - - if not isinstance(node, dict): - prefix = ".".join(segments[:-1]) - _exit_config_error( - ConfigError(f"Invalid `{prefix}` in {path}; expected a table.") - ) - leaf = segments[-1] - if leaf not in node: - raise typer.Exit(code=1) - node.pop(leaf, None) - - while stack and not node: - parent, key_name = stack.pop() - parent.pop(key_name, None) - node = parent - - try: - validate_settings_data(config, config_path=path) - write_config(config, path) - except ConfigError as exc: - _exit_config_error(exc) - - -def app_main( - ctx: typer.Context, - version: bool = typer.Option( - False, - "--version", - help="Show the version and exit.", - callback=_version_callback, - is_eager=True, - ), - final_notify: bool = typer.Option( - True, - "--final-notify/--no-final-notify", - help="Send the final response as a new message (not an edit).", - ), - onboard: bool = typer.Option( - False, - "--onboard/--no-onboard", - help="Run the interactive setup wizard before starting.", - ), - transport: str | None = typer.Option( - None, - "--transport", - help="Override the transport backend id.", - ), - debug: bool = typer.Option( - False, - "--debug/--no-debug", - help="Log engine JSONL, Telegram requests, and rendered messages.", - ), -) -> None: - """Takopi CLI.""" - if ctx.invoked_subcommand is None: - _run_auto_router( - default_engine_override=None, - transport_override=transport, - final_notify=final_notify, - debug=debug, - onboard=onboard, - ) - raise typer.Exit() - - -def make_engine_cmd(engine_id: str) -> Callable[..., None]: - def _cmd( - final_notify: bool = typer.Option( - True, - "--final-notify/--no-final-notify", - help="Send the final response as a new message (not an edit).", - ), - onboard: bool = typer.Option( - False, - "--onboard/--no-onboard", - help="Run the interactive setup wizard before starting.", - ), - transport: str | None = typer.Option( - None, - "--transport", - help="Override the transport backend id.", - ), - debug: bool = typer.Option( - False, - "--debug/--no-debug", - help="Log engine JSONL, Telegram requests, and rendered messages.", - ), - ) -> None: - _run_auto_router( - default_engine_override=engine_id, - transport_override=transport, - final_notify=final_notify, - debug=debug, - onboard=onboard, - ) - - _cmd.__name__ = f"run_{engine_id}" - return _cmd - - -def _engine_ids_for_cli() -> list[str]: - allowlist: list[str] | None = None - try: - config, _ = load_or_init_config() - except ConfigError: - return list_backend_ids() - raw_plugins = config.get("plugins") - if isinstance(raw_plugins, dict): - enabled = raw_plugins.get("enabled") - if isinstance(enabled, list): - allowlist = [ - value.strip() - for value in enabled - if isinstance(value, str) and value.strip() - ] - if not allowlist: - allowlist = None - return list_backend_ids(allowlist=allowlist) - - -def create_app() -> typer.Typer: - app = typer.Typer( - add_completion=False, - invoke_without_command=True, - help="Telegram bridge for coding agents. Docs: https://takopi.dev/", - ) - config_app = typer.Typer(help="Read and modify takopi config.") - config_app.command(name="path")(config_path_cmd) - config_app.command(name="list")(config_list) - config_app.command(name="get")(config_get) - config_app.command(name="set")(config_set) - config_app.command(name="unset")(config_unset) - app.command(name="init")(init) - app.command(name="chat-id")(chat_id) - app.command(name="doctor")(doctor) - app.command(name="onboarding-paths")(onboarding_paths) - app.command(name="plugins")(plugins_cmd) - app.add_typer(config_app, name="config") - app.callback()(app_main) - for engine_id in _engine_ids_for_cli(): - help_text = f"Run with the {engine_id} engine." - app.command(name=engine_id, help=help_text)(make_engine_cmd(engine_id)) - return app - - -def main() -> None: - app = create_app() - app() - - -if __name__ == "__main__": - main() diff --git a/src/takopi/cli/__init__.py b/src/takopi/cli/__init__.py new file mode 100644 index 0000000..a9d7ab0 --- /dev/null +++ b/src/takopi/cli/__init__.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +# ruff: noqa: F401 + +from collections.abc import Callable +import sys +from pathlib import Path + +import typer + +from .. import __version__ +from ..config import ( + ConfigError, + HOME_CONFIG_PATH, + load_or_init_config, + write_config, +) +from ..config_migrations import migrate_config +from ..commands import get_command +from ..engines import get_backend, list_backend_ids +from ..ids import RESERVED_CHAT_COMMANDS, RESERVED_COMMAND_IDS, RESERVED_ENGINE_IDS +from ..lockfile import LockError, LockHandle, acquire_lock, token_fingerprint +from ..logging import setup_logging +from ..runtime_loader import build_runtime_spec, resolve_plugins_allowlist +from ..settings import ( + TakopiSettings, + load_settings, + load_settings_if_exists, + validate_settings_data, +) +from ..plugins import ( + COMMAND_GROUP, + ENGINE_GROUP, + TRANSPORT_GROUP, + entrypoint_distribution_name, + get_load_errors, + is_entrypoint_allowed, + list_entrypoints, + normalize_allowlist, +) +from ..transports import get_transport +from ..utils.git import resolve_default_base, resolve_main_worktree_root +from ..telegram import onboarding +from ..telegram.client import TelegramClient +from ..telegram.topics import _validate_topics_setup_for +from .doctor import ( + DoctorCheck, + DoctorStatus, + _doctor_file_checks, + _doctor_telegram_checks, + _doctor_voice_checks, + run_doctor, +) +from .init import ( + _default_alias_from_path, + _ensure_projects_table, + _prompt_alias, + run_init, +) +from .onboarding_cmd import chat_id, onboarding_paths +from .plugins import plugins_cmd +from .run import ( + _default_engine_for_setup, + _print_version_and_exit, + _resolve_setup_engine, + _resolve_transport_id, + _run_auto_router, + _setup_needs_config, + _should_run_interactive, + _version_callback, + acquire_config_lock, + app_main, + make_engine_cmd, +) +from .config import ( + _CONFIG_PATH_OPTION, + _config_path_display, + _exit_config_error, + _fail_missing_config, + _flatten_config, + _load_config_or_exit, + _normalized_value_from_settings, + _parse_key_path, + _parse_value, + _resolve_config_path_override, + _toml_literal, + config_get, + config_list, + config_path_cmd, + config_set, + config_unset, +) + + +def _load_settings_optional() -> tuple[TakopiSettings | None, Path | None]: + try: + loaded = load_settings_if_exists() + except ConfigError: + return None, None + if loaded is None: + return None, None + return loaded + + +def init( + alias: str | None = typer.Argument( + None, help="Project alias (used as /alias in messages)." + ), + default: bool = typer.Option( + False, + "--default", + help="Set this project as the default_project.", + ), +) -> None: + """Register the current repo as a Takopi project.""" + run_init( + alias=alias, + default=default, + load_or_init_config_fn=load_or_init_config, + resolve_main_worktree_root_fn=resolve_main_worktree_root, + resolve_default_base_fn=resolve_default_base, + list_backend_ids_fn=list_backend_ids, + resolve_plugins_allowlist_fn=resolve_plugins_allowlist, + ) + + +def doctor() -> None: + """Run configuration checks for the active transport.""" + setup_logging(debug=False, cache_logger_on_first_use=False) + run_doctor( + load_settings_fn=load_settings, + telegram_checks=_doctor_telegram_checks, + file_checks=_doctor_file_checks, + voice_checks=_doctor_voice_checks, + ) + + +def _engine_ids_for_cli() -> list[str]: + allowlist: list[str] | None = None + try: + config, _ = load_or_init_config() + except ConfigError: + return list_backend_ids() + raw_plugins = config.get("plugins") + if isinstance(raw_plugins, dict): + enabled = raw_plugins.get("enabled") + if isinstance(enabled, list): + allowlist = [ + value.strip() + for value in enabled + if isinstance(value, str) and value.strip() + ] + if not allowlist: + allowlist = None + return list_backend_ids(allowlist=allowlist) + + +def create_app() -> typer.Typer: + app = typer.Typer( + add_completion=False, + invoke_without_command=True, + help="Telegram bridge for coding agents. Docs: https://takopi.dev/", + ) + config_app = typer.Typer(help="Read and modify takopi config.") + config_app.command(name="path")(config_path_cmd) + config_app.command(name="list")(config_list) + config_app.command(name="get")(config_get) + config_app.command(name="set")(config_set) + config_app.command(name="unset")(config_unset) + app.command(name="init")(init) + app.command(name="chat-id")(chat_id) + app.command(name="doctor")(doctor) + app.command(name="onboarding-paths")(onboarding_paths) + app.command(name="plugins")(plugins_cmd) + app.add_typer(config_app, name="config") + app.callback()(app_main) + for engine_id in _engine_ids_for_cli(): + help_text = f"Run with the {engine_id} engine." + app.command(name=engine_id, help=help_text)(make_engine_cmd(engine_id)) + return app + + +def main() -> None: + app = create_app() + app() + + +if __name__ == "__main__": + main() diff --git a/src/takopi/cli/config.py b/src/takopi/cli/config.py new file mode 100644 index 0000000..eb4620b --- /dev/null +++ b/src/takopi/cli/config.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import re +import sys +import tomllib +from pathlib import Path +from typing import Any + +import typer +from pydantic import BaseModel + +from ..config import ( + ConfigError, + HOME_CONFIG_PATH, + dump_toml, + read_config, + write_config, +) +from ..config_migrations import migrate_config +from ..settings import TakopiSettings, validate_settings_data + +_KEY_SEGMENT_RE = re.compile(r"^[A-Za-z0-9_-]+$") +_MISSING = object() +_CONFIG_PATH_OPTION = typer.Option( + None, + "--config-path", + help="Override the default config path.", +) + + +def _config_path_display(path: Path) -> str: + home = Path.home() + try: + return f"~/{path.relative_to(home)}" + except ValueError: + return str(path) + + +def _fail_missing_config(path: Path) -> None: + display = _config_path_display(path) + if path.exists(): + typer.echo(f"error: invalid takopi config at {display}", err=True) + else: + typer.echo(f"error: missing takopi config at {display}", err=True) + + +def _resolve_config_path_override(value: Path | None) -> Path: + if value is None: + return _resolve_home_config_path() + return value.expanduser() + + +def _resolve_home_config_path() -> Path: + cli_module = sys.modules.get("takopi.cli") + if cli_module is not None: + override = getattr(cli_module, "HOME_CONFIG_PATH", None) + if override is not None: + return Path(override) + return HOME_CONFIG_PATH + + +def _exit_config_error(exc: ConfigError, *, code: int = 2) -> None: + typer.echo(f"error: {exc}", err=True) + raise typer.Exit(code=code) from exc + + +def _parse_key_path(raw: str) -> list[str]: + value = raw.strip() + if not value: + raise ConfigError("Invalid key path; expected a non-empty value.") + segments = value.split(".") + for segment in segments: + if not segment: + raise ConfigError(f"Invalid key path {raw!r}; empty segment.") + if not _KEY_SEGMENT_RE.fullmatch(segment): + raise ConfigError( + f"Invalid key segment {segment!r} in {raw!r}; " + "use only letters, numbers, '_' or '-'." + ) + return segments + + +def _parse_value(raw: str) -> Any: + value = raw.strip() + if not value: + return "" + try: + return tomllib.loads(f"__v__ = {value}")["__v__"] + except tomllib.TOMLDecodeError: + return value + + +def _toml_literal(value: Any) -> str: + dumped = dump_toml({"__v__": value}) + prefix = "__v__ = " + if dumped.startswith(prefix): + return dumped[len(prefix) :].rstrip("\n") + raise ConfigError("Unsupported config value; unable to render TOML literal.") + + +def _normalized_value_from_settings( + settings: TakopiSettings, segments: list[str] +) -> Any: + node: Any = settings + for segment in segments: + if isinstance(node, BaseModel): + if segment in node.__class__.model_fields: + node = getattr(node, segment) + else: + extra = node.model_extra or {} + node = extra.get(segment, _MISSING) + elif isinstance(node, dict): + node = node.get(segment, _MISSING) + else: + return _MISSING + if node is _MISSING: + return _MISSING + if isinstance(node, BaseModel): + return node.model_dump(exclude_unset=True) + return node + + +def _flatten_config(config: dict[str, Any]) -> list[tuple[str, Any]]: + items: list[tuple[str, Any]] = [] + + def _walk(node: Any, prefix: str) -> None: + if isinstance(node, dict): + for key in sorted(node): + value = node[key] + path = f"{prefix}.{key}" if prefix else key + if isinstance(value, dict): + _walk(value, path) + else: + items.append((path, value)) + elif prefix: + items.append((prefix, node)) + + _walk(config, "") + return items + + +def _load_config_or_exit(path: Path, *, missing_code: int) -> dict[str, Any]: + if not path.exists(): + _fail_missing_config(path) + raise typer.Exit(code=missing_code) + try: + return read_config(path) + except ConfigError as exc: + _exit_config_error(exc) + return {} + + +def config_path_cmd( + config_path: Path | None = _CONFIG_PATH_OPTION, +) -> None: + """Print the resolved config path.""" + path = _resolve_config_path_override(config_path) + typer.echo(_config_path_display(path)) + + +def config_list( + config_path: Path | None = _CONFIG_PATH_OPTION, +) -> None: + """List config keys as flattened dot-paths.""" + path = _resolve_config_path_override(config_path) + config = _load_config_or_exit(path, missing_code=1) + try: + for key, value in _flatten_config(config): + literal = _toml_literal(value) + typer.echo(f"{key} = {literal}") + except ConfigError as exc: + _exit_config_error(exc) + + +def config_get( + key: str = typer.Argument(..., help="Dot-path key to fetch."), + config_path: Path | None = _CONFIG_PATH_OPTION, +) -> None: + """Fetch a single config key.""" + path = _resolve_config_path_override(config_path) + config = _load_config_or_exit(path, missing_code=2) + try: + segments = _parse_key_path(key) + except ConfigError as exc: + _exit_config_error(exc) + + node: Any = config + for index, segment in enumerate(segments): + if not isinstance(node, dict): + prefix = ".".join(segments[:index]) + _exit_config_error( + ConfigError(f"Invalid `{prefix}` in {path}; expected a table.") + ) + if segment not in node: + raise typer.Exit(code=1) + node = node[segment] + + if isinstance(node, dict): + typer.echo( + f"error: {'.'.join(segments)!r} is a table; pick a leaf node.", + err=True, + ) + raise typer.Exit(code=2) + + try: + typer.echo(_toml_literal(node)) + except ConfigError as exc: + _exit_config_error(exc) + + +def config_set( + key: str = typer.Argument(..., help="Dot-path key to set."), + value: str = typer.Argument(..., help="Value to assign (auto-parsed)."), + config_path: Path | None = _CONFIG_PATH_OPTION, +) -> None: + """Set a config value.""" + path = _resolve_config_path_override(config_path) + config = _load_config_or_exit(path, missing_code=2) + try: + segments = _parse_key_path(key) + except ConfigError as exc: + _exit_config_error(exc) + + try: + migrate_config(config, config_path=path) + except ConfigError as exc: + _exit_config_error(exc) + + parsed = _parse_value(value) + node: Any = config + for index, segment in enumerate(segments[:-1]): + next_node = node.get(segment) + if next_node is None: + created: dict[str, Any] = {} + node[segment] = created + node = created + continue + if not isinstance(next_node, dict): + prefix = ".".join(segments[: index + 1]) + _exit_config_error( + ConfigError(f"Invalid `{prefix}` in {path}; expected a table.") + ) + node = next_node + node[segments[-1]] = parsed + + try: + settings = validate_settings_data(config, config_path=path) + except ConfigError as exc: + _exit_config_error(exc) + + normalized = _normalized_value_from_settings(settings, segments) + if normalized is not _MISSING: + node[segments[-1]] = normalized + parsed = normalized + + try: + write_config(config, path) + except ConfigError as exc: + _exit_config_error(exc) + + rendered = _toml_literal(parsed) + typer.echo(f"updated {'.'.join(segments)} = {rendered}") + + +def config_unset( + key: str = typer.Argument(..., help="Dot-path key to remove."), + config_path: Path | None = _CONFIG_PATH_OPTION, +) -> None: + """Remove a config key.""" + path = _resolve_config_path_override(config_path) + config = _load_config_or_exit(path, missing_code=2) + try: + segments = _parse_key_path(key) + except ConfigError as exc: + _exit_config_error(exc) + + try: + migrate_config(config, config_path=path) + except ConfigError as exc: + _exit_config_error(exc) + + node: Any = config + stack: list[tuple[dict[str, Any], str]] = [] + for index, segment in enumerate(segments[:-1]): + if not isinstance(node, dict): + prefix = ".".join(segments[:index]) + _exit_config_error( + ConfigError(f"Invalid `{prefix}` in {path}; expected a table.") + ) + next_node = node.get(segment) + if next_node is None: + raise typer.Exit(code=1) + if not isinstance(next_node, dict): + prefix = ".".join(segments[: index + 1]) + _exit_config_error( + ConfigError(f"Invalid `{prefix}` in {path}; expected a table.") + ) + stack.append((node, segment)) + node = next_node + + if not isinstance(node, dict): + prefix = ".".join(segments[:-1]) + _exit_config_error( + ConfigError(f"Invalid `{prefix}` in {path}; expected a table.") + ) + leaf = segments[-1] + if leaf not in node: + raise typer.Exit(code=1) + node.pop(leaf, None) + + while stack and not node: + parent, key_name = stack.pop() + parent.pop(key_name, None) + node = parent + + try: + validate_settings_data(config, config_path=path) + write_config(config, path) + except ConfigError as exc: + _exit_config_error(exc) diff --git a/src/takopi/cli/doctor.py b/src/takopi/cli/doctor.py new file mode 100644 index 0000000..f91c3c2 --- /dev/null +++ b/src/takopi/cli/doctor.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import os +import sys +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import anyio +import typer + +from ..config import ConfigError +from ..engines import list_backend_ids +from ..ids import RESERVED_CHAT_COMMANDS +from ..runtime_loader import resolve_plugins_allowlist +from ..settings import TakopiSettings, TelegramTopicsSettings +from ..telegram.client import TelegramClient +from ..telegram.topics import _validate_topics_setup_for + +DoctorStatus = Literal["ok", "warning", "error"] + + +@dataclass(frozen=True, slots=True) +class DoctorCheck: + label: str + status: DoctorStatus + detail: str | None = None + + def render(self) -> str: + if self.detail: + return f"- {self.label}: {self.status} ({self.detail})" + return f"- {self.label}: {self.status}" + + +def _doctor_file_checks(settings: TakopiSettings) -> list[DoctorCheck]: + files = settings.transports.telegram.files + if not files.enabled: + return [DoctorCheck("file transfer", "ok", "disabled")] + if files.allowed_user_ids: + count = len(files.allowed_user_ids) + detail = f"restricted to {count} user id(s)" + return [DoctorCheck("file transfer", "ok", detail)] + return [DoctorCheck("file transfer", "warning", "enabled for all users")] + + +def _doctor_voice_checks(settings: TakopiSettings) -> list[DoctorCheck]: + if not settings.transports.telegram.voice_transcription: + return [DoctorCheck("voice transcription", "ok", "disabled")] + if os.environ.get("OPENAI_API_KEY"): + return [DoctorCheck("voice transcription", "ok", "OPENAI_API_KEY set")] + return [DoctorCheck("voice transcription", "error", "OPENAI_API_KEY not set")] + + +async def _doctor_telegram_checks( + token: str, + chat_id: int, + topics: TelegramTopicsSettings, + project_chat_ids: tuple[int, ...], +) -> list[DoctorCheck]: + checks: list[DoctorCheck] = [] + client_factory = _resolve_cli_attr("TelegramClient") or TelegramClient + validate_topics = ( + _resolve_cli_attr("_validate_topics_setup_for") or _validate_topics_setup_for + ) + bot = client_factory(token) + try: + me = await bot.get_me() + if me is None: + checks.append( + DoctorCheck("telegram token", "error", "failed to fetch bot info") + ) + checks.append(DoctorCheck("chat_id", "error", "skipped (token invalid)")) + if topics.enabled: + checks.append(DoctorCheck("topics", "error", "skipped (token invalid)")) + else: + checks.append(DoctorCheck("topics", "ok", "disabled")) + return checks + bot_label = f"@{me.username}" if me.username else f"id={me.id}" + checks.append(DoctorCheck("telegram token", "ok", bot_label)) + chat = await bot.get_chat(chat_id) + if chat is None: + checks.append(DoctorCheck("chat_id", "error", f"unreachable ({chat_id})")) + else: + checks.append(DoctorCheck("chat_id", "ok", f"{chat.type} ({chat_id})")) + if topics.enabled: + try: + await validate_topics( + bot=bot, + topics=topics, + chat_id=chat_id, + project_chat_ids=project_chat_ids, + ) + checks.append(DoctorCheck("topics", "ok", f"scope={topics.scope}")) + except ConfigError as exc: + checks.append(DoctorCheck("topics", "error", str(exc))) + else: + checks.append(DoctorCheck("topics", "ok", "disabled")) + except Exception as exc: # noqa: BLE001 + checks.append(DoctorCheck("telegram", "error", str(exc))) + finally: + await bot.close() + return checks + + +def run_doctor( + *, + load_settings_fn: Callable[[], tuple[TakopiSettings, Path]], + telegram_checks: Callable[ + [str, int, TelegramTopicsSettings, tuple[int, ...]], + Awaitable[list[DoctorCheck]], + ], + file_checks: Callable[[TakopiSettings], list[DoctorCheck]], + voice_checks: Callable[[TakopiSettings], list[DoctorCheck]], +) -> None: + try: + settings, config_path = load_settings_fn() + except ConfigError as exc: + typer.echo(f"error: {exc}", err=True) + raise typer.Exit(code=1) from exc + + if settings.transport != "telegram": + typer.echo( + "error: takopi doctor currently supports the telegram transport only.", + err=True, + ) + raise typer.Exit(code=1) + + allowlist = resolve_plugins_allowlist(settings) + engine_ids = list_backend_ids(allowlist=allowlist) + try: + projects_cfg = settings.to_projects_config( + config_path=config_path, + engine_ids=engine_ids, + reserved=RESERVED_CHAT_COMMANDS, + ) + except ConfigError as exc: + typer.echo(f"error: {exc}", err=True) + raise typer.Exit(code=1) from exc + + tg = settings.transports.telegram + project_chat_ids = projects_cfg.project_chat_ids() + telegram_checks_result = anyio.run( + telegram_checks, + tg.bot_token, + tg.chat_id, + tg.topics, + project_chat_ids, + ) + if telegram_checks_result is None: + telegram_checks_result = [] + checks = [ + *telegram_checks_result, + *file_checks(settings), + *voice_checks(settings), + ] + typer.echo("takopi doctor") + for check in checks: + typer.echo(check.render()) + if any(check.status == "error" for check in checks): + raise typer.Exit(code=1) + + +def _resolve_cli_attr(name: str) -> object | None: + cli_module = sys.modules.get("takopi.cli") + if cli_module is None: + return None + return getattr(cli_module, name, None) diff --git a/src/takopi/cli/init.py b/src/takopi/cli/init.py new file mode 100644 index 0000000..f345c25 --- /dev/null +++ b/src/takopi/cli/init.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from collections.abc import Callable +from pathlib import Path + +import typer + +from ..config import ConfigError, write_config +from ..config_migrations import migrate_config +from ..ids import RESERVED_CHAT_COMMANDS +from ..settings import TakopiSettings, validate_settings_data +from .config import _config_path_display + + +def _prompt_alias(value: str | None, *, default_alias: str | None = None) -> str: + if value is not None: + alias = value + elif default_alias: + alias = typer.prompt("project alias", default=default_alias) + else: + alias = typer.prompt("project alias") + alias = alias.strip() + if not alias: + typer.echo("error: project alias cannot be empty", err=True) + raise typer.Exit(code=1) + return alias + + +def _default_alias_from_path(path: Path) -> str | None: + name = path.name + if not name: + return None + name = name.removesuffix(".git") + return name or None + + +def _ensure_projects_table(config: dict, config_path: Path) -> dict: + projects = config.setdefault("projects", {}) + if not isinstance(projects, dict): + raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.") + return projects + + +def run_init( + *, + alias: str | None, + default: bool, + load_or_init_config_fn: Callable[[], tuple[dict, Path]], + resolve_main_worktree_root_fn: Callable[[Path], Path | None], + resolve_default_base_fn: Callable[[Path], str | None], + list_backend_ids_fn: Callable[..., list[str]], + resolve_plugins_allowlist_fn: Callable[[TakopiSettings], list[str] | None], +) -> None: + config, config_path = load_or_init_config_fn() + if config_path.exists(): + applied = migrate_config(config, config_path=config_path) + if applied: + write_config(config, config_path) + + cwd = Path.cwd() + project_path = resolve_main_worktree_root_fn(cwd) or cwd + default_alias = _default_alias_from_path(project_path) + alias = _prompt_alias(alias, default_alias=default_alias) + + settings = validate_settings_data(config, config_path=config_path) + allowlist = resolve_plugins_allowlist_fn(settings) + engine_ids = list_backend_ids_fn(allowlist=allowlist) + projects_cfg = settings.to_projects_config( + config_path=config_path, + engine_ids=engine_ids, + reserved=RESERVED_CHAT_COMMANDS, + ) + + alias_key = alias.lower() + if alias_key in {engine.lower() for engine in engine_ids}: + raise ConfigError( + f"Invalid project alias {alias!r}; aliases must not match engine ids." + ) + if alias_key in RESERVED_CHAT_COMMANDS: + raise ConfigError( + f"Invalid project alias {alias!r}; aliases must not match reserved commands." + ) + + existing = projects_cfg.projects.get(alias_key) + if existing is not None: + overwrite = typer.confirm( + f"project {existing.alias!r} already exists, overwrite?", + default=False, + ) + if not overwrite: + raise typer.Exit(code=1) + + projects = _ensure_projects_table(config, config_path) + if existing is not None and existing.alias in projects: + projects.pop(existing.alias, None) + + default_engine = settings.default_engine + worktree_base = resolve_default_base_fn(project_path) + + entry: dict[str, object] = { + "path": str(project_path), + "worktrees_dir": ".worktrees", + "default_engine": default_engine, + } + if worktree_base: + entry["worktree_base"] = worktree_base + + projects[alias] = entry + if default: + config["default_project"] = alias + + write_config(config, config_path) + typer.echo(f"saved project {alias!r} to {_config_path_display(config_path)}") diff --git a/src/takopi/cli/onboarding_cmd.py b/src/takopi/cli/onboarding_cmd.py new file mode 100644 index 0000000..a9b3980 --- /dev/null +++ b/src/takopi/cli/onboarding_cmd.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import sys +from collections.abc import Callable +from functools import partial +from pathlib import Path +from typing import Any, cast + +import anyio +import typer + +from ..config import ConfigError, load_or_init_config, write_config +from ..config_migrations import migrate_config +from ..logging import setup_logging +from ..settings import TakopiSettings +from ..telegram import onboarding +from .init import _ensure_projects_table +from .run import _load_settings_optional + + +def chat_id( + token: str | None = typer.Option( + None, + "--token", + help="Telegram bot token (defaults to config if available).", + ), + project: str | None = typer.Option( + None, + "--project", + help="Project alias to print a chat_id snippet for.", + ), +) -> None: + """Capture a Telegram chat id and exit.""" + setup_logging_fn = cast( + Callable[..., None], + _resolve_cli_attr("setup_logging") or setup_logging, + ) + load_settings_optional_fn = cast( + Callable[[], tuple[TakopiSettings | None, Path | None]], + _resolve_cli_attr("_load_settings_optional") or _load_settings_optional, + ) + onboarding_mod = cast( + Any, + _resolve_cli_attr("onboarding") or onboarding, + ) + load_or_init_config_fn = cast( + Callable[[], tuple[dict, Path]], + _resolve_cli_attr("load_or_init_config") or load_or_init_config, + ) + ensure_projects_table_fn = cast( + Callable[[dict, Path], dict], + _resolve_cli_attr("_ensure_projects_table") or _ensure_projects_table, + ) + migrate_config_fn = cast( + Callable[..., object], + _resolve_cli_attr("migrate_config") or migrate_config, + ) + write_config_fn = cast( + Callable[[dict, Path], None], + _resolve_cli_attr("write_config") or write_config, + ) + + setup_logging_fn(debug=False, cache_logger_on_first_use=False) + if token is None: + settings, _ = load_settings_optional_fn() + if settings is not None: + tg = settings.transports.telegram + token = tg.bot_token or None + chat = anyio.run(partial(onboarding_mod.capture_chat_id, token=token)) + if chat is None: + raise typer.Exit(code=1) + if project: + project = project.strip() + if not project: + raise ConfigError("Invalid `--project`; expected a non-empty string.") + + config, config_path = load_or_init_config_fn() + if config_path.exists(): + applied = migrate_config_fn(config, config_path=config_path) + if applied: + write_config_fn(config, config_path) + + projects = ensure_projects_table_fn(config, config_path) + entry = projects.get(project) + if entry is None: + lowered = project.lower() + for key, value in projects.items(): + if isinstance(key, str) and key.lower() == lowered: + entry = value + project = key + break + if entry is None: + raise ConfigError( + f"Unknown project {project!r}; run `takopi init {project}` first." + ) + if not isinstance(entry, dict): + raise ConfigError( + f"Invalid `projects.{project}` in {config_path}; expected a table." + ) + entry["chat_id"] = chat.chat_id + write_config_fn(config, config_path) + typer.echo(f"updated projects.{project}.chat_id = {chat.chat_id}") + return + + typer.echo(f"chat_id = {chat.chat_id}") + + +def onboarding_paths() -> None: + """Print all possible onboarding paths.""" + setup_logging_fn = cast( + Callable[..., None], + _resolve_cli_attr("setup_logging") or setup_logging, + ) + onboarding_mod = cast( + Any, + _resolve_cli_attr("onboarding") or onboarding, + ) + setup_logging_fn(debug=False, cache_logger_on_first_use=False) + onboarding_mod.debug_onboarding_paths() + + +def _resolve_cli_attr(name: str) -> object | None: + cli_module = sys.modules.get("takopi.cli") + if cli_module is None: + return None + return getattr(cli_module, name, None) diff --git a/src/takopi/cli/plugins.py b/src/takopi/cli/plugins.py new file mode 100644 index 0000000..20723fd --- /dev/null +++ b/src/takopi/cli/plugins.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import sys +from collections.abc import Callable +from importlib.metadata import EntryPoint +from pathlib import Path +from typing import cast + +import typer + +from ..commands import get_command +from ..config import ConfigError +from ..engines import get_backend +from ..ids import RESERVED_COMMAND_IDS, RESERVED_ENGINE_IDS +from ..plugins import ( + COMMAND_GROUP, + ENGINE_GROUP, + PluginLoadError, + TRANSPORT_GROUP, + entrypoint_distribution_name, + get_load_errors, + is_entrypoint_allowed, + list_entrypoints, + normalize_allowlist, +) +from ..runtime_loader import resolve_plugins_allowlist +from ..settings import TakopiSettings, load_settings_if_exists +from ..transports import get_transport + + +def _load_settings_optional() -> tuple[TakopiSettings | None, Path | None]: + try: + loaded = load_settings_if_exists() + except ConfigError: + return None, None + if loaded is None: + return None, None + return loaded + + +def _print_entrypoints( + label: str, + entrypoints: list[EntryPoint], + *, + allowlist: set[str] | None, + entrypoint_distribution_name_fn: Callable[[EntryPoint], str | None], + is_entrypoint_allowed_fn: Callable[[EntryPoint, set[str] | None], bool], +) -> None: + typer.echo(f"{label}:") + if not entrypoints: + typer.echo(" (none)") + return + for ep in entrypoints: + dist = entrypoint_distribution_name_fn(ep) or "unknown" + status = "" + if allowlist is not None: + allowed = is_entrypoint_allowed_fn(ep, allowlist) + status = " enabled" if allowed else " disabled" + typer.echo(f" {ep.name} ({dist}){status}") + + +def plugins_cmd( + load: bool = typer.Option( + False, + "--load/--no-load", + help="Load plugins to validate and surface import errors.", + ), +) -> None: + """List discovered plugins and optionally validate them.""" + load_settings_optional = cast( + Callable[[], tuple[TakopiSettings | None, Path | None]], + _resolve_cli_attr("_load_settings_optional") or _load_settings_optional, + ) + resolve_plugins_allowlist_fn = cast( + Callable[[TakopiSettings | None], list[str] | None], + _resolve_cli_attr("resolve_plugins_allowlist") or resolve_plugins_allowlist, + ) + list_entrypoints_fn = cast( + Callable[..., list[EntryPoint]], + _resolve_cli_attr("list_entrypoints") or list_entrypoints, + ) + get_backend_fn = cast( + Callable[..., object], + _resolve_cli_attr("get_backend") or get_backend, + ) + get_transport_fn = cast( + Callable[..., object], + _resolve_cli_attr("get_transport") or get_transport, + ) + get_command_fn = cast( + Callable[..., object], + _resolve_cli_attr("get_command") or get_command, + ) + get_load_errors_fn = cast( + Callable[[], tuple[PluginLoadError, ...]], + _resolve_cli_attr("get_load_errors") or get_load_errors, + ) + entrypoint_distribution_name_fn = cast( + Callable[[EntryPoint], str | None], + _resolve_cli_attr("entrypoint_distribution_name") + or entrypoint_distribution_name, + ) + is_entrypoint_allowed_fn = cast( + Callable[[EntryPoint, set[str] | None], bool], + _resolve_cli_attr("is_entrypoint_allowed") or is_entrypoint_allowed, + ) + normalize_allowlist_fn = cast( + Callable[[list[str] | None], set[str] | None], + _resolve_cli_attr("normalize_allowlist") or normalize_allowlist, + ) + + settings_hint, _ = load_settings_optional() + allowlist = resolve_plugins_allowlist_fn(settings_hint) + + allowlist_set = normalize_allowlist_fn(allowlist) + engine_eps = list_entrypoints_fn( + ENGINE_GROUP, + reserved_ids=RESERVED_ENGINE_IDS, + ) + transport_eps = list_entrypoints_fn(TRANSPORT_GROUP) + command_eps = list_entrypoints_fn( + COMMAND_GROUP, + reserved_ids=RESERVED_COMMAND_IDS, + ) + + _print_entrypoints( + "engine backends", + engine_eps, + allowlist=allowlist_set, + entrypoint_distribution_name_fn=entrypoint_distribution_name_fn, + is_entrypoint_allowed_fn=is_entrypoint_allowed_fn, + ) + _print_entrypoints( + "transport backends", + transport_eps, + allowlist=allowlist_set, + entrypoint_distribution_name_fn=entrypoint_distribution_name_fn, + is_entrypoint_allowed_fn=is_entrypoint_allowed_fn, + ) + _print_entrypoints( + "command backends", + command_eps, + allowlist=allowlist_set, + entrypoint_distribution_name_fn=entrypoint_distribution_name_fn, + is_entrypoint_allowed_fn=is_entrypoint_allowed_fn, + ) + + if load: + for ep in engine_eps: + if allowlist_set is not None and not is_entrypoint_allowed_fn( + ep, allowlist_set + ): + continue + try: + get_backend_fn(ep.name, allowlist=allowlist) + except ConfigError: + continue + for ep in transport_eps: + if allowlist_set is not None and not is_entrypoint_allowed_fn( + ep, allowlist_set + ): + continue + try: + get_transport_fn(ep.name, allowlist=allowlist) + except ConfigError: + continue + for ep in command_eps: + if allowlist_set is not None and not is_entrypoint_allowed_fn( + ep, allowlist_set + ): + continue + try: + get_command_fn(ep.name, allowlist=allowlist) + except ConfigError: + continue + + errors = get_load_errors_fn() + if errors: + typer.echo("errors:") + for err in errors: + group = err.group + if group == ENGINE_GROUP: + group = "engine" + elif group == TRANSPORT_GROUP: + group = "transport" + elif group == COMMAND_GROUP: + group = "command" + dist = err.distribution or "unknown" + typer.echo(f" {group} {err.name} ({dist}): {err.error}") + + +def _resolve_cli_attr(name: str) -> object | None: + cli_module = sys.modules.get("takopi.cli") + if cli_module is None: + return None + return getattr(cli_module, name, None) diff --git a/src/takopi/cli/run.py b/src/takopi/cli/run.py new file mode 100644 index 0000000..6e2ac71 --- /dev/null +++ b/src/takopi/cli/run.py @@ -0,0 +1,419 @@ +from __future__ import annotations + +import os +import sys +from collections.abc import Callable +from functools import partial +from pathlib import Path +from typing import Any, cast + +import anyio +import typer + +from .. import __version__ +from ..backends import EngineBackend +from ..config import ConfigError, load_or_init_config +from ..engines import get_backend +from ..ids import RESERVED_CHAT_COMMANDS +from ..lockfile import LockError, LockHandle, acquire_lock, token_fingerprint +from ..logging import get_logger, setup_logging +from ..runtime_loader import build_runtime_spec, resolve_plugins_allowlist +from ..settings import TakopiSettings, load_settings, load_settings_if_exists +from ..transports import SetupResult, get_transport +from .config import _config_path_display, _fail_missing_config + +logger = get_logger(__name__) + + +def _load_settings_optional() -> tuple[TakopiSettings | None, Path | None]: + try: + loaded = load_settings_if_exists() + except ConfigError: + return None, None + if loaded is None: + return None, None + return loaded + + +def _resolve_transport_id(override: str | None) -> str: + if override is not None: + value = override.strip() + if not value: + raise ConfigError("Invalid `--transport`; expected a non-empty string.") + return value + load_or_init_config_fn = cast( + Callable[[], tuple[dict, Path]], + _resolve_cli_attr("load_or_init_config") or load_or_init_config, + ) + try: + config, _ = load_or_init_config_fn() + except ConfigError: + return "telegram" + raw = config.get("transport") + if not isinstance(raw, str) or not raw.strip(): + return "telegram" + return raw.strip() + + +def acquire_config_lock(config_path: Path, token: str | None) -> LockHandle: + fingerprint = token_fingerprint(token) if token else None + acquire_lock_fn = cast( + Callable[..., LockHandle], + _resolve_cli_attr("acquire_lock") or acquire_lock, + ) + try: + return acquire_lock_fn( + config_path=config_path, + token_fingerprint=fingerprint, + ) + except LockError as exc: + lines = str(exc).splitlines() + if lines: + typer.echo(lines[0], err=True) + if len(lines) > 1: + typer.echo("\n".join(lines[1:]), err=True) + else: + typer.echo("error: unknown error", err=True) + raise typer.Exit(code=1) from exc + + +def _default_engine_for_setup( + override: str | None, + *, + settings: TakopiSettings | None, + config_path: Path | None, +) -> str: + if override: + return override + if settings is None or config_path is None: + return "codex" + value = settings.default_engine + return value + + +def _resolve_setup_engine( + default_engine_override: str | None, +) -> tuple[ + TakopiSettings | None, + Path | None, + list[str] | None, + str, + EngineBackend, +]: + load_settings_optional_fn = cast( + Callable[[], tuple[TakopiSettings | None, Path | None]], + _resolve_cli_attr("_load_settings_optional") or _load_settings_optional, + ) + resolve_plugins_allowlist_fn = cast( + Callable[[TakopiSettings | None], list[str] | None], + _resolve_cli_attr("resolve_plugins_allowlist") or resolve_plugins_allowlist, + ) + default_engine_for_setup_fn = cast( + Callable[..., str], + _resolve_cli_attr("_default_engine_for_setup") or _default_engine_for_setup, + ) + get_backend_fn = cast( + Callable[..., EngineBackend], + _resolve_cli_attr("get_backend") or get_backend, + ) + + settings_hint, config_hint = load_settings_optional_fn() + allowlist = resolve_plugins_allowlist_fn(settings_hint) + default_engine = default_engine_for_setup_fn( + default_engine_override, + settings=settings_hint, + config_path=config_hint, + ) + engine_backend = get_backend_fn(default_engine, allowlist=allowlist) + return settings_hint, config_hint, allowlist, default_engine, engine_backend + + +def _should_run_interactive() -> bool: + if os.environ.get("TAKOPI_NO_INTERACTIVE"): + return False + return sys.stdin.isatty() and sys.stdout.isatty() + + +def _setup_needs_config(setup: SetupResult) -> bool: + config_titles = {"create a config", "configure telegram"} + return any(issue.title in config_titles for issue in setup.issues) + + +def _run_auto_router( + *, + default_engine_override: str | None, + transport_override: str | None, + final_notify: bool, + debug: bool, + onboard: bool, +) -> None: + setup_logging_fn = cast( + Callable[..., None], + _resolve_cli_attr("setup_logging") or setup_logging, + ) + resolve_setup_engine_fn = cast( + Callable[ + [str | None], + tuple[ + TakopiSettings | None, + Path | None, + list[str] | None, + str, + EngineBackend, + ], + ], + _resolve_cli_attr("_resolve_setup_engine") or _resolve_setup_engine, + ) + resolve_transport_id_fn = cast( + Callable[[str | None], str], + _resolve_cli_attr("_resolve_transport_id") or _resolve_transport_id, + ) + get_transport_fn = cast( + Callable[..., Any], + _resolve_cli_attr("get_transport") or get_transport, + ) + should_run_interactive_fn = cast( + Callable[[], bool], + _resolve_cli_attr("_should_run_interactive") or _should_run_interactive, + ) + setup_needs_config_fn = cast( + Callable[[SetupResult], bool], + _resolve_cli_attr("_setup_needs_config") or _setup_needs_config, + ) + config_path_display_fn = cast( + Callable[[Path], str], + _resolve_cli_attr("_config_path_display") or _config_path_display, + ) + fail_missing_config_fn = cast( + Callable[[Path], None], + _resolve_cli_attr("_fail_missing_config") or _fail_missing_config, + ) + load_settings_fn = cast( + Callable[[], tuple[TakopiSettings, Path]], + _resolve_cli_attr("load_settings") or load_settings, + ) + build_runtime_spec_fn = cast( + Callable[..., Any], + _resolve_cli_attr("build_runtime_spec") or build_runtime_spec, + ) + acquire_config_lock_fn = cast( + Callable[[Path, str | None], LockHandle], + _resolve_cli_attr("acquire_config_lock") or acquire_config_lock, + ) + + if debug: + os.environ.setdefault("TAKOPI_LOG_FILE", "debug.log") + setup_logging_fn(debug=debug) + lock_handle: LockHandle | None = None + try: + ( + settings_hint, + config_hint, + allowlist, + default_engine, + engine_backend, + ) = resolve_setup_engine_fn(default_engine_override) + transport_id = resolve_transport_id_fn(transport_override) + transport_backend = get_transport_fn(transport_id, allowlist=allowlist) + except ConfigError as exc: + typer.echo(f"error: {exc}", err=True) + raise typer.Exit(code=1) from exc + if onboard: + if not should_run_interactive_fn(): + typer.echo("error: --onboard requires a TTY", err=True) + raise typer.Exit(code=1) + if not anyio.run(partial(transport_backend.interactive_setup, force=True)): + raise typer.Exit(code=1) + ( + settings_hint, + config_hint, + allowlist, + default_engine, + engine_backend, + ) = resolve_setup_engine_fn(default_engine_override) + setup = transport_backend.check_setup( + engine_backend, + transport_override=transport_override, + ) + if not setup.ok: + if setup_needs_config_fn(setup) and should_run_interactive_fn(): + if setup.config_path.exists(): + display = config_path_display_fn(setup.config_path) + run_onboard = typer.confirm( + f"config at {display} is missing/invalid for " + f"{transport_backend.id}, run onboarding now?", + default=False, + ) + if run_onboard and anyio.run( + partial(transport_backend.interactive_setup, force=True) + ): + ( + settings_hint, + config_hint, + allowlist, + default_engine, + engine_backend, + ) = resolve_setup_engine_fn(default_engine_override) + setup = transport_backend.check_setup( + engine_backend, + transport_override=transport_override, + ) + elif anyio.run(partial(transport_backend.interactive_setup, force=False)): + ( + settings_hint, + config_hint, + allowlist, + default_engine, + engine_backend, + ) = resolve_setup_engine_fn(default_engine_override) + setup = transport_backend.check_setup( + engine_backend, + transport_override=transport_override, + ) + if not setup.ok: + if setup_needs_config_fn(setup): + fail_missing_config_fn(setup.config_path) + else: + first = setup.issues[0] + typer.echo(f"error: {first.title}", err=True) + raise typer.Exit(code=1) + try: + settings, config_path = load_settings_fn() + if transport_override and transport_override != settings.transport: + settings = settings.model_copy(update={"transport": transport_override}) + spec = build_runtime_spec_fn( + settings=settings, + config_path=config_path, + default_engine_override=default_engine_override, + reserved=RESERVED_CHAT_COMMANDS, + ) + if settings.transport == "telegram": + transport_config = settings.transports.telegram + else: + transport_config = settings.transport_config( + settings.transport, config_path=config_path + ) + lock_token = transport_backend.lock_token( + transport_config=transport_config, + _config_path=config_path, + ) + lock_handle = acquire_config_lock_fn(config_path, lock_token) + runtime = spec.to_runtime(config_path=config_path) + transport_backend.build_and_run( + final_notify=final_notify, + default_engine_override=default_engine_override, + config_path=config_path, + transport_config=transport_config, + runtime=runtime, + ) + except ConfigError as exc: + typer.echo(f"error: {exc}", err=True) + raise typer.Exit(code=1) from exc + except KeyboardInterrupt: + logger.info("shutdown.interrupted") + raise typer.Exit(code=130) from None + finally: + if lock_handle is not None: + lock_handle.release() + + +def _print_version_and_exit() -> None: + typer.echo(__version__) + raise typer.Exit() + + +def _version_callback(value: bool) -> None: + if value: + _print_version_and_exit() + + +def app_main( + ctx: typer.Context, + version: bool = typer.Option( + False, + "--version", + help="Show the version and exit.", + callback=_version_callback, + is_eager=True, + ), + final_notify: bool = typer.Option( + True, + "--final-notify/--no-final-notify", + help="Send the final response as a new message (not an edit).", + ), + onboard: bool = typer.Option( + False, + "--onboard/--no-onboard", + help="Run the interactive setup wizard before starting.", + ), + transport: str | None = typer.Option( + None, + "--transport", + help="Override the transport backend id.", + ), + debug: bool = typer.Option( + False, + "--debug/--no-debug", + help="Log engine JSONL, Telegram requests, and rendered messages.", + ), +) -> None: + """Takopi CLI.""" + if ctx.invoked_subcommand is None: + run_auto_router = cast( + Callable[..., None], + _resolve_cli_attr("_run_auto_router") or _run_auto_router, + ) + run_auto_router( + default_engine_override=None, + transport_override=transport, + final_notify=final_notify, + debug=debug, + onboard=onboard, + ) + raise typer.Exit() + + +def make_engine_cmd(engine_id: str) -> Callable[..., None]: + def _cmd( + final_notify: bool = typer.Option( + True, + "--final-notify/--no-final-notify", + help="Send the final response as a new message (not an edit).", + ), + onboard: bool = typer.Option( + False, + "--onboard/--no-onboard", + help="Run the interactive setup wizard before starting.", + ), + transport: str | None = typer.Option( + None, + "--transport", + help="Override the transport backend id.", + ), + debug: bool = typer.Option( + False, + "--debug/--no-debug", + help="Log engine JSONL, Telegram requests, and rendered messages.", + ), + ) -> None: + run_auto_router = cast( + Callable[..., None], + _resolve_cli_attr("_run_auto_router") or _run_auto_router, + ) + run_auto_router( + default_engine_override=engine_id, + transport_override=transport, + final_notify=final_notify, + debug=debug, + onboard=onboard, + ) + + _cmd.__name__ = f"run_{engine_id}" + return _cmd + + +def _resolve_cli_attr(name: str) -> object | None: + cli_module = sys.modules.get("takopi.cli") + if cli_module is None: + return None + return getattr(cli_module, name, None) diff --git a/src/takopi/runner.py b/src/takopi/runner.py index a65649e..da1e974 100644 --- a/src/takopi/runner.py +++ b/src/takopi/runner.py @@ -131,6 +131,15 @@ class JsonlRunState: note_seq: int = 0 +@dataclass(slots=True) +class JsonlStreamState: + expected_session: ResumeToken | None + found_session: ResumeToken | None = None + did_emit_completed: bool = False + ignored_after_completed: bool = False + jsonl_seq: int = 0 + + class JsonlSubprocessRunner(BaseRunner): def get_logger(self) -> Any: return getattr(self, "logger", get_logger(__name__)) @@ -340,6 +349,250 @@ class JsonlSubprocessRunner(BaseRunner): raise RuntimeError(message) return found_session, False + async def _send_payload( + self, + proc: Any, + payload: bytes | None, + *, + logger: Any, + resume: ResumeToken | None, + ) -> None: + if payload is not None: + assert proc.stdin is not None + await proc.stdin.send(payload) + await proc.stdin.aclose() + logger.info( + "subprocess.stdin.send", + pid=proc.pid, + resume=resume.value if resume else None, + bytes=len(payload), + ) + elif proc.stdin is not None: + await proc.stdin.aclose() + + def _decode_jsonl_events( + self, + *, + raw_line: bytes, + line: bytes, + jsonl_seq: int, + state: Any, + resume: ResumeToken | None, + found_session: ResumeToken | None, + logger: Any, + pid: int, + ) -> list[TakopiEvent]: + raw_text = raw_line.decode("utf-8", errors="replace") + line_text = line.decode("utf-8", errors="replace") + try: + decoded = self.decode_jsonl(line=line) + except Exception as exc: # noqa: BLE001 + log_pipeline( + logger, + "jsonl.parse.error", + pid=pid, + jsonl_seq=jsonl_seq, + line=line_text, + error=str(exc), + ) + return self.decode_error_events( + raw=raw_text, + line=line_text, + error=exc, + state=state, + ) + if decoded is None: + log_pipeline( + logger, + "jsonl.parse.invalid", + pid=pid, + jsonl_seq=jsonl_seq, + line=line_text, + ) + logger.info( + "runner.jsonl.invalid", + pid=pid, + jsonl_seq=jsonl_seq, + line=line_text, + ) + return self.invalid_json_events( + raw=raw_text, + line=line_text, + state=state, + ) + try: + return self.translate( + decoded, + state=state, + resume=resume, + found_session=found_session, + ) + except Exception as exc: # noqa: BLE001 + log_pipeline( + logger, + "runner.translate.error", + pid=pid, + jsonl_seq=jsonl_seq, + error=str(exc), + ) + return self.translate_error_events( + data=decoded, + error=exc, + state=state, + ) + + def _process_started_event( + self, + event: StartedEvent, + *, + expected_session: ResumeToken | None, + found_session: ResumeToken | None, + logger: Any, + pid: int, + jsonl_seq: int, + ) -> tuple[ResumeToken | None, bool]: + prior_found = found_session + try: + found_session, emit = self.handle_started_event( + event, + expected_session=expected_session, + found_session=found_session, + ) + except Exception as exc: + log_pipeline( + logger, + "runner.started.error", + pid=pid, + jsonl_seq=jsonl_seq, + resume=event.resume.value, + expected_session=expected_session.value if expected_session else None, + found_session=prior_found.value if prior_found else None, + error=str(exc), + ) + raise + if prior_found is None and emit: + reason = ( + "matched_expected" if expected_session is not None else "first_seen" + ) + elif prior_found is not None and not emit: + reason = "duplicate" + else: + reason = "unknown" + log_pipeline( + logger, + "runner.started.seen", + pid=pid, + jsonl_seq=jsonl_seq, + resume=event.resume.value, + expected_session=expected_session.value if expected_session else None, + found_session=found_session.value if found_session else None, + emit=emit, + reason=reason, + ) + return found_session, emit + + def _log_completed_event( + self, + *, + logger: Any, + pid: int, + event: CompletedEvent, + jsonl_seq: int | None = None, + source: str | None = None, + ) -> None: + payload: dict[str, Any] = { + "pid": pid, + "ok": event.ok, + "has_answer": bool(event.answer.strip()), + "emit": True, + } + if jsonl_seq is not None: + payload["jsonl_seq"] = jsonl_seq + if source is not None: + payload["source"] = source + log_pipeline(logger, "runner.completed.seen", **payload) + + def _handle_jsonl_line( + self, + *, + raw_line: bytes, + stream: JsonlStreamState, + state: Any, + resume: ResumeToken | None, + logger: Any, + pid: int, + ) -> list[TakopiEvent]: + if stream.did_emit_completed: + if not stream.ignored_after_completed: + log_pipeline( + logger, + "runner.drop.jsonl_after_completed", + pid=pid, + ) + stream.ignored_after_completed = True + return [] + line = raw_line.strip() + if not line: + return [] + stream.jsonl_seq += 1 + seq = stream.jsonl_seq + events = self._decode_jsonl_events( + raw_line=raw_line, + line=line, + jsonl_seq=seq, + state=state, + resume=resume, + found_session=stream.found_session, + logger=logger, + pid=pid, + ) + output: list[TakopiEvent] = [] + for evt in events: + if isinstance(evt, StartedEvent): + stream.found_session, emit = self._process_started_event( + evt, + expected_session=stream.expected_session, + found_session=stream.found_session, + logger=logger, + pid=pid, + jsonl_seq=seq, + ) + if not emit: + continue + if isinstance(evt, CompletedEvent): + stream.did_emit_completed = True + self._log_completed_event( + logger=logger, + pid=pid, + event=evt, + jsonl_seq=seq, + ) + output.append(evt) + break + output.append(evt) + return output + + async def _iter_jsonl_events( + self, + *, + stdout: Any, + stream: JsonlStreamState, + state: Any, + resume: ResumeToken | None, + logger: Any, + pid: int, + ) -> AsyncIterator[TakopiEvent]: + async for raw_line in self.iter_json_lines(stdout): + for evt in self._handle_jsonl_line( + raw_line=raw_line, + stream=stream, + state=state, + resume=resume, + logger=logger, + pid=pid, + ): + yield evt + async def run_impl( self, prompt: str, resume: ResumeToken | None ) -> AsyncIterator[TakopiEvent]: @@ -381,25 +634,10 @@ class JsonlSubprocessRunner(BaseRunner): pid=proc.pid, ) - if payload is not None: - assert proc.stdin is not None - await proc.stdin.send(payload) - await proc.stdin.aclose() - logger.info( - "subprocess.stdin.send", - pid=proc.pid, - resume=resume.value if resume else None, - bytes=len(payload), - ) - elif proc.stdin is not None: - await proc.stdin.aclose() + await self._send_payload(proc, payload, logger=logger, resume=resume) rc: int | None = None - expected_session: ResumeToken | None = resume - found_session: ResumeToken | None = None - did_emit_completed = False - ignored_after_completed = False - jsonl_seq = 0 + stream = JsonlStreamState(expected_session=resume) async with anyio.create_task_group() as tg: tg.start_soon( @@ -408,154 +646,22 @@ class JsonlSubprocessRunner(BaseRunner): logger, tag, ) - async for raw_line in self.iter_json_lines(proc.stdout): - if did_emit_completed: - if not ignored_after_completed: - log_pipeline( - logger, - "runner.drop.jsonl_after_completed", - pid=proc.pid, - ) - ignored_after_completed = True - continue - line = raw_line.strip() - if not line: - continue - jsonl_seq += 1 - seq = jsonl_seq - raw_text = raw_line.decode("utf-8", errors="replace") - line_text = line.decode("utf-8", errors="replace") - try: - decoded = self.decode_jsonl(line=line) - except Exception as exc: # noqa: BLE001 - log_pipeline( - logger, - "jsonl.parse.error", - pid=proc.pid, - jsonl_seq=seq, - line=line_text, - error=str(exc), - ) - events = self.decode_error_events( - raw=raw_text, - line=line_text, - error=exc, - state=state, - ) - else: - if decoded is None: - log_pipeline( - logger, - "jsonl.parse.invalid", - pid=proc.pid, - jsonl_seq=seq, - line=line_text, - ) - logger.info( - "runner.jsonl.invalid", - pid=proc.pid, - jsonl_seq=seq, - line=line_text, - ) - events = self.invalid_json_events( - raw=raw_text, - line=line_text, - state=state, - ) - else: - try: - events = self.translate( - decoded, - state=state, - resume=resume, - found_session=found_session, - ) - except Exception as exc: # noqa: BLE001 - log_pipeline( - logger, - "runner.translate.error", - pid=proc.pid, - jsonl_seq=seq, - error=str(exc), - ) - events = self.translate_error_events( - data=decoded, - error=exc, - state=state, - ) - - for evt in events: - if isinstance(evt, StartedEvent): - prior_found = found_session - try: - found_session, emit = self.handle_started_event( - evt, - expected_session=expected_session, - found_session=found_session, - ) - except Exception as exc: - log_pipeline( - logger, - "runner.started.error", - pid=proc.pid, - jsonl_seq=seq, - resume=evt.resume.value, - expected_session=expected_session.value - if expected_session - else None, - found_session=prior_found.value - if prior_found - else None, - error=str(exc), - ) - raise - if prior_found is None and emit: - reason = ( - "matched_expected" - if expected_session is not None - else "first_seen" - ) - elif prior_found is not None and not emit: - reason = "duplicate" - else: - reason = "unknown" - log_pipeline( - logger, - "runner.started.seen", - pid=proc.pid, - jsonl_seq=seq, - resume=evt.resume.value, - expected_session=expected_session.value - if expected_session - else None, - found_session=found_session.value - if found_session - else None, - emit=emit, - reason=reason, - ) - if not emit: - continue - if isinstance(evt, CompletedEvent): - did_emit_completed = True - log_pipeline( - logger, - "runner.completed.seen", - pid=proc.pid, - jsonl_seq=seq, - ok=evt.ok, - has_answer=bool(evt.answer.strip()), - emit=True, - ) - yield evt - break - yield evt + async for evt in self._iter_jsonl_events( + stdout=proc.stdout, + stream=stream, + state=state, + resume=resume, + logger=logger, + pid=proc.pid, + ): + yield evt rc = await proc.wait() logger.info("subprocess.exit", pid=proc.pid, rc=rc) - if did_emit_completed: + if stream.did_emit_completed: return + found_session = stream.found_session if rc is not None and rc != 0: events = self.process_error_events( rc, @@ -565,13 +671,10 @@ class JsonlSubprocessRunner(BaseRunner): ) for evt in events: if isinstance(evt, CompletedEvent): - log_pipeline( - logger, - "runner.completed.seen", + self._log_completed_event( + logger=logger, pid=proc.pid, - ok=evt.ok, - has_answer=bool(evt.answer.strip()), - emit=True, + event=evt, source="process_error", ) yield evt @@ -584,13 +687,10 @@ class JsonlSubprocessRunner(BaseRunner): ) for evt in events: if isinstance(evt, CompletedEvent): - log_pipeline( - logger, - "runner.completed.seen", + self._log_completed_event( + logger=logger, pid=proc.pid, - ok=evt.ok, - has_answer=bool(evt.answer.strip()), - emit=True, + event=evt, source="stream_end", ) yield evt diff --git a/src/takopi/telegram/commands/agent.py b/src/takopi/telegram/commands/agent.py index 2c6bc6d..a01ae68 100644 --- a/src/takopi/telegram/commands/agent.py +++ b/src/takopi/telegram/commands/agent.py @@ -27,10 +27,7 @@ async def _check_agent_permissions( if sender_id is None: await reply(text="cannot verify sender for agent defaults.") return False - is_private = msg.chat_type == "private" - if msg.chat_type is None: - is_private = msg.chat_id > 0 - if is_private: + if msg.is_private: return True member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) if member is None: diff --git a/src/takopi/telegram/commands/file_transfer.py b/src/takopi/telegram/commands/file_transfer.py index 43266ec..74e897b 100644 --- a/src/takopi/telegram/commands/file_transfer.py +++ b/src/takopi/telegram/commands/file_transfer.py @@ -107,10 +107,7 @@ async def _check_file_permissions( await reply(text="file transfer is not allowed for this user.") return False return True - is_private = msg.chat_type == "private" - if msg.chat_type is None: - is_private = msg.chat_id > 0 - if is_private: + if msg.is_private: return True member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) if member is None: diff --git a/src/takopi/telegram/commands/handlers.py b/src/takopi/telegram/commands/handlers.py new file mode 100644 index 0000000..d03be2a --- /dev/null +++ b/src/takopi/telegram/commands/handlers.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +# ruff: noqa: F401 + +from .agent import _handle_agent_command as handle_agent_command +from .dispatch import _dispatch_command as dispatch_command +from .executor import _run_engine as run_engine +from .executor import _should_show_resume_line as should_show_resume_line +from .file_transfer import _handle_file_command as handle_file_command +from .file_transfer import _handle_file_put_default as handle_file_put_default +from .file_transfer import _save_file_put as save_file_put +from .media import _handle_media_group as handle_media_group +from .menu import _reserved_commands as get_reserved_commands +from .menu import _set_command_menu as set_command_menu +from .model import _handle_model_command as handle_model_command +from .parse import _parse_slash_command as parse_slash_command +from .reasoning import _handle_reasoning_command as handle_reasoning_command +from .topics import _handle_chat_new_command as handle_chat_new_command +from .topics import _handle_ctx_command as handle_ctx_command +from .topics import _handle_new_command as handle_new_command +from .topics import _handle_topic_command as handle_topic_command +from .trigger import _handle_trigger_command as handle_trigger_command + +__all__ = [ + "dispatch_command", + "get_reserved_commands", + "handle_agent_command", + "handle_chat_new_command", + "handle_ctx_command", + "handle_file_command", + "handle_file_put_default", + "handle_media_group", + "handle_model_command", + "handle_new_command", + "handle_reasoning_command", + "handle_topic_command", + "handle_trigger_command", + "parse_slash_command", + "run_engine", + "save_file_put", + "set_command_menu", + "should_show_resume_line", +] diff --git a/src/takopi/telegram/commands/model.py b/src/takopi/telegram/commands/model.py index e8ae395..81a2a97 100644 --- a/src/takopi/telegram/commands/model.py +++ b/src/takopi/telegram/commands/model.py @@ -3,14 +3,20 @@ from __future__ import annotations from typing import TYPE_CHECKING from ...context import RunContext -from ...directives import DirectiveError from ..chat_prefs import ChatPrefsStore -from ..engine_defaults import resolve_engine_for_message from ..engine_overrides import EngineOverrides, resolve_override_value from ..files import split_command_args from ..topic_state import TopicStateStore from ..topics import _topic_key from ..types import TelegramIncomingMessage +from .overrides import ( + ENGINE_SOURCE_LABELS, + OVERRIDE_SOURCE_LABELS, + apply_engine_override, + parse_set_args, + require_admin_or_private, + resolve_engine_selection, +) from .reply import make_reply if TYPE_CHECKING: @@ -22,79 +28,6 @@ MODEL_USAGE = ( ) -async def _check_model_permissions( - cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage -) -> bool: - reply = make_reply(cfg, msg) - sender_id = msg.sender_id - if sender_id is None: - await reply(text="cannot verify sender for model overrides.") - return False - is_private = msg.chat_type == "private" - if msg.chat_type is None: - is_private = msg.chat_id > 0 - if is_private: - return True - member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) - if member is None: - await reply(text="failed to verify model override permissions.") - return False - if member.status in {"creator", "administrator"}: - return True - await reply(text="changing model overrides is restricted to group admins.") - return False - - -async def _resolve_engine_selection( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - *, - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, - chat_prefs: ChatPrefsStore | None, - topic_key: tuple[int, int] | None, -) -> tuple[str, str] | None: - reply = make_reply(cfg, msg) - try: - resolved = cfg.runtime.resolve_message( - text="", - reply_text=msg.reply_to_text, - ambient_context=ambient_context, - chat_id=msg.chat_id, - ) - except DirectiveError as exc: - await reply(text=f"error:\n{exc}") - return None - selection = await resolve_engine_for_message( - runtime=cfg.runtime, - context=resolved.context, - explicit_engine=None, - chat_id=msg.chat_id, - topic_key=topic_key, - topic_store=topic_store, - chat_prefs=chat_prefs, - ) - return selection.engine, selection.source - - -def _parse_set_args( - tokens: tuple[str, ...], *, engine_ids: set[str] -) -> tuple[str | None, str | None]: - if len(tokens) < 2: - return None, None - if len(tokens) == 2: - maybe_engine = tokens[1].strip().lower() - if maybe_engine in engine_ids: - return None, None - return None, tokens[1].strip() - maybe_engine = tokens[1].strip().lower() - if maybe_engine in engine_ids: - model = " ".join(tokens[2:]).strip() - return maybe_engine, model or None - model = " ".join(tokens[1:]).strip() - return None, model or None - - async def _handle_model_command( cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, @@ -117,7 +50,7 @@ async def _handle_model_command( engine_ids = {engine.lower() for engine in cfg.runtime.engine_ids} if action in {"show", ""}: - selection = await _resolve_engine_selection( + selection = await resolve_engine_selection( cfg, msg, ambient_context=ambient_context, @@ -141,21 +74,11 @@ async def _handle_model_command( chat_override=chat_override, field="model", ) - source_labels = { - "directive": "directive", - "topic_default": "topic default", - "chat_default": "chat default", - "project_default": "project default", - "global_default": "global default", - } - override_labels = { - "topic_override": "topic override", - "chat_default": "chat default", - "default": "no override", - } - engine_line = f"engine: {engine} ({source_labels[engine_source]})" + engine_line = f"engine: {engine} ({ENGINE_SOURCE_LABELS[engine_source]})" model_value = resolution.value or "default" - model_line = f"model: {model_value} ({override_labels[resolution.source]})" + model_line = ( + f"model: {model_value} ({OVERRIDE_SOURCE_LABELS[resolution.source]})" + ) topic_label = resolution.topic_value or "none" if tkey is None: topic_label = "none" @@ -170,14 +93,20 @@ async def _handle_model_command( return if action == "set": - engine_arg, model = _parse_set_args(tokens, engine_ids=engine_ids) + engine_arg, model = parse_set_args(tokens, engine_ids=engine_ids) if model is None: await reply(text=MODEL_USAGE) return - if not await _check_model_permissions(cfg, msg): + if not await require_admin_or_private( + cfg, + msg, + missing_sender="cannot verify sender for model overrides.", + failed_member="failed to verify model override permissions.", + denied="changing model overrides is restricted to group admins.", + ): return if engine_arg is None: - selection = await _resolve_engine_selection( + selection = await resolve_engine_selection( cfg, msg, ambient_context=ambient_context, @@ -196,16 +125,23 @@ async def _handle_model_command( text=f"unknown engine `{engine}`.\navailable agents: `{available}`" ) return - if tkey is not None: - if topic_store is None: - await reply(text="topic model overrides are unavailable.") - return - current = await topic_store.get_engine_override(tkey[0], tkey[1], engine) - updated = EngineOverrides( + scope = await apply_engine_override( + reply=reply, + tkey=tkey, + topic_store=topic_store, + chat_prefs=chat_prefs, + chat_id=msg.chat_id, + engine=engine, + update=lambda current: EngineOverrides( model=model, reasoning=current.reasoning if current is not None else None, - ) - await topic_store.set_engine_override(tkey[0], tkey[1], engine, updated) + ), + topic_unavailable="topic model overrides are unavailable.", + chat_unavailable="chat model overrides are unavailable (no config path).", + ) + if scope is None: + return + if scope == "topic": await reply( text=( f"topic model override set to `{model}` for `{engine}`.\n" @@ -213,15 +149,6 @@ async def _handle_model_command( ) ) return - if chat_prefs is None: - await reply(text="chat model overrides are unavailable (no config path).") - return - current = await chat_prefs.get_engine_override(msg.chat_id, engine) - updated = EngineOverrides( - model=model, - reasoning=current.reasoning if current is not None else None, - ) - await chat_prefs.set_engine_override(msg.chat_id, engine, updated) await reply( text=( f"chat model override set to `{model}` for `{engine}`.\n" @@ -237,10 +164,16 @@ async def _handle_model_command( return if len(tokens) == 2: engine = tokens[1].strip().lower() or None - if not await _check_model_permissions(cfg, msg): + if not await require_admin_or_private( + cfg, + msg, + missing_sender="cannot verify sender for model overrides.", + failed_member="failed to verify model override permissions.", + denied="changing model overrides is restricted to group admins.", + ): return if engine is None: - selection = await _resolve_engine_selection( + selection = await resolve_engine_selection( cfg, msg, ambient_context=ambient_context, @@ -257,27 +190,25 @@ async def _handle_model_command( text=f"unknown engine `{engine}`.\navailable agents: `{available}`" ) return - if tkey is not None: - if topic_store is None: - await reply(text="topic model overrides are unavailable.") - return - current = await topic_store.get_engine_override(tkey[0], tkey[1], engine) - updated = EngineOverrides( + scope = await apply_engine_override( + reply=reply, + tkey=tkey, + topic_store=topic_store, + chat_prefs=chat_prefs, + chat_id=msg.chat_id, + engine=engine, + update=lambda current: EngineOverrides( model=None, reasoning=current.reasoning if current is not None else None, - ) - await topic_store.set_engine_override(tkey[0], tkey[1], engine, updated) + ), + topic_unavailable="topic model overrides are unavailable.", + chat_unavailable="chat model overrides are unavailable (no config path).", + ) + if scope is None: + return + if scope == "topic": await reply(text="topic model override cleared (using chat default).") return - if chat_prefs is None: - await reply(text="chat model overrides are unavailable (no config path).") - return - current = await chat_prefs.get_engine_override(msg.chat_id, engine) - updated = EngineOverrides( - model=None, - reasoning=current.reasoning if current is not None else None, - ) - await chat_prefs.set_engine_override(msg.chat_id, engine, updated) await reply(text="chat model override cleared.") return diff --git a/src/takopi/telegram/commands/overrides.py b/src/takopi/telegram/commands/overrides.py new file mode 100644 index 0000000..0176844 --- /dev/null +++ b/src/takopi/telegram/commands/overrides.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Literal + +from ...context import RunContext +from ...directives import DirectiveError +from ..chat_prefs import ChatPrefsStore +from ..engine_defaults import resolve_engine_for_message +from ..engine_overrides import EngineOverrides +from ..topic_state import TopicStateStore +from ..types import TelegramIncomingMessage +from .reply import make_reply + +if TYPE_CHECKING: + from ..bridge import TelegramBridgeConfig + +ENGINE_SOURCE_LABELS = { + "directive": "directive", + "topic_default": "topic default", + "chat_default": "chat default", + "project_default": "project default", + "global_default": "global default", +} +OVERRIDE_SOURCE_LABELS = { + "topic_override": "topic override", + "chat_default": "chat default", + "default": "no override", +} + + +async def require_admin_or_private( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + *, + missing_sender: str, + failed_member: str, + denied: str, +) -> bool: + reply = make_reply(cfg, msg) + sender_id = msg.sender_id + if sender_id is None: + await reply(text=missing_sender) + return False + if msg.is_private: + return True + member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) + if member is None: + await reply(text=failed_member) + return False + if member.status in {"creator", "administrator"}: + return True + await reply(text=denied) + return False + + +async def resolve_engine_selection( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + *, + ambient_context: RunContext | None, + topic_store: TopicStateStore | None, + chat_prefs: ChatPrefsStore | None, + topic_key: tuple[int, int] | None, +) -> tuple[str, str] | None: + reply = make_reply(cfg, msg) + try: + resolved = cfg.runtime.resolve_message( + text="", + reply_text=msg.reply_to_text, + ambient_context=ambient_context, + chat_id=msg.chat_id, + ) + except DirectiveError as exc: + await reply(text=f"error:\n{exc}") + return None + selection = await resolve_engine_for_message( + runtime=cfg.runtime, + context=resolved.context, + explicit_engine=None, + chat_id=msg.chat_id, + topic_key=topic_key, + topic_store=topic_store, + chat_prefs=chat_prefs, + ) + return selection.engine, selection.source + + +def parse_set_args( + tokens: tuple[str, ...], *, engine_ids: set[str] +) -> tuple[str | None, str | None]: + if len(tokens) < 2: + return None, None + if len(tokens) == 2: + maybe_engine = tokens[1].strip().lower() + if maybe_engine in engine_ids: + return None, None + return None, tokens[1].strip() + maybe_engine = tokens[1].strip().lower() + if maybe_engine in engine_ids: + value = " ".join(tokens[2:]).strip() + return maybe_engine, value or None + value = " ".join(tokens[1:]).strip() + return None, value or None + + +async def apply_engine_override( + *, + reply: Callable[..., Awaitable[None]], + tkey: tuple[int, int] | None, + topic_store: TopicStateStore | None, + chat_prefs: ChatPrefsStore | None, + chat_id: int, + engine: str, + update: Callable[[EngineOverrides | None], EngineOverrides], + topic_unavailable: str, + chat_unavailable: str, +) -> Literal["topic", "chat"] | None: + if tkey is not None: + if topic_store is None: + await reply(text=topic_unavailable) + return None + current = await topic_store.get_engine_override(tkey[0], tkey[1], engine) + updated = update(current) + await topic_store.set_engine_override(tkey[0], tkey[1], engine, updated) + return "topic" + if chat_prefs is None: + await reply(text=chat_unavailable) + return None + current = await chat_prefs.get_engine_override(chat_id, engine) + updated = update(current) + await chat_prefs.set_engine_override(chat_id, engine, updated) + return "chat" diff --git a/src/takopi/telegram/commands/reasoning.py b/src/takopi/telegram/commands/reasoning.py index d334ba5..be726f8 100644 --- a/src/takopi/telegram/commands/reasoning.py +++ b/src/takopi/telegram/commands/reasoning.py @@ -3,9 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from ...context import RunContext -from ...directives import DirectiveError from ..chat_prefs import ChatPrefsStore -from ..engine_defaults import resolve_engine_for_message from ..engine_overrides import ( EngineOverrides, allowed_reasoning_levels, @@ -15,6 +13,14 @@ from ..files import split_command_args from ..topic_state import TopicStateStore from ..topics import _topic_key from ..types import TelegramIncomingMessage +from .overrides import ( + ENGINE_SOURCE_LABELS, + OVERRIDE_SOURCE_LABELS, + apply_engine_override, + parse_set_args, + require_admin_or_private, + resolve_engine_selection, +) from .reply import make_reply if TYPE_CHECKING: @@ -26,79 +32,6 @@ REASONING_USAGE = ( ) -async def _check_reasoning_permissions( - cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage -) -> bool: - reply = make_reply(cfg, msg) - sender_id = msg.sender_id - if sender_id is None: - await reply(text="cannot verify sender for reasoning overrides.") - return False - is_private = msg.chat_type == "private" - if msg.chat_type is None: - is_private = msg.chat_id > 0 - if is_private: - return True - member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) - if member is None: - await reply(text="failed to verify reasoning override permissions.") - return False - if member.status in {"creator", "administrator"}: - return True - await reply(text="changing reasoning overrides is restricted to group admins.") - return False - - -async def _resolve_engine_selection( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - *, - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, - chat_prefs: ChatPrefsStore | None, - topic_key: tuple[int, int] | None, -) -> tuple[str, str] | None: - reply = make_reply(cfg, msg) - try: - resolved = cfg.runtime.resolve_message( - text="", - reply_text=msg.reply_to_text, - ambient_context=ambient_context, - chat_id=msg.chat_id, - ) - except DirectiveError as exc: - await reply(text=f"error:\n{exc}") - return None - selection = await resolve_engine_for_message( - runtime=cfg.runtime, - context=resolved.context, - explicit_engine=None, - chat_id=msg.chat_id, - topic_key=topic_key, - topic_store=topic_store, - chat_prefs=chat_prefs, - ) - return selection.engine, selection.source - - -def _parse_set_args( - tokens: tuple[str, ...], *, engine_ids: set[str] -) -> tuple[str | None, str | None]: - if len(tokens) < 2: - return None, None - if len(tokens) == 2: - maybe_engine = tokens[1].strip().lower() - if maybe_engine in engine_ids: - return None, None - return None, tokens[1].strip() - maybe_engine = tokens[1].strip().lower() - if maybe_engine in engine_ids: - level = " ".join(tokens[2:]).strip() - return maybe_engine, level or None - level = " ".join(tokens[1:]).strip() - return None, level or None - - async def _handle_reasoning_command( cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, @@ -121,7 +54,7 @@ async def _handle_reasoning_command( engine_ids = {engine.lower() for engine in cfg.runtime.engine_ids} if action in {"show", ""}: - selection = await _resolve_engine_selection( + selection = await resolve_engine_selection( cfg, msg, ambient_context=ambient_context, @@ -145,22 +78,11 @@ async def _handle_reasoning_command( chat_override=chat_override, field="reasoning", ) - source_labels = { - "directive": "directive", - "topic_default": "topic default", - "chat_default": "chat default", - "project_default": "project default", - "global_default": "global default", - } - override_labels = { - "topic_override": "topic override", - "chat_default": "chat default", - "default": "no override", - } - engine_line = f"engine: {engine} ({source_labels[engine_source]})" + engine_line = f"engine: {engine} ({ENGINE_SOURCE_LABELS[engine_source]})" reasoning_value = resolution.value or "default" reasoning_line = ( - f"reasoning: {reasoning_value} ({override_labels[resolution.source]})" + f"reasoning: {reasoning_value} " + f"({OVERRIDE_SOURCE_LABELS[resolution.source]})" ) topic_label = resolution.topic_value or "none" if tkey is None: @@ -179,14 +101,20 @@ async def _handle_reasoning_command( return if action == "set": - engine_arg, level = _parse_set_args(tokens, engine_ids=engine_ids) + engine_arg, level = parse_set_args(tokens, engine_ids=engine_ids) if level is None: await reply(text=REASONING_USAGE) return - if not await _check_reasoning_permissions(cfg, msg): + if not await require_admin_or_private( + cfg, + msg, + missing_sender="cannot verify sender for reasoning overrides.", + failed_member="failed to verify reasoning override permissions.", + denied="changing reasoning overrides is restricted to group admins.", + ): return if engine_arg is None: - selection = await _resolve_engine_selection( + selection = await resolve_engine_selection( cfg, msg, ambient_context=ambient_context, @@ -215,16 +143,23 @@ async def _handle_reasoning_command( ) ) return - if tkey is not None: - if topic_store is None: - await reply(text="topic reasoning overrides are unavailable.") - return - current = await topic_store.get_engine_override(tkey[0], tkey[1], engine) - updated = EngineOverrides( + scope = await apply_engine_override( + reply=reply, + tkey=tkey, + topic_store=topic_store, + chat_prefs=chat_prefs, + chat_id=msg.chat_id, + engine=engine, + update=lambda current: EngineOverrides( model=current.model if current is not None else None, reasoning=normalized_level, - ) - await topic_store.set_engine_override(tkey[0], tkey[1], engine, updated) + ), + topic_unavailable="topic reasoning overrides are unavailable.", + chat_unavailable="chat reasoning overrides are unavailable (no config path).", + ) + if scope is None: + return + if scope == "topic": await reply( text=( f"topic reasoning override set to `{normalized_level}` " @@ -233,17 +168,6 @@ async def _handle_reasoning_command( ) ) return - if chat_prefs is None: - await reply( - text="chat reasoning overrides are unavailable (no config path)." - ) - return - current = await chat_prefs.get_engine_override(msg.chat_id, engine) - updated = EngineOverrides( - model=current.model if current is not None else None, - reasoning=normalized_level, - ) - await chat_prefs.set_engine_override(msg.chat_id, engine, updated) await reply( text=( f"chat reasoning override set to `{normalized_level}` for `{engine}`.\n" @@ -259,10 +183,16 @@ async def _handle_reasoning_command( return if len(tokens) == 2: engine = tokens[1].strip().lower() or None - if not await _check_reasoning_permissions(cfg, msg): + if not await require_admin_or_private( + cfg, + msg, + missing_sender="cannot verify sender for reasoning overrides.", + failed_member="failed to verify reasoning override permissions.", + denied="changing reasoning overrides is restricted to group admins.", + ): return if engine is None: - selection = await _resolve_engine_selection( + selection = await resolve_engine_selection( cfg, msg, ambient_context=ambient_context, @@ -279,29 +209,25 @@ async def _handle_reasoning_command( text=f"unknown engine `{engine}`.\navailable agents: `{available}`" ) return - if tkey is not None: - if topic_store is None: - await reply(text="topic reasoning overrides are unavailable.") - return - current = await topic_store.get_engine_override(tkey[0], tkey[1], engine) - updated = EngineOverrides( + scope = await apply_engine_override( + reply=reply, + tkey=tkey, + topic_store=topic_store, + chat_prefs=chat_prefs, + chat_id=msg.chat_id, + engine=engine, + update=lambda current: EngineOverrides( model=current.model if current is not None else None, reasoning=None, - ) - await topic_store.set_engine_override(tkey[0], tkey[1], engine, updated) + ), + topic_unavailable="topic reasoning overrides are unavailable.", + chat_unavailable="chat reasoning overrides are unavailable (no config path).", + ) + if scope is None: + return + if scope == "topic": await reply(text="topic reasoning override cleared (using chat default).") return - if chat_prefs is None: - await reply( - text="chat reasoning overrides are unavailable (no config path)." - ) - return - current = await chat_prefs.get_engine_override(msg.chat_id, engine) - updated = EngineOverrides( - model=current.model if current is not None else None, - reasoning=None, - ) - await chat_prefs.set_engine_override(msg.chat_id, engine, updated) await reply(text="chat reasoning override cleared.") return diff --git a/src/takopi/telegram/commands/trigger.py b/src/takopi/telegram/commands/trigger.py index e6d5ea6..b36137f 100644 --- a/src/takopi/telegram/commands/trigger.py +++ b/src/takopi/telegram/commands/trigger.py @@ -8,6 +8,7 @@ from ..topic_state import TopicStateStore from ..topics import _topic_key from ..trigger_mode import resolve_trigger_mode from ..types import TelegramIncomingMessage +from .overrides import require_admin_or_private from .reply import make_reply if TYPE_CHECKING: @@ -18,29 +19,6 @@ TRIGGER_USAGE = ( ) -async def _check_trigger_permissions( - cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage -) -> bool: - reply = make_reply(cfg, msg) - sender_id = msg.sender_id - if sender_id is None: - await reply(text="cannot verify sender for trigger settings.") - return False - is_private = msg.chat_type == "private" - if msg.chat_type is None: - is_private = msg.chat_id > 0 - if is_private: - return True - member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) - if member is None: - await reply(text="failed to verify trigger permissions.") - return False - if member.status in {"creator", "administrator"}: - return True - await reply(text="changing trigger mode is restricted to group admins.") - return False - - async def _handle_trigger_command( cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, @@ -91,7 +69,13 @@ async def _handle_trigger_command( return if action in {"all", "mentions"}: - if not await _check_trigger_permissions(cfg, msg): + if not await require_admin_or_private( + cfg, + msg, + missing_sender="cannot verify sender for trigger settings.", + failed_member="failed to verify trigger permissions.", + denied="changing trigger mode is restricted to group admins.", + ): return if tkey is not None: if topic_store is None: @@ -108,7 +92,13 @@ async def _handle_trigger_command( return if action == "clear": - if not await _check_trigger_permissions(cfg, msg): + if not await require_admin_or_private( + cfg, + msg, + missing_sender="cannot verify sender for trigger settings.", + failed_member="failed to verify trigger permissions.", + denied="changing trigger mode is restricted to group admins.", + ): return if tkey is not None: if topic_store is None: diff --git a/src/takopi/telegram/loop.py b/src/takopi/telegram/loop.py index 8bb51b8..6511c7c 100644 --- a/src/takopi/telegram/loop.py +++ b/src/takopi/telegram/loop.py @@ -1,9 +1,9 @@ from __future__ import annotations -from collections.abc import AsyncIterator, Awaitable, Callable +from collections.abc import AsyncIterator, Awaitable, Callable, Mapping from dataclasses import dataclass from functools import partial -from typing import cast +from typing import TYPE_CHECKING, cast import anyio from anyio.abc import TaskGroup @@ -23,29 +23,30 @@ from ..transport_runtime import ResolvedMessage from ..context import RunContext from ..ids import RESERVED_CHAT_COMMANDS from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain -from .commands.agent import _handle_agent_command from .commands.cancel import handle_callback_cancel, handle_cancel -from .commands.dispatch import _dispatch_command -from .commands.executor import _run_engine, _should_show_resume_line -from .commands.file_transfer import ( - FILE_PUT_USAGE, - _handle_file_command, - _handle_file_put_default, - _save_file_put, +from .commands.file_transfer import FILE_PUT_USAGE +from .commands.handlers import ( + dispatch_command, + handle_agent_command, + handle_chat_new_command, + handle_ctx_command, + handle_file_command, + handle_file_put_default, + handle_media_group, + handle_model_command, + handle_new_command, + handle_reasoning_command, + handle_topic_command, + handle_trigger_command, + parse_slash_command, + get_reserved_commands, + run_engine, + save_file_put, + set_command_menu, + should_show_resume_line, ) -from .commands.media import _handle_media_group -from .commands.menu import _reserved_commands, _set_command_menu -from .commands.parse import _parse_slash_command, is_cancel_command +from .commands.parse import is_cancel_command from .commands.reply import make_reply -from .commands.topics import ( - _handle_chat_new_command, - _handle_ctx_command, - _handle_new_command, - _handle_topic_command, -) -from .commands.model import _handle_model_command -from .commands.reasoning import _handle_reasoning_command -from .commands.trigger import _handle_trigger_command from .context import _merge_topic_context, _usage_ctx_set, _usage_topic from .topics import ( _maybe_rename_topic, @@ -75,6 +76,8 @@ __all__ = ["poll_updates", "run_main_loop", "send_with_resume"] ForwardKey = tuple[int, int, int] +_handle_file_put_default = handle_file_put_default + def _chat_session_key( msg: TelegramIncomingMessage, *, store: ChatSessionStore | None @@ -135,84 +138,76 @@ async def _send_startup(cfg: TelegramBridgeConfig) -> None: def _dispatch_builtin_command( *, - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, + ctx: TelegramCommandContext, command_id: str, - args_text: str, - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, - chat_prefs: ChatPrefsStore | None, - resolved_scope: str | None, - scope_chat_ids: frozenset[int], - reply: Callable[..., Awaitable[None]], - task_group: TaskGroup, ) -> bool: - handlers: dict[str, Callable[[], Awaitable[None]]] = {} - + cfg = ctx.cfg + msg = ctx.msg + args_text = ctx.args_text + ambient_context = ctx.ambient_context + topic_store = ctx.topic_store + chat_prefs = ctx.chat_prefs + resolved_scope = ctx.resolved_scope + scope_chat_ids = ctx.scope_chat_ids + reply = ctx.reply + task_group = ctx.task_group if command_id == "file": if not cfg.files.enabled: - handlers["file"] = partial( + handler = partial( reply, text="file transfer disabled; enable `[transports.telegram.files]`.", ) else: - handlers["file"] = partial( - _handle_file_command, + handler = partial( + handle_file_command, cfg, msg, args_text, ambient_context, topic_store, ) + task_group.start_soon(handler) + return True if cfg.topics.enabled and topic_store is not None: - handlers.update( - { - "ctx": partial( - _handle_ctx_command, - cfg, - msg, - args_text, - topic_store, - resolved_scope=resolved_scope, - scope_chat_ids=scope_chat_ids, - ), - "new": partial( - _handle_new_command, - cfg, - msg, - topic_store, - resolved_scope=resolved_scope, - scope_chat_ids=scope_chat_ids, - ), - "topic": partial( - _handle_topic_command, - cfg, - msg, - args_text, - topic_store, - resolved_scope=resolved_scope, - scope_chat_ids=scope_chat_ids, - ), - } - ) - - if command_id == "agent": - handlers["agent"] = partial( - _handle_agent_command, - cfg, - msg, - args_text, - ambient_context, - topic_store, - chat_prefs, - resolved_scope=resolved_scope, - scope_chat_ids=scope_chat_ids, - ) + if command_id == "ctx": + handler = partial( + handle_ctx_command, + cfg, + msg, + args_text, + topic_store, + resolved_scope=resolved_scope, + scope_chat_ids=scope_chat_ids, + ) + elif command_id == "new": + handler = partial( + handle_new_command, + cfg, + msg, + topic_store, + resolved_scope=resolved_scope, + scope_chat_ids=scope_chat_ids, + ) + elif command_id == "topic": + handler = partial( + handle_topic_command, + cfg, + msg, + args_text, + topic_store, + resolved_scope=resolved_scope, + scope_chat_ids=scope_chat_ids, + ) + else: + handler = None + if handler is not None: + task_group.start_soon(handler) + return True if command_id == "model": - handlers["model"] = partial( - _handle_model_command, + handler = partial( + handle_model_command, cfg, msg, args_text, @@ -222,10 +217,27 @@ def _dispatch_builtin_command( resolved_scope=resolved_scope, scope_chat_ids=scope_chat_ids, ) + task_group.start_soon(handler) + return True + + if command_id == "agent": + handler = partial( + handle_agent_command, + cfg, + msg, + args_text, + ambient_context, + topic_store, + chat_prefs, + resolved_scope=resolved_scope, + scope_chat_ids=scope_chat_ids, + ) + task_group.start_soon(handler) + return True if command_id == "reasoning": - handlers["reasoning"] = partial( - _handle_reasoning_command, + handler = partial( + handle_reasoning_command, cfg, msg, args_text, @@ -235,10 +247,12 @@ def _dispatch_builtin_command( resolved_scope=resolved_scope, scope_chat_ids=scope_chat_ids, ) + task_group.start_soon(handler) + return True if command_id == "trigger": - handlers["trigger"] = partial( - _handle_trigger_command, + handler = partial( + handle_trigger_command, cfg, msg, args_text, @@ -248,12 +262,10 @@ def _dispatch_builtin_command( resolved_scope=resolved_scope, scope_chat_ids=scope_chat_ids, ) + task_group.start_soon(handler) + return True - handler = handlers.get(command_id) - if handler is None: - return False - task_group.start_soon(handler) - return True + return False async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int | None: @@ -312,6 +324,57 @@ class _PendingPrompt: cancel_scope: anyio.CancelScope | None = None +@dataclass(frozen=True, slots=True) +class TelegramMsgContext: + chat_id: int + thread_id: int | None + reply_id: int | None + reply_ref: MessageRef | None + topic_key: tuple[int, int] | None + chat_session_key: tuple[int, int | None] | None + stateful_mode: bool + chat_project: str | None + ambient_context: RunContext | None + + +@dataclass(frozen=True, slots=True) +class TelegramCommandContext: + cfg: TelegramBridgeConfig + msg: TelegramIncomingMessage + args_text: str + ambient_context: RunContext | None + topic_store: TopicStateStore | None + chat_prefs: ChatPrefsStore | None + resolved_scope: str | None + scope_chat_ids: frozenset[int] + reply: Callable[..., Awaitable[None]] + task_group: TaskGroup + + +@dataclass(slots=True) +class TelegramLoopState: + running_tasks: RunningTasks + pending_prompts: dict[ForwardKey, _PendingPrompt] + media_groups: dict[tuple[int, str], _MediaGroupState] + command_ids: set[str] + reserved_commands: set[str] + reserved_chat_commands: set[str] + transport_snapshot: dict[str, object] | None + topic_store: TopicStateStore | None + chat_session_store: ChatSessionStore | None + chat_prefs: ChatPrefsStore | None + resolved_topics_scope: str | None + topics_chat_ids: frozenset[int] + bot_username: str | None + forward_coalesce_s: float + media_group_debounce_s: float + transport_id: str | None + + +if TYPE_CHECKING: + from ..runner_bridge import RunningTasks + + _FORWARD_FIELDS = ( "forward_origin", "forward_from", @@ -350,6 +413,349 @@ def _format_forwarded_prompt(forwarded: list[str], prompt: str) -> str: return forward_block +class ForwardCoalescer: + def __init__( + self, + *, + task_group: TaskGroup, + debounce_s: float, + dispatch: Callable[[_PendingPrompt], Awaitable[None]], + pending: dict[ForwardKey, _PendingPrompt], + ) -> None: + self._task_group = task_group + self._debounce_s = debounce_s + self._dispatch = dispatch + self._pending = pending + + def cancel(self, key: ForwardKey) -> None: + pending = self._pending.pop(key, None) + if pending is None: + return + if pending.cancel_scope is not None: + pending.cancel_scope.cancel() + logger.debug( + "forward.prompt.cancelled", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + message_id=pending.msg.message_id, + forward_count=len(pending.forwards), + ) + + def schedule(self, pending: _PendingPrompt) -> None: + if pending.msg.sender_id is None: + logger.debug( + "forward.prompt.bypass", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + message_id=pending.msg.message_id, + reason="missing_sender", + ) + self._task_group.start_soon(self._dispatch, pending) + return + if self._debounce_s <= 0: + logger.debug( + "forward.prompt.bypass", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + message_id=pending.msg.message_id, + reason="disabled", + ) + self._task_group.start_soon(self._dispatch, pending) + return + key = _forward_key(pending.msg) + existing = self._pending.get(key) + if existing is not None: + if existing.cancel_scope is not None: + existing.cancel_scope.cancel() + if existing.forwards: + pending.forwards = list(existing.forwards) + logger.debug( + "forward.prompt.replace", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + old_message_id=existing.msg.message_id, + new_message_id=pending.msg.message_id, + forward_count=len(pending.forwards), + ) + self._pending[key] = pending + logger.debug( + "forward.prompt.schedule", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + message_id=pending.msg.message_id, + debounce_s=self._debounce_s, + ) + self._reschedule(key, pending) + + def attach_forward(self, msg: TelegramIncomingMessage) -> None: + if msg.sender_id is None: + logger.debug( + "forward.message.ignored", + chat_id=msg.chat_id, + thread_id=msg.thread_id, + sender_id=msg.sender_id, + message_id=msg.message_id, + reason="missing_sender", + ) + return + key = _forward_key(msg) + pending = self._pending.get(key) + if pending is None: + logger.debug( + "forward.message.ignored", + chat_id=msg.chat_id, + thread_id=msg.thread_id, + sender_id=msg.sender_id, + message_id=msg.message_id, + reason="no_pending_prompt", + ) + return + text = msg.text + if not text.strip(): + logger.debug( + "forward.message.ignored", + chat_id=msg.chat_id, + thread_id=msg.thread_id, + sender_id=msg.sender_id, + message_id=msg.message_id, + reason="empty_text", + ) + return + pending.forwards.append((msg.message_id, text)) + logger.debug( + "forward.message.attached", + chat_id=msg.chat_id, + thread_id=msg.thread_id, + sender_id=msg.sender_id, + message_id=msg.message_id, + prompt_message_id=pending.msg.message_id, + forward_count=len(pending.forwards), + forward_fields=_forward_fields_present(msg.raw), + forward_date=msg.raw.get("forward_date") if msg.raw else None, + message_date=msg.raw.get("date") if msg.raw else None, + text_len=len(text), + ) + self._reschedule(key, pending) + + def _reschedule(self, key: ForwardKey, pending: _PendingPrompt) -> None: + if pending.cancel_scope is not None: + pending.cancel_scope.cancel() + pending.cancel_scope = None + self._task_group.start_soon(self._debounce_prompt_run, key, pending) + + async def _debounce_prompt_run( + self, + key: ForwardKey, + pending: _PendingPrompt, + ) -> None: + try: + with anyio.CancelScope() as scope: + pending.cancel_scope = scope + await anyio.sleep(self._debounce_s) + except anyio.get_cancelled_exc_class(): + return + if self._pending.get(key) is not pending: + return + self._pending.pop(key, None) + logger.debug( + "forward.prompt.run", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + message_id=pending.msg.message_id, + forward_count=len(pending.forwards), + debounce_s=self._debounce_s, + ) + await self._dispatch(pending) + + +@dataclass(frozen=True, slots=True) +class ResumeDecision: + resume_token: ResumeToken | None + handled_by_running_task: bool + + +class ResumeResolver: + def __init__( + self, + *, + cfg: TelegramBridgeConfig, + task_group: TaskGroup, + running_tasks: Mapping[MessageRef, object], + enqueue_resume: Callable[ + [ + int, + int, + str, + ResumeToken, + RunContext | None, + int | None, + tuple[int, int | None] | None, + MessageRef | None, + ], + Awaitable[None], + ], + topic_store: TopicStateStore | None, + chat_session_store: ChatSessionStore | None, + ) -> None: + self._cfg = cfg + self._task_group = task_group + self._running_tasks = running_tasks + self._enqueue_resume = enqueue_resume + self._topic_store = topic_store + self._chat_session_store = chat_session_store + + async def resolve( + self, + *, + resume_token: ResumeToken | None, + reply_id: int | None, + chat_id: int, + user_msg_id: int, + thread_id: int | None, + chat_session_key: tuple[int, int | None] | None, + topic_key: tuple[int, int] | None, + engine_for_session: EngineId, + prompt_text: str, + ) -> ResumeDecision: + if resume_token is not None: + return ResumeDecision( + resume_token=resume_token, handled_by_running_task=False + ) + if reply_id is not None: + running_task = self._running_tasks.get( + MessageRef(channel_id=chat_id, message_id=reply_id) + ) + if running_task is not None: + self._task_group.start_soon( + send_with_resume, + self._cfg, + self._enqueue_resume, + running_task, + chat_id, + user_msg_id, + thread_id, + chat_session_key, + prompt_text, + ) + return ResumeDecision(resume_token=None, handled_by_running_task=True) + if self._topic_store is not None and topic_key is not None: + stored = await self._topic_store.get_session_resume( + topic_key[0], + topic_key[1], + engine_for_session, + ) + if stored is not None: + resume_token = stored + if ( + resume_token is None + and self._chat_session_store is not None + and chat_session_key is not None + ): + stored = await self._chat_session_store.get_session_resume( + chat_session_key[0], + chat_session_key[1], + engine_for_session, + ) + if stored is not None: + resume_token = stored + return ResumeDecision(resume_token=resume_token, handled_by_running_task=False) + + +class MediaGroupBuffer: + def __init__( + self, + *, + task_group: TaskGroup, + debounce_s: float, + cfg: TelegramBridgeConfig, + chat_prefs: ChatPrefsStore | None, + topic_store: TopicStateStore | None, + bot_username: str | None, + command_ids: Callable[[], set[str]], + reserved_chat_commands: set[str], + groups: dict[tuple[int, str], _MediaGroupState], + run_prompt_from_upload: Callable[ + [TelegramIncomingMessage, str, ResolvedMessage], Awaitable[None] + ], + resolve_prompt_message: Callable[ + [TelegramIncomingMessage, str, RunContext | None], + Awaitable[ResolvedMessage | None], + ], + ) -> None: + self._task_group = task_group + self._debounce_s = debounce_s + self._cfg = cfg + self._chat_prefs = chat_prefs + self._topic_store = topic_store + self._bot_username = bot_username + self._command_ids = command_ids + self._reserved_chat_commands = reserved_chat_commands + self._groups = groups + self._run_prompt_from_upload = run_prompt_from_upload + self._resolve_prompt_message = resolve_prompt_message + + def add(self, msg: TelegramIncomingMessage) -> None: + if msg.media_group_id is None: + return + key = (msg.chat_id, msg.media_group_id) + state = self._groups.get(key) + if state is None: + state = _MediaGroupState(messages=[]) + self._groups[key] = state + self._task_group.start_soon(self._flush_media_group, key) + state.messages.append(msg) + state.token += 1 + + async def _flush_media_group(self, key: tuple[int, str]) -> None: + while True: + state = self._groups.get(key) + if state is None: + return + token = state.token + await anyio.sleep(self._debounce_s) + state = self._groups.get(key) + if state is None: + return + if state.token != token: + continue + messages = list(state.messages) + del self._groups[key] + if not messages: + return + trigger_mode = await resolve_trigger_mode( + chat_id=messages[0].chat_id, + thread_id=messages[0].thread_id, + chat_prefs=self._chat_prefs, + topic_store=self._topic_store, + ) + command_ids = self._command_ids() + if trigger_mode == "mentions" and not any( + should_trigger_run( + msg, + bot_username=self._bot_username, + runtime=self._cfg.runtime, + command_ids=command_ids, + reserved_chat_commands=self._reserved_chat_commands, + ) + for msg in messages + ): + return + await handle_media_group( + self._cfg, + messages, + self._topic_store, + self._run_prompt_from_upload, + self._resolve_prompt_message, + ) + return + + def _diff_keys(old: dict[str, object], new: dict[str, object]) -> list[str]: keys = set(old) | set(new) return sorted(key for key in keys if old.get(key) != new.get(key)) @@ -475,49 +881,51 @@ async def run_main_loop( transport_id: str | None = None, transport_config: TelegramTransportSettings | None = None, ) -> None: - from ..runner_bridge import RunningTasks - - running_tasks: RunningTasks = {} - command_ids = { - command_id.lower() - for command_id in list_command_ids(allowlist=cfg.runtime.allowlist) - } - reserved_commands = _reserved_commands(cfg.runtime) - reserved_chat_commands = set(RESERVED_CHAT_COMMANDS) - transport_snapshot = ( - transport_config.model_dump() if transport_config is not None else None + state = TelegramLoopState( + running_tasks={}, + pending_prompts={}, + media_groups={}, + command_ids={ + command_id.lower() + for command_id in list_command_ids(allowlist=cfg.runtime.allowlist) + }, + reserved_commands=get_reserved_commands(cfg.runtime), + reserved_chat_commands=set(RESERVED_CHAT_COMMANDS), + transport_snapshot=( + transport_config.model_dump() if transport_config is not None else None + ), + topic_store=None, + chat_session_store=None, + chat_prefs=None, + resolved_topics_scope=None, + topics_chat_ids=frozenset(), + bot_username=None, + forward_coalesce_s=max(0.0, float(cfg.forward_coalesce_s)), + media_group_debounce_s=max(0.0, float(cfg.media_group_debounce_s)), + transport_id=transport_id, ) - topic_store: TopicStateStore | None = None - chat_session_store: ChatSessionStore | None = None - chat_prefs: ChatPrefsStore | None = None - media_groups: dict[tuple[int, str], _MediaGroupState] = {} - pending_prompts: dict[ForwardKey, _PendingPrompt] = {} - resolved_topics_scope: str | None = None - topics_chat_ids: frozenset[int] = frozenset() - bot_username: str | None = None - forward_coalesce_s = max(0.0, float(cfg.forward_coalesce_s)) - media_group_debounce_s = max(0.0, float(cfg.media_group_debounce_s)) def refresh_topics_scope() -> None: - nonlocal resolved_topics_scope, topics_chat_ids if cfg.topics.enabled: - resolved_topics_scope, topics_chat_ids = _resolve_topics_scope(cfg) + ( + state.resolved_topics_scope, + state.topics_chat_ids, + ) = _resolve_topics_scope(cfg) else: - resolved_topics_scope = None - topics_chat_ids = frozenset() + state.resolved_topics_scope = None + state.topics_chat_ids = frozenset() def refresh_commands() -> None: - nonlocal command_ids, reserved_commands allowlist = cfg.runtime.allowlist - command_ids = { + state.command_ids = { command_id.lower() for command_id in list_command_ids(allowlist=allowlist) } - reserved_commands = _reserved_commands(cfg.runtime) + state.reserved_commands = get_reserved_commands(cfg.runtime) try: config_path = cfg.runtime.config_path if config_path is not None: - chat_prefs = ChatPrefsStore(resolve_prefs_path(config_path)) + state.chat_prefs = ChatPrefsStore(resolve_prefs_path(config_path)) logger.info( "chat_prefs.enabled", state_path=str(resolve_prefs_path(config_path)), @@ -527,7 +935,9 @@ async def run_main_loop( raise ConfigError( "session_mode=chat but config path is not set; cannot locate state file." ) - chat_session_store = ChatSessionStore(resolve_sessions_path(config_path)) + state.chat_session_store = ChatSessionStore( + resolve_sessions_path(config_path) + ) logger.info( "chat_sessions.enabled", state_path=str(resolve_sessions_path(config_path)), @@ -537,16 +947,16 @@ async def run_main_loop( raise ConfigError( "topics enabled but config path is not set; cannot locate state file." ) - topic_store = TopicStateStore(resolve_state_path(config_path)) + state.topic_store = TopicStateStore(resolve_state_path(config_path)) await _validate_topics_setup(cfg) refresh_topics_scope() logger.info( "topics.enabled", scope=cfg.topics.scope, - resolved_scope=resolved_topics_scope, + resolved_scope=state.resolved_topics_scope, state_path=str(resolve_state_path(config_path)), ) - await _set_command_menu(cfg) + await set_command_menu(cfg) try: me = await cfg.bot.get_me() except Exception as exc: # noqa: BLE001 @@ -557,7 +967,7 @@ async def run_main_loop( ) me = None if me is not None and me.username: - bot_username = me.username.lower() + state.bot_username = me.username.lower() else: logger.info("trigger_mode.bot_username.unavailable") async with anyio.create_task_group() as tg: @@ -565,13 +975,12 @@ async def run_main_loop( watch_enabled = bool(watch_config) and config_path is not None async def handle_reload(reload: ConfigReload) -> None: - nonlocal transport_snapshot, transport_id refresh_commands() refresh_topics_scope() - await _set_command_menu(cfg) - if transport_snapshot is not None: + await set_command_menu(cfg) + if state.transport_snapshot is not None: new_snapshot = reload.settings.transports.telegram.model_dump() - changed = _diff_keys(transport_snapshot, new_snapshot) + changed = _diff_keys(state.transport_snapshot, new_snapshot) if changed: logger.warning( "config.reload.transport_config_changed", @@ -579,18 +988,18 @@ async def run_main_loop( keys=changed, restart_required=True, ) - transport_snapshot = new_snapshot + state.transport_snapshot = new_snapshot if ( - transport_id is not None - and reload.settings.transport != transport_id + state.transport_id is not None + and reload.settings.transport != state.transport_id ): logger.warning( "config.reload.transport_changed", - old=transport_id, + old=state.transport_id, new=reload.settings.transport, restart_required=True, ) - transport_id = reload.settings.transport + state.transport_id = reload.settings.transport if watch_enabled and config_path is not None: @@ -615,12 +1024,15 @@ async def run_main_loop( async def _wrapped(token: ResumeToken, done: anyio.Event) -> None: if base_cb is not None: await base_cb(token, done) - if topic_store is not None and topic_key is not None: - await topic_store.set_session_resume( + if state.topic_store is not None and topic_key is not None: + await state.topic_store.set_session_resume( topic_key[0], topic_key[1], token ) - if chat_session_store is not None and chat_session_key is not None: - await chat_session_store.set_session_resume( + if ( + state.chat_session_store is not None + and chat_session_key is not None + ): + await state.chat_session_store.set_session_resume( chat_session_key[0], chat_session_key[1], token ) @@ -642,15 +1054,15 @@ async def run_main_loop( ) -> None: topic_key = ( (chat_id, thread_id) - if topic_store is not None + if state.topic_store is not None and thread_id is not None and _topics_chat_allowed( - cfg, chat_id, scope_chat_ids=topics_chat_ids + cfg, chat_id, scope_chat_ids=state.topics_chat_ids ) else None ) stateful_mode = topic_key is not None or chat_session_key is not None - show_resume_line = _should_show_resume_line( + show_resume_line = should_show_resume_line( show_resume_line=cfg.show_resume_line, stateful_mode=stateful_mode, context=context, @@ -670,13 +1082,13 @@ async def run_main_loop( chat_id, overrides_thread_id, engine_for_overrides, - chat_prefs=chat_prefs, - topic_store=topic_store, + chat_prefs=state.chat_prefs, + topic_store=state.topic_store, ) - await _run_engine( + await run_engine( exec_cfg=cfg.exec_cfg, runtime=cfg.runtime, - running_tasks=running_tasks, + running_tasks=state.running_tasks, chat_id=chat_id, user_msg_id=user_msg_id, text=text, @@ -710,6 +1122,13 @@ async def run_main_loop( scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job) + def resolve_topic_key( + msg: TelegramIncomingMessage, + ) -> tuple[int, int] | None: + if state.topic_store is None: + return None + return _topic_key(msg, cfg, scope_chat_ids=state.topics_chat_ids) + def _build_upload_prompt(base: str, annotation: str) -> str: if base and base.strip(): return f"{base}\n\n{annotation}" @@ -731,29 +1150,25 @@ async def run_main_loop( except DirectiveError as exc: await reply(text=f"error:\n{exc}") return None - topic_key = ( - _topic_key(msg, cfg, scope_chat_ids=topics_chat_ids) - if topic_store is not None - else None - ) + topic_key = resolve_topic_key(msg) effective_context = ambient_context if ( - topic_store is not None + state.topic_store is not None and topic_key is not None and resolved.context is not None and resolved.context_source == "directives" ): - await topic_store.set_context(*topic_key, resolved.context) + await state.topic_store.set_context(*topic_key, resolved.context) await _maybe_rename_topic( cfg, - topic_store, + state.topic_store, chat_id=topic_key[0], thread_id=topic_key[1], context=resolved.context, ) effective_context = resolved.context if ( - topic_store is not None + state.topic_store is not None and topic_key is not None and effective_context is None and resolved.context_source not in {"directives", "reply_ctx"} @@ -784,10 +1199,19 @@ async def run_main_loop( explicit_engine=explicit_engine, chat_id=chat_id, topic_key=topic_key, - topic_store=topic_store, - chat_prefs=chat_prefs, + topic_store=state.topic_store, + chat_prefs=state.chat_prefs, ) + resume_resolver = ResumeResolver( + cfg=cfg, + task_group=tg, + running_tasks=state.running_tasks, + enqueue_resume=scheduler.enqueue_resume, + topic_store=state.topic_store, + chat_session_store=state.chat_session_store, + ) + async def run_prompt_from_upload( msg: TelegramIncomingMessage, prompt_text: str, @@ -807,12 +1231,10 @@ async def run_main_loop( ) resume_token = resolved.resume_token context = resolved.context - chat_session_key = _chat_session_key(msg, store=chat_session_store) - topic_key = ( - _topic_key(msg, cfg, scope_chat_ids=topics_chat_ids) - if topic_store is not None - else None + chat_session_key = _chat_session_key( + msg, store=state.chat_session_store ) + topic_key = resolve_topic_key(msg) engine_resolution = await resolve_engine_defaults( explicit_engine=resolved.engine_override, context=context, @@ -820,47 +1242,20 @@ async def run_main_loop( topic_key=topic_key, ) engine_override = engine_resolution.engine - if resume_token is None and reply_id is not None: - running_task = running_tasks.get( - MessageRef(channel_id=chat_id, message_id=reply_id) - ) - if running_task is not None: - tg.start_soon( - send_with_resume, - cfg, - scheduler.enqueue_resume, - running_task, - chat_id, - user_msg_id, - msg.thread_id, - chat_session_key, - prompt_text, - ) - return - if ( - resume_token is None - and topic_store is not None - and topic_key is not None - ): - engine_for_session = engine_resolution.engine - stored = await topic_store.get_session_resume( - topic_key[0], topic_key[1], engine_for_session - ) - if stored is not None: - resume_token = stored - if ( - resume_token is None - and chat_session_store is not None - and chat_session_key is not None - ): - engine_for_session = engine_resolution.engine - stored = await chat_session_store.get_session_resume( - chat_session_key[0], - chat_session_key[1], - engine_for_session, - ) - if stored is not None: - resume_token = stored + resume_decision = await resume_resolver.resolve( + resume_token=resume_token, + reply_id=reply_id, + chat_id=chat_id, + user_msg_id=user_msg_id, + thread_id=msg.thread_id, + chat_session_key=chat_session_key, + topic_key=topic_key, + engine_for_session=engine_resolution.engine, + prompt_text=prompt_text, + ) + if resume_decision.handled_by_running_task: + return + resume_token = resume_decision.resume_token if resume_token is None: await run_job( chat_id, @@ -943,22 +1338,24 @@ async def run_main_loop( engine_override = engine_resolution.engine effective_context = pending.ambient_context if ( - topic_store is not None + state.topic_store is not None and pending.topic_key is not None and resolved.context is not None and resolved.context_source == "directives" ): - await topic_store.set_context(*pending.topic_key, resolved.context) + await state.topic_store.set_context( + *pending.topic_key, resolved.context + ) await _maybe_rename_topic( cfg, - topic_store, + state.topic_store, chat_id=pending.topic_key[0], thread_id=pending.topic_key[1], context=resolved.context, ) effective_context = resolved.context if ( - topic_store is not None + state.topic_store is not None and pending.topic_key is not None and effective_context is None and resolved.context_source not in {"directives", "reply_ctx"} @@ -969,49 +1366,20 @@ async def run_main_loop( f"{_usage_topic(chat_project=pending.chat_project)}", ) return - if resume_token is None and pending.reply_id is not None: - running_task = running_tasks.get( - MessageRef(channel_id=chat_id, message_id=pending.reply_id) - ) - if running_task is not None: - tg.start_soon( - send_with_resume, - cfg, - scheduler.enqueue_resume, - running_task, - chat_id, - user_msg_id, - msg.thread_id, - pending.chat_session_key, - prompt_text, - ) - return - if ( - resume_token is None - and topic_store is not None - and pending.topic_key is not None - ): - engine_for_session = engine_resolution.engine - stored = await topic_store.get_session_resume( - pending.topic_key[0], - pending.topic_key[1], - engine_for_session, - ) - if stored is not None: - resume_token = stored - if ( - resume_token is None - and chat_session_store is not None - and pending.chat_session_key is not None - ): - engine_for_session = engine_resolution.engine - stored = await chat_session_store.get_session_resume( - pending.chat_session_key[0], - pending.chat_session_key[1], - engine_for_session, - ) - if stored is not None: - resume_token = stored + resume_decision = await resume_resolver.resolve( + resume_token=resume_token, + reply_id=pending.reply_id, + chat_id=chat_id, + user_msg_id=user_msg_id, + thread_id=msg.thread_id, + chat_session_key=pending.chat_session_key, + topic_key=pending.topic_key, + engine_for_session=engine_resolution.engine, + prompt_text=prompt_text, + ) + if resume_decision.handled_by_running_task: + return + resume_token = resume_decision.resume_token if resume_token is None: tg.start_soon( @@ -1047,151 +1415,12 @@ async def run_main_loop( progress_ref, ) - async def _debounce_prompt_run( - key: ForwardKey, pending: _PendingPrompt - ) -> None: - try: - with anyio.CancelScope() as scope: - pending.cancel_scope = scope - await anyio.sleep(forward_coalesce_s) - except anyio.get_cancelled_exc_class(): - return - if pending_prompts.get(key) is not pending: - return - pending_prompts.pop(key, None) - logger.debug( - "forward.prompt.run", - chat_id=pending.msg.chat_id, - thread_id=pending.msg.thread_id, - sender_id=pending.msg.sender_id, - message_id=pending.msg.message_id, - forward_count=len(pending.forwards), - debounce_s=forward_coalesce_s, - ) - await _dispatch_pending_prompt(pending) - - def _reschedule_prompt(key: ForwardKey, pending: _PendingPrompt) -> None: - if pending.cancel_scope is not None: - pending.cancel_scope.cancel() - pending.cancel_scope = None - tg.start_soon(_debounce_prompt_run, key, pending) - - def _cancel_pending_prompt(key: ForwardKey) -> None: - pending = pending_prompts.pop(key, None) - if pending is None: - return - if pending.cancel_scope is not None: - pending.cancel_scope.cancel() - logger.debug( - "forward.prompt.cancelled", - chat_id=pending.msg.chat_id, - thread_id=pending.msg.thread_id, - sender_id=pending.msg.sender_id, - message_id=pending.msg.message_id, - forward_count=len(pending.forwards), - ) - - def _schedule_prompt( - pending: _PendingPrompt, - ) -> None: - if pending.msg.sender_id is None: - logger.debug( - "forward.prompt.bypass", - chat_id=pending.msg.chat_id, - thread_id=pending.msg.thread_id, - sender_id=pending.msg.sender_id, - message_id=pending.msg.message_id, - reason="missing_sender", - ) - tg.start_soon(_dispatch_pending_prompt, pending) - return - if forward_coalesce_s <= 0: - logger.debug( - "forward.prompt.bypass", - chat_id=pending.msg.chat_id, - thread_id=pending.msg.thread_id, - sender_id=pending.msg.sender_id, - message_id=pending.msg.message_id, - reason="disabled", - ) - tg.start_soon(_dispatch_pending_prompt, pending) - return - key = _forward_key(pending.msg) - existing = pending_prompts.get(key) - if existing is not None: - if existing.cancel_scope is not None: - existing.cancel_scope.cancel() - if existing.forwards: - pending.forwards = list(existing.forwards) - logger.debug( - "forward.prompt.replace", - chat_id=pending.msg.chat_id, - thread_id=pending.msg.thread_id, - sender_id=pending.msg.sender_id, - old_message_id=existing.msg.message_id, - new_message_id=pending.msg.message_id, - forward_count=len(pending.forwards), - ) - pending_prompts[key] = pending - logger.debug( - "forward.prompt.schedule", - chat_id=pending.msg.chat_id, - thread_id=pending.msg.thread_id, - sender_id=pending.msg.sender_id, - message_id=pending.msg.message_id, - debounce_s=forward_coalesce_s, - ) - _reschedule_prompt(key, pending) - - def _attach_forward(msg: TelegramIncomingMessage) -> None: - if msg.sender_id is None: - logger.debug( - "forward.message.ignored", - chat_id=msg.chat_id, - thread_id=msg.thread_id, - sender_id=msg.sender_id, - message_id=msg.message_id, - reason="missing_sender", - ) - return - key = _forward_key(msg) - pending = pending_prompts.get(key) - if pending is None: - logger.debug( - "forward.message.ignored", - chat_id=msg.chat_id, - thread_id=msg.thread_id, - sender_id=msg.sender_id, - message_id=msg.message_id, - reason="no_pending_prompt", - ) - return - text = msg.text - if not text.strip(): - logger.debug( - "forward.message.ignored", - chat_id=msg.chat_id, - thread_id=msg.thread_id, - sender_id=msg.sender_id, - message_id=msg.message_id, - reason="empty_text", - ) - return - pending.forwards.append((msg.message_id, text)) - logger.debug( - "forward.message.attached", - chat_id=msg.chat_id, - thread_id=msg.thread_id, - sender_id=msg.sender_id, - message_id=msg.message_id, - prompt_message_id=pending.msg.message_id, - forward_count=len(pending.forwards), - forward_fields=_forward_fields_present(msg.raw), - forward_date=msg.raw.get("forward_date") if msg.raw else None, - message_date=msg.raw.get("date") if msg.raw else None, - text_len=len(text), - ) - _reschedule_prompt(key, pending) + forward_coalescer = ForwardCoalescer( + task_group=tg, + debounce_s=state.forward_coalesce_s, + dispatch=_dispatch_pending_prompt, + pending=state.pending_prompts, + ) async def handle_prompt_upload( msg: TelegramIncomingMessage, @@ -1206,7 +1435,7 @@ async def run_main_loop( ) if resolved is None: return - saved = await _save_file_put( + saved = await save_file_put( cfg, msg, "", @@ -1219,60 +1448,23 @@ async def run_main_loop( prompt = _build_upload_prompt(resolved.prompt, annotation) await run_prompt_from_upload(msg, prompt, resolved) - async def flush_media_group(key: tuple[int, str]) -> None: - while True: - state = media_groups.get(key) - if state is None: - return - token = state.token - await anyio.sleep(media_group_debounce_s) - state = media_groups.get(key) - if state is None: - return - if state.token != token: - continue - messages = list(state.messages) - del media_groups[key] - if not messages: - return - trigger_mode = await resolve_trigger_mode( - chat_id=messages[0].chat_id, - thread_id=messages[0].thread_id, - chat_prefs=chat_prefs, - topic_store=topic_store, - ) - if trigger_mode == "mentions" and not any( - should_trigger_run( - msg, - bot_username=bot_username, - runtime=cfg.runtime, - command_ids=command_ids, - reserved_chat_commands=reserved_chat_commands, - ) - for msg in messages - ): - return - await _handle_media_group( - cfg, - messages, - topic_store, - run_prompt_from_upload, - resolve_prompt_message, - ) - return + media_group_buffer = MediaGroupBuffer( + task_group=tg, + debounce_s=state.media_group_debounce_s, + cfg=cfg, + chat_prefs=state.chat_prefs, + topic_store=state.topic_store, + bot_username=state.bot_username, + command_ids=lambda: state.command_ids, + reserved_chat_commands=state.reserved_chat_commands, + groups=state.media_groups, + run_prompt_from_upload=run_prompt_from_upload, + resolve_prompt_message=resolve_prompt_message, + ) - async for msg in poller(cfg): - if isinstance(msg, TelegramCallbackQuery): - if msg.data == CANCEL_CALLBACK_DATA: - tg.start_soon( - handle_callback_cancel, cfg, msg, running_tasks, scheduler - ) - else: - tg.start_soon( - cfg.bot.answer_callback_query, - msg.callback_query_id, - ) - continue + async def build_message_context( + msg: TelegramIncomingMessage, + ) -> TelegramMsgContext: chat_id = msg.chat_id reply_id = msg.reply_to_message_id reply_ref = ( @@ -1280,6 +1472,35 @@ async def run_main_loop( if reply_id is not None else None ) + topic_key = resolve_topic_key(msg) + chat_session_key = _chat_session_key( + msg, store=state.chat_session_store + ) + stateful_mode = topic_key is not None or chat_session_key is not None + chat_project = ( + _topics_chat_project(cfg, chat_id) if cfg.topics.enabled else None + ) + bound_context = ( + await state.topic_store.get_context(*topic_key) + if state.topic_store is not None and topic_key is not None + else None + ) + ambient_context = _merge_topic_context( + chat_project=chat_project, bound=bound_context + ) + return TelegramMsgContext( + chat_id=chat_id, + thread_id=msg.thread_id, + reply_id=reply_id, + reply_ref=reply_ref, + topic_key=topic_key, + chat_session_key=chat_session_key, + stateful_mode=stateful_mode, + chat_project=chat_project, + ambient_context=ambient_context, + ) + + async def route_message(msg: TelegramIncomingMessage) -> None: reply = make_reply(cfg, msg) text = msg.text is_voice_transcribed = False @@ -1290,111 +1511,99 @@ async def run_main_loop( and msg.media_group_id is None ) if is_forward_candidate: - _attach_forward(msg) - continue + forward_coalescer.attach_forward(msg) + return forward_key = _forward_key(msg) if ( cfg.files.enabled and msg.document is not None and msg.media_group_id is not None ): - key = (chat_id, msg.media_group_id) - state = media_groups.get(key) - if state is None: - state = _MediaGroupState(messages=[]) - media_groups[key] = state - tg.start_soon(flush_media_group, key) - state.messages.append(msg) - state.token += 1 - continue - topic_key = ( - _topic_key(msg, cfg, scope_chat_ids=topics_chat_ids) - if topic_store is not None - else None - ) - chat_session_key = _chat_session_key(msg, store=chat_session_store) - stateful_mode = topic_key is not None or chat_session_key is not None - chat_project = ( - _topics_chat_project(cfg, chat_id) if cfg.topics.enabled else None - ) - bound_context = ( - await topic_store.get_context(*topic_key) - if topic_store is not None and topic_key is not None - else None - ) - ambient_context = _merge_topic_context( - chat_project=chat_project, bound=bound_context - ) + media_group_buffer.add(msg) + return + ctx = await build_message_context(msg) + chat_id = ctx.chat_id + reply_id = ctx.reply_id + reply_ref = ctx.reply_ref + topic_key = ctx.topic_key + chat_session_key = ctx.chat_session_key + stateful_mode = ctx.stateful_mode + chat_project = ctx.chat_project + ambient_context = ctx.ambient_context if is_cancel_command(text): - tg.start_soon(handle_cancel, cfg, msg, running_tasks, scheduler) - continue + tg.start_soon( + handle_cancel, cfg, msg, state.running_tasks, scheduler + ) + return - command_id, args_text = _parse_slash_command(text) + command_id, args_text = parse_slash_command(text) if command_id == "new": - _cancel_pending_prompt(forward_key) - if topic_store is not None and topic_key is not None: + forward_coalescer.cancel(forward_key) + if state.topic_store is not None and topic_key is not None: tg.start_soon( partial( - _handle_new_command, + handle_new_command, cfg, msg, - topic_store, - resolved_scope=resolved_topics_scope, - scope_chat_ids=topics_chat_ids, + state.topic_store, + resolved_scope=state.resolved_topics_scope, + scope_chat_ids=state.topics_chat_ids, ) ) - continue - if chat_session_store is not None: + return + if state.chat_session_store is not None: tg.start_soon( - _handle_chat_new_command, + handle_chat_new_command, cfg, msg, - chat_session_store, + state.chat_session_store, chat_session_key, ) - continue - if topic_store is not None: + return + if state.topic_store is not None: tg.start_soon( partial( - _handle_new_command, + handle_new_command, cfg, msg, - topic_store, - resolved_scope=resolved_topics_scope, - scope_chat_ids=topics_chat_ids, + state.topic_store, + resolved_scope=state.resolved_topics_scope, + scope_chat_ids=state.topics_chat_ids, ) ) - continue + return if command_id is not None and _dispatch_builtin_command( - cfg=cfg, - msg=msg, + ctx=TelegramCommandContext( + cfg=cfg, + msg=msg, + args_text=args_text, + ambient_context=ambient_context, + topic_store=state.topic_store, + chat_prefs=state.chat_prefs, + resolved_scope=state.resolved_topics_scope, + scope_chat_ids=state.topics_chat_ids, + reply=reply, + task_group=tg, + ), command_id=command_id, - args_text=args_text, - ambient_context=ambient_context, - topic_store=topic_store, - chat_prefs=chat_prefs, - resolved_scope=resolved_topics_scope, - scope_chat_ids=topics_chat_ids, - reply=reply, - task_group=tg, ): - continue + return trigger_mode = await resolve_trigger_mode( chat_id=chat_id, thread_id=msg.thread_id, - chat_prefs=chat_prefs, - topic_store=topic_store, + chat_prefs=state.chat_prefs, + topic_store=state.topic_store, ) if trigger_mode == "mentions" and not should_trigger_run( msg, - bot_username=bot_username, + bot_username=state.bot_username, runtime=cfg.runtime, - command_ids=command_ids, - reserved_chat_commands=reserved_chat_commands, + command_ids=state.command_ids, + reserved_chat_commands=state.reserved_chat_commands, ): - continue + return if msg.voice is not None: text = await transcribe_voice( @@ -1406,7 +1615,7 @@ async def run_main_loop( reply=reply, ) if text is None: - continue + return is_voice_transcribed = True if msg.document is not None: if cfg.files.enabled and cfg.files.auto_put: @@ -1417,15 +1626,15 @@ async def run_main_loop( msg, caption_text, ambient_context, - topic_store, + state.topic_store, ) elif not caption_text: tg.start_soon( - _handle_file_put_default, + handle_file_put_default, cfg, msg, ambient_context, - topic_store, + state.topic_store, ) else: tg.start_soon( @@ -1435,11 +1644,11 @@ async def run_main_loop( tg.start_soon( partial(reply, text=FILE_PUT_USAGE), ) - continue - if command_id is not None and command_id not in reserved_commands: - if command_id not in command_ids: + return + if command_id is not None and command_id not in state.reserved_commands: + if command_id not in state.command_ids: refresh_commands() - if command_id in command_ids: + if command_id in state.command_ids: engine_resolution = await resolve_engine_defaults( explicit_engine=None, context=ambient_context, @@ -1459,17 +1668,17 @@ async def run_main_loop( _resolve_engine_run_options, chat_id, overrides_thread_id, - chat_prefs=chat_prefs, - topic_store=topic_store, + chat_prefs=state.chat_prefs, + topic_store=state.topic_store, ) tg.start_soon( - _dispatch_command, + dispatch_command, cfg, msg, text, command_id, args_text, - running_tasks, + state.running_tasks, scheduler, wrap_on_thread_known( scheduler.note_thread_known, @@ -1480,7 +1689,7 @@ async def run_main_loop( default_engine_override, engine_overrides_resolver, ) - continue + return pending = _PendingPrompt( msg=msg, @@ -1494,7 +1703,7 @@ async def run_main_loop( is_voice_transcribed=is_voice_transcribed, forwards=[], ) - if reply_id is not None and running_tasks.get( + if reply_id is not None and state.running_tasks.get( MessageRef(channel_id=chat_id, message_id=reply_id) ): logger.debug( @@ -1506,7 +1715,28 @@ async def run_main_loop( reason="reply_resume", ) tg.start_soon(_dispatch_pending_prompt, pending) - continue - _schedule_prompt(pending) + return + forward_coalescer.schedule(pending) + + async def route_update(update: TelegramIncomingUpdate) -> None: + if isinstance(update, TelegramCallbackQuery): + if update.data == CANCEL_CALLBACK_DATA: + tg.start_soon( + handle_callback_cancel, + cfg, + update, + state.running_tasks, + scheduler, + ) + else: + tg.start_soon( + cfg.bot.answer_callback_query, + update.callback_query_id, + ) + return + await route_message(update) + + async for update in poller(cfg): + await route_update(update) finally: await cfg.exec_cfg.transport.close() diff --git a/src/takopi/telegram/types.py b/src/takopi/telegram/types.py index 1034118..193c3de 100644 --- a/src/takopi/telegram/types.py +++ b/src/takopi/telegram/types.py @@ -42,6 +42,12 @@ class TelegramIncomingMessage: document: TelegramDocument | None = None raw: dict[str, Any] | None = None + @property + def is_private(self) -> bool: + if self.chat_type is not None: + return self.chat_type == "private" + return self.chat_id > 0 + @dataclass(frozen=True, slots=True) class TelegramCallbackQuery: