diff --git a/docs/adding-a-runner.md b/docs/adding-a-runner.md index 1eee034..7117555 100644 --- a/docs/adding-a-runner.md +++ b/docs/adding-a-runner.md @@ -408,7 +408,7 @@ from ..schemas import acme as acme_schema logger = logging.getLogger(__name__) -ENGINE: EngineId = EngineId("acme") +ENGINE: EngineId = "acme" _RESUME_RE = re.compile( r"(?im)^\s*`?acme\s+--resume\s+(?P[^`\s]+)`?\s*$" ) diff --git a/pyproject.toml b/pyproject.toml index 09cd049..eb1fbb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ addopts = ["--cov=takopi", "--cov-report=term-missing", "--cov-fail-under=75"] testpaths = ["tests"] [tool.ruff.lint] -extend-select = ["B904", "BLE001", "S110", "RUF043"] +extend-select = ["B", "BLE001", "C4", "PERF", "RUF043", "S110", "SIM", "UP"] [tool.ty.src] include = ["src", "tests"] diff --git a/src/takopi/cli.py b/src/takopi/cli.py index 38e5393..6b6d8a0 100644 --- a/src/takopi/cli.py +++ b/src/takopi/cli.py @@ -256,7 +256,7 @@ def _run_auto_router( ) lock_token = transport_backend.lock_token( transport_config=transport_config, - config_path=config_path, + _config_path=config_path, ) lock_handle = acquire_config_lock(config_path, lock_token) runtime = spec.to_runtime(config_path=config_path) @@ -301,10 +301,7 @@ def _default_alias_from_path(path: Path) -> str | None: def _ensure_projects_table(config: dict, config_path: Path) -> dict: - projects = config.get("projects") - if projects is None: - projects = {} - config["projects"] = projects + projects = config.setdefault("projects", {}) if not isinstance(projects, dict): raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.") return projects diff --git a/src/takopi/config_watch.py b/src/takopi/config_watch.py index 88acbec..aa69847 100644 --- a/src/takopi/config_watch.py +++ b/src/takopi/config_watch.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path -from typing import Awaitable, Callable, Iterable +from collections.abc import Awaitable, Callable, Iterable from watchfiles import awatch diff --git a/src/takopi/engines.py b/src/takopi/engines.py index 244fa45..6edf206 100644 --- a/src/takopi/engines.py +++ b/src/takopi/engines.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Iterable +from collections.abc import Iterable from .backends import EngineBackend from .config import ConfigError diff --git a/src/takopi/lockfile.py b/src/takopi/lockfile.py index 049b6ba..8c6f006 100644 --- a/src/takopi/lockfile.py +++ b/src/takopi/lockfile.py @@ -11,7 +11,7 @@ from .logging import get_logger logger = get_logger(__name__) -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class LockInfo: pid: int | None token_fingerprint: str | None @@ -29,7 +29,7 @@ class LockError(RuntimeError): super().__init__(_format_lock_message(path, state)) -@dataclass +@dataclass(slots=True) class LockHandle: path: Path @@ -44,7 +44,7 @@ class LockHandle: error_type=exc.__class__.__name__, ) - def __enter__(self) -> "LockHandle": + def __enter__(self) -> LockHandle: return self def __exit__(self, exc_type, exc, tb) -> None: diff --git a/src/takopi/logging.py b/src/takopi/logging.py index 3775c70..a1f87d6 100644 --- a/src/takopi/logging.py +++ b/src/takopi/logging.py @@ -107,9 +107,8 @@ def _redact_value(value: Any, memo: dict[int, Any]) -> Any: def _redact_event_dict( - logger: Any, method_name: str, event_dict: dict[str, Any] + _logger: Any, _method_name: str, event_dict: dict[str, Any] ) -> dict[str, Any]: - _ = logger, method_name return _redact_value(event_dict, memo={}) @@ -222,10 +221,7 @@ def setup_logging( format_value = os.environ.get("TAKOPI_LOG_FORMAT", "console").strip().lower() color_override = os.environ.get("TAKOPI_LOG_COLOR") - if color_override is None: - is_tty = sys.stdout.isatty() - else: - is_tty = _truthy(color_override) + is_tty = sys.stdout.isatty() if color_override is None else _truthy(color_override) if format_value == "json": renderer: Any = structlog.processors.JSONRenderer(default=str) else: @@ -242,7 +238,9 @@ def setup_logging( _log_file_handle = None if log_file: try: - _log_file_handle = open(log_file, "a", encoding="utf-8") + _log_file_handle = open( # noqa: SIM115 + log_file, "a", encoding="utf-8" + ) except OSError: _log_file_handle = None diff --git a/src/takopi/markdown.py b/src/takopi/markdown.py index dea09f9..84dc4d9 100644 --- a/src/takopi/markdown.py +++ b/src/takopi/markdown.py @@ -120,9 +120,12 @@ def format_file_change_title(action: Action, *, command_width: int | None) -> st was_relativized = relativized != fallback if was_relativized: fallback = relativized - if fallback and not (fallback.startswith("`") and fallback.endswith("`")): - if was_relativized or os.sep in fallback or "/" in fallback: - fallback = f"`{fallback}`" + if ( + fallback + and not (fallback.startswith("`") and fallback.endswith("`")) + and (was_relativized or os.sep in fallback or "/" in fallback) + ): + fallback = f"`{fallback}`" return f"files: {shorten(fallback, command_width)}" @@ -247,10 +250,7 @@ class MarkdownFormatter: def _format_actions(self, state: ProgressState) -> list[str]: actions = list(state.actions) - if self.max_actions == 0: - actions = [] - else: - actions = actions[-self.max_actions :] + actions = [] if self.max_actions == 0 else actions[-self.max_actions :] return [ format_action_line( action_state.action, diff --git a/src/takopi/model.py b/src/takopi/model.py index b19de09..d8e30b7 100644 --- a/src/takopi/model.py +++ b/src/takopi/model.py @@ -3,11 +3,11 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Literal, TypeAlias +from typing import Any, Literal -EngineId: TypeAlias = str +type EngineId = str -ActionKind: TypeAlias = Literal[ +type ActionKind = Literal[ "command", "tool", "file_change", @@ -19,14 +19,14 @@ ActionKind: TypeAlias = Literal[ "telemetry", ] -TakopiEventType: TypeAlias = Literal[ +type TakopiEventType = Literal[ "started", "action", "completed", ] -ActionPhase: TypeAlias = Literal["started", "updated", "completed"] -ActionLevel: TypeAlias = Literal["debug", "info", "warning", "error"] +type ActionPhase = Literal["started", "updated", "completed"] +type ActionLevel = Literal["debug", "info", "warning", "error"] @dataclass(frozen=True, slots=True) @@ -74,4 +74,4 @@ class CompletedEvent: usage: dict[str, Any] | None = None -TakopiEvent: TypeAlias = StartedEvent | ActionEvent | CompletedEvent +type TakopiEvent = StartedEvent | ActionEvent | CompletedEvent diff --git a/src/takopi/plugins.py b/src/takopi/plugins.py index 813851b..9fc679f 100644 --- a/src/takopi/plugins.py +++ b/src/takopi/plugins.py @@ -1,10 +1,11 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping +from collections.abc import Iterable from dataclasses import dataclass from importlib.metadata import EntryPoint, entry_points import re -from typing import Any, Callable +from typing import Any +from collections.abc import Callable from .ids import ID_PATTERN, is_valid_id @@ -80,12 +81,7 @@ def reset_plugin_state() -> None: def _select_entrypoints(group: str) -> list[EntryPoint]: - eps = entry_points() - if hasattr(eps, "select"): - return list(eps.select(group=group)) - if isinstance(eps, Mapping): - return list(eps.get(group, [])) - return [] + return list(entry_points().select(group=group)) def entrypoint_distribution_name(ep: EntryPoint) -> str | None: diff --git a/src/takopi/progress.py b/src/takopi/progress.py index 763f576..9201197 100644 --- a/src/takopi/progress.py +++ b/src/takopi/progress.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Callable +from collections.abc import Callable from .model import Action, ActionEvent, ResumeToken, StartedEvent, TakopiEvent diff --git a/src/takopi/router.py b/src/takopi/router.py index c94a4a0..5e8c246 100644 --- a/src/takopi/router.py +++ b/src/takopi/router.py @@ -1,7 +1,8 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Iterable, Literal, TypeAlias +from typing import Literal +from collections.abc import Iterable from .model import EngineId, ResumeToken from .runner import Runner @@ -17,7 +18,7 @@ class RunnerUnavailableError(RuntimeError): self.issue = issue -EngineStatus: TypeAlias = Literal["ok", "missing_cli", "bad_config", "load_error"] +type EngineStatus = Literal["ok", "missing_cli", "bad_config", "load_error"] @dataclass(frozen=True, slots=True) diff --git a/src/takopi/runner_bridge.py b/src/takopi/runner_bridge.py index 6b95a24..14759a8 100644 --- a/src/takopi/runner_bridge.py +++ b/src/takopi/runner_bridge.py @@ -37,12 +37,9 @@ def _log_runner_event(evt: TakopiEvent) -> None: def _strip_resume_lines(text: str, *, is_resume_line: Callable[[str], bool]) -> str: - stripped_lines: list[str] = [] - for line in text.splitlines(): - if is_resume_line(line): - continue - stripped_lines.append(line) - prompt = "\n".join(stripped_lines).strip() + prompt = "\n".join( + line for line in text.splitlines() if not is_resume_line(line) + ).strip() return prompt or "continue" @@ -83,14 +80,14 @@ class IncomingMessage: thread_id: ThreadId | None = None -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class ExecBridgeConfig: transport: Transport presenter: Presenter final_notify: bool -@dataclass +@dataclass(slots=True) class RunningTask: resume: ResumeToken | None = None resume_ready: anyio.Event = field(default_factory=anyio.Event) diff --git a/src/takopi/runners/claude.py b/src/takopi/runners/claude.py index c3f3e93..7602178 100644 --- a/src/takopi/runners/claude.py +++ b/src/takopi/runners/claude.py @@ -14,11 +14,11 @@ from ..logging import get_logger from ..model import Action, ActionKind, EngineId, ResumeToken, TakopiEvent from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner from ..schemas import claude as claude_schema -from ..utils.paths import relativize_command, relativize_path +from .tool_actions import tool_input_path, tool_kind_and_title logger = get_logger(__name__) -ENGINE: EngineId = EngineId("claude") +ENGINE: EngineId = "claude" DEFAULT_ALLOWED_TOOLS = ["Bash", "Read", "Edit", "Write"] _RESUME_RE = re.compile( @@ -67,55 +67,10 @@ def _coerce_comma_list(value: Any) -> str | None: return text or None -def _tool_input_path(tool_input: dict[str, Any]) -> str | None: - for key in ("file_path", "path"): - value = tool_input.get(key) - if isinstance(value, str) and value: - return value - return None - - def _tool_kind_and_title( name: str, tool_input: dict[str, Any] ) -> tuple[ActionKind, str]: - if name in {"Bash", "Shell", "KillShell"}: - command = tool_input.get("command") - display = relativize_command(str(command or name)) - return "command", display - if name in {"Edit", "Write", "NotebookEdit", "MultiEdit"}: - path = _tool_input_path(tool_input) - if path: - return "file_change", relativize_path(str(path)) - return "file_change", str(name) - if name == "Read": - path = _tool_input_path(tool_input) - if path: - return "tool", f"read: `{relativize_path(str(path))}`" - return "tool", "read" - if name == "Glob": - pattern = tool_input.get("pattern") - if pattern: - return "tool", f"glob: `{pattern}`" - return "tool", "glob" - if name == "Grep": - pattern = tool_input.get("pattern") - if pattern: - return "tool", f"grep: {pattern}" - return "tool", "grep" - if name == "WebSearch": - query = tool_input.get("query") - return "web_search", str(query or "search") - if name == "WebFetch": - url = tool_input.get("url") - return "web_search", str(url or "fetch") - if name in {"TodoWrite", "TodoRead"}: - return "note", "update todos" if name == "TodoWrite" else "read todos" - if name == "AskUserQuestion": - return "note", "ask user" - if name in {"Task", "Agent"}: - desc = tool_input.get("description") or tool_input.get("prompt") - return "subagent", str(desc or name) - return "tool", name + return tool_kind_and_title(name, tool_input, path_keys=("file_path", "path")) def _tool_action( @@ -137,7 +92,7 @@ def _tool_action( detail["parent_tool_use_id"] = parent_tool_use_id if kind == "file_change": - path = _tool_input_path(tool_input) + path = tool_input_path(tool_input, path_keys=("file_path", "path")) if path: detail["changes"] = [{"path": path, "kind": "update"}] @@ -321,7 +276,7 @@ def translate_claude_event( return [] -@dataclass +@dataclass(slots=True) class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner): engine: EngineId = ENGINE resume_re: re.Pattern[str] = _RESUME_RE diff --git a/src/takopi/runners/codex.py b/src/takopi/runners/codex.py index 2c81b6a..503044d 100644 --- a/src/takopi/runners/codex.py +++ b/src/takopi/runners/codex.py @@ -18,7 +18,7 @@ from ..utils.paths import relativize_command logger = get_logger(__name__) -ENGINE: EngineId = EngineId("codex") +ENGINE: EngineId = "codex" __all__ = [ "ENGINE", diff --git a/src/takopi/runners/mock.py b/src/takopi/runners/mock.py index c9e0cb4..6d70280 100644 --- a/src/takopi/runners/mock.py +++ b/src/takopi/runners/mock.py @@ -4,7 +4,6 @@ import re import uuid from collections.abc import AsyncIterator, Awaitable, Callable, Iterable from dataclasses import dataclass, replace -from typing import TypeAlias import anyio @@ -18,7 +17,7 @@ from ..model import ( ) from ..runner import ResumeTokenMixin, Runner, SessionLockMixin -ENGINE: EngineId = EngineId("mock") +ENGINE: EngineId = "mock" @dataclass(frozen=True, slots=True) @@ -52,7 +51,7 @@ class Raise: error: Exception -ScriptStep: TypeAlias = Emit | Advance | Sleep | Wait | Return | Raise +type ScriptStep = Emit | Advance | Sleep | Wait | Return | Raise def _resume_token(engine: EngineId, value: str | None) -> ResumeToken: @@ -108,9 +107,9 @@ class MockRunner(SessionLockMixin, ResumeTokenMixin, Runner): if ( isinstance(event_out, ActionEvent) and event_out.phase == "completed" + and event_out.ok is None ): - if event_out.ok is None: - event_out = replace(event_out, ok=True) + event_out = replace(event_out, ok=True) yield event_out await anyio.sleep(0) @@ -187,9 +186,9 @@ class ScriptRunner(MockRunner): if ( isinstance(event_out, ActionEvent) and event_out.phase == "completed" + and event_out.ok is None ): - if event_out.ok is None: - event_out = replace(event_out, ok=True) + event_out = replace(event_out, ok=True) yield event_out await anyio.sleep(0) continue diff --git a/src/takopi/runners/opencode.py b/src/takopi/runners/opencode.py index a28bb17..13ee49b 100644 --- a/src/takopi/runners/opencode.py +++ b/src/takopi/runners/opencode.py @@ -35,11 +35,12 @@ from ..model import ( ) from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner from ..schemas import opencode as opencode_schema -from ..utils.paths import relativize_command, relativize_path +from ..utils.paths import relativize_path +from .tool_actions import tool_input_path, tool_kind_and_title logger = get_logger(__name__) -ENGINE: EngineId = EngineId("opencode") +ENGINE: EngineId = "opencode" _RESUME_RE = re.compile( r"(?im)^\s*`?opencode(?:\s+run)?\s+(?:--session|-s)\s+(?Pses_[A-Za-z0-9]+)`?\s*$" @@ -79,54 +80,12 @@ def _action_event( def _tool_kind_and_title( tool_name: str, tool_input: dict[str, Any] ) -> tuple[ActionKind, str]: - """Map OpenCode tool names to Takopi action kinds and titles.""" - name_lower = tool_name.lower() - - if name_lower in {"bash", "shell"}: - command = tool_input.get("command") - display = relativize_command(str(command or tool_name)) - return "command", display - - if name_lower in {"edit", "write", "multiedit"}: - path = tool_input.get("file_path") or tool_input.get("filePath") - if path: - return "file_change", relativize_path(str(path)) - return "file_change", str(tool_name) - - if name_lower == "read": - path = tool_input.get("file_path") or tool_input.get("filePath") - if path: - return "tool", f"read: `{relativize_path(str(path))}`" - return "tool", "read" - - if name_lower == "glob": - pattern = tool_input.get("pattern") - if pattern: - return "tool", f"glob: `{pattern}`" - return "tool", "glob" - - if name_lower == "grep": - pattern = tool_input.get("pattern") - if pattern: - return "tool", f"grep: {pattern}" - return "tool", "grep" - - if name_lower in {"websearch", "web_search"}: - query = tool_input.get("query") - return "web_search", str(query or "search") - - if name_lower in {"webfetch", "web_fetch"}: - url = tool_input.get("url") - return "web_search", str(url or "fetch") - - if name_lower in {"todowrite", "todoread"}: - return "note", "update todos" if "write" in name_lower else "read todos" - - if name_lower == "task": - desc = tool_input.get("description") or tool_input.get("prompt") - return "tool", str(desc or tool_name) - - return "tool", tool_name + return tool_kind_and_title( + tool_name, + tool_input, + path_keys=("file_path", "filePath"), + task_kind="tool", + ) def _normalize_tool_title( @@ -137,10 +96,10 @@ def _normalize_tool_title( if "`" in title: return title - path = tool_input.get("file_path") or tool_input.get("filePath") + path = tool_input_path(tool_input, path_keys=("file_path", "filePath")) if isinstance(path, str) and path: rel_path = relativize_path(path) - if title == path or title == rel_path: + if title in (path, rel_path): return f"`{rel_path}`" return title @@ -190,9 +149,8 @@ def translate_opencode_event( """Translate an OpenCode JSON event into Takopi events.""" session_id = event.sessionID - if isinstance(session_id, str) and session_id: - if state.session_id is None: - state.session_id = session_id + if isinstance(session_id, str) and session_id and state.session_id is None: + state.session_id = session_id match event: case opencode_schema.StepStart(): @@ -340,7 +298,7 @@ def translate_opencode_event( return [] -@dataclass +@dataclass(slots=True) class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner): """Runner for OpenCode CLI.""" diff --git a/src/takopi/runners/pi.py b/src/takopi/runners/pi.py index ef64fb0..a5ae786 100644 --- a/src/takopi/runners/pi.py +++ b/src/takopi/runners/pi.py @@ -3,7 +3,7 @@ from __future__ import annotations import os import re from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import datetime, UTC from pathlib import Path, PurePath from typing import Any from uuid import uuid4 @@ -27,11 +27,12 @@ from ..model import ( ) from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner from ..schemas import pi as pi_schema -from ..utils.paths import get_run_base_dir, relativize_command, relativize_path +from ..utils.paths import get_run_base_dir +from .tool_actions import tool_kind_and_title logger = get_logger(__name__) -ENGINE: EngineId = EngineId("pi") +ENGINE: EngineId = "pi" _RESUME_RE = re.compile(r"(?im)^\s*`?pi\s+--session\s+(?P.+?)`?\s*$") @@ -96,32 +97,7 @@ def _tool_kind_and_title( name: str, args: dict[str, Any], ) -> tuple[ActionKind, str]: - tool = name.lower() - if tool == "bash": - command = args.get("command") - return "command", relativize_command(str(command or "bash")) - if tool in {"edit", "write"}: - path = args.get("path") - if path: - return "file_change", relativize_path(str(path)) - return "file_change", tool - if tool == "read": - path = args.get("path") - if path: - return "tool", f"read: `{relativize_path(str(path))}`" - return "tool", "read" - if tool == "grep": - pattern = args.get("pattern") - return "tool", f"grep: {pattern}" if pattern else "grep" - if tool == "find": - pattern = args.get("pattern") - return "tool", f"find: {pattern}" if pattern else "find" - if tool == "ls": - path = args.get("path") - if path: - return "tool", f"ls: `{relativize_path(str(path))}`" - return "tool", "ls" - return "tool", name + return tool_kind_and_title(name, args, path_keys=("path",)) def _last_assistant_message(messages: Any) -> dict[str, Any] | None: @@ -418,7 +394,7 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner): cwd = get_run_base_dir() or Path.cwd() session_dir = _default_session_dir(cwd) session_dir.mkdir(parents=True, exist_ok=True) - timestamp = datetime.now(timezone.utc).isoformat() + timestamp = datetime.now(UTC).isoformat() safe_timestamp = timestamp.replace(":", "-").replace(".", "-") token = uuid4().hex filename = f"{safe_timestamp}_{token}.jsonl" @@ -442,7 +418,9 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner): def _default_session_dir(cwd: PurePath) -> Path: agent_dir = os.environ.get("PI_CODING_AGENT_DIR") base = Path(agent_dir).expanduser() if agent_dir else Path.home() / ".pi" / "agent" - safe_path = f"--{str(cwd).lstrip('/\\\\').replace('/', '-').replace('\\', '-').replace(':', '-')}--" + cwd_str = str(cwd).lstrip("/\\") + safe_path_part = cwd_str.translate(str.maketrans({"/": "-", "\\": "-", ":": "-"})) + safe_path = f"--{safe_path_part}--" return base / "sessions" / safe_path diff --git a/src/takopi/runners/tool_actions.py b/src/takopi/runners/tool_actions.py new file mode 100644 index 0000000..e1c63a9 --- /dev/null +++ b/src/takopi/runners/tool_actions.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any + +from ..model import ActionKind +from ..utils.paths import relativize_command, relativize_path + + +def tool_input_path( + tool_input: Mapping[str, Any], + *, + path_keys: Sequence[str], +) -> str | None: + for key in path_keys: + value = tool_input.get(key) + if isinstance(value, str) and value: + return value + return None + + +def tool_kind_and_title( + tool_name: str, + tool_input: Mapping[str, Any], + *, + path_keys: Sequence[str], + task_kind: ActionKind = "subagent", +) -> tuple[ActionKind, str]: + name_lower = tool_name.lower() + + if name_lower in {"bash", "shell", "killshell"}: + command = tool_input.get("command") + display = relativize_command(str(command or tool_name)) + return "command", display + + if name_lower in {"edit", "write", "notebookedit", "multiedit"}: + path = tool_input_path(tool_input, path_keys=path_keys) + if path: + return "file_change", relativize_path(str(path)) + return "file_change", str(tool_name) + + if name_lower == "read": + path = tool_input_path(tool_input, path_keys=path_keys) + if path: + return "tool", f"read: `{relativize_path(str(path))}`" + return "tool", "read" + + if name_lower == "glob": + pattern = tool_input.get("pattern") + if pattern: + return "tool", f"glob: `{pattern}`" + return "tool", "glob" + + if name_lower == "grep": + pattern = tool_input.get("pattern") + if pattern: + return "tool", f"grep: {pattern}" + return "tool", "grep" + + if name_lower == "find": + pattern = tool_input.get("pattern") + if pattern: + return "tool", f"find: {pattern}" + return "tool", "find" + + if name_lower == "ls": + path = tool_input_path(tool_input, path_keys=path_keys) + if path: + return "tool", f"ls: `{relativize_path(str(path))}`" + return "tool", "ls" + + if name_lower in {"websearch", "web_search"}: + query = tool_input.get("query") + return "web_search", str(query or "search") + + if name_lower in {"webfetch", "web_fetch"}: + url = tool_input.get("url") + return "web_search", str(url or "fetch") + + if name_lower in {"todowrite", "todoread"}: + return "note", "update todos" if "write" in name_lower else "read todos" + + if name_lower == "askuserquestion": + return "note", "ask user" + + if name_lower in {"task", "agent"}: + desc = tool_input.get("description") or tool_input.get("prompt") + return task_kind, str(desc or tool_name) + + return "tool", tool_name diff --git a/src/takopi/runtime_loader.py b/src/takopi/runtime_loader.py index bf6cae6..314ac4a 100644 --- a/src/takopi/runtime_loader.py +++ b/src/takopi/runtime_loader.py @@ -3,7 +3,8 @@ from __future__ import annotations import shutil from dataclasses import dataclass from pathlib import Path -from typing import Any, Iterable, Mapping +from typing import Any +from collections.abc import Iterable, Mapping from .backends import EngineBackend from .config import ConfigError, ProjectsConfig diff --git a/src/takopi/scheduler.py b/src/takopi/scheduler.py index bf45c74..8daf3d0 100644 --- a/src/takopi/scheduler.py +++ b/src/takopi/scheduler.py @@ -2,14 +2,18 @@ from __future__ import annotations from collections import deque from dataclasses import dataclass -from typing import Any, Awaitable, Callable, Protocol +from typing import Any, Protocol +from collections.abc import Awaitable, Callable import anyio from .context import RunContext +from .logging import get_logger from .model import ResumeToken from .transport import ChannelId, MessageId, ThreadId +logger = get_logger(__name__) + @dataclass(frozen=True, slots=True) class ThreadJob: @@ -108,7 +112,18 @@ class ThreadScheduler: if done is not None and not done.is_set(): await done.wait() - await self._run_job(job) + try: + await self._run_job(job) + except Exception as exc: # noqa: BLE001 + logger.exception( + "scheduler.job_failed", + key=key, + tag=job.resume_token.engine, + chat_id=job.chat_id, + user_msg_id=job.user_msg_id, + error=str(exc), + error_type=exc.__class__.__name__, + ) finally: async with self._lock: self._active_threads.discard(key) diff --git a/src/takopi/schemas/claude.py b/src/takopi/schemas/claude.py index 76b3812..e11efcf 100644 --- a/src/takopi/schemas/claude.py +++ b/src/takopi/schemas/claude.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Literal, TypeAlias +from typing import Any, Literal import msgspec @@ -36,7 +36,7 @@ class StreamToolResultBlock( is_error: bool | None = None -StreamContentBlock: TypeAlias = ( +type StreamContentBlock = ( StreamTextBlock | StreamThinkingBlock | StreamToolUseBlock | StreamToolResultBlock ) @@ -164,7 +164,7 @@ class ControlRewindFilesRequest( user_message_id: str -ControlRequest: TypeAlias = ( +type ControlRequest = ( ControlInterruptRequest | ControlCanUseToolRequest | ControlInitializeRequest @@ -196,7 +196,7 @@ class ControlErrorResponse( error: str -ControlResponse: TypeAlias = ControlSuccessResponse | ControlErrorResponse +type ControlResponse = ControlSuccessResponse | ControlErrorResponse class StreamControlResponse( @@ -217,7 +217,7 @@ class StreamControlCancelRequest( request_id: str | None = None -StreamJsonMessage: TypeAlias = ( +type StreamJsonMessage = ( StreamUserMessage | StreamAssistantMessage | StreamSystemMessage diff --git a/src/takopi/schemas/codex.py b/src/takopi/schemas/codex.py index b17e0e7..e3f10ce 100644 --- a/src/takopi/schemas/codex.py +++ b/src/takopi/schemas/codex.py @@ -2,27 +2,27 @@ from __future__ import annotations # Headless JSONL schema derived from tag rust-v0.77.0 (git 112f40e91c12af0f7146d7e03f20283516a8af0b). -from typing import Any, Literal, TypeAlias +from typing import Any, Literal import msgspec -CommandExecutionStatus: TypeAlias = Literal[ +type CommandExecutionStatus = Literal[ "in_progress", "completed", "failed", "declined", ] -PatchApplyStatus: TypeAlias = Literal[ +type PatchApplyStatus = Literal[ "in_progress", "completed", "failed", ] -PatchChangeKind: TypeAlias = Literal[ +type PatchChangeKind = Literal[ "add", "delete", "update", ] -McpToolCallStatus: TypeAlias = Literal[ +type McpToolCallStatus = Literal[ "in_progress", "completed", "failed", @@ -127,7 +127,7 @@ class TodoListItem(msgspec.Struct, tag="todo_list", kw_only=True): items: list[TodoItem] -ThreadItem: TypeAlias = ( +type ThreadItem = ( AgentMessageItem | ReasoningItem | CommandExecutionItem @@ -151,7 +151,7 @@ class ItemCompleted(msgspec.Struct, tag="item.completed", kw_only=True): item: ThreadItem -ThreadEvent: TypeAlias = ( +type ThreadEvent = ( ThreadStarted | TurnStarted | TurnCompleted diff --git a/src/takopi/schemas/opencode.py b/src/takopi/schemas/opencode.py index 777a066..e0038a1 100644 --- a/src/takopi/schemas/opencode.py +++ b/src/takopi/schemas/opencode.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, TypeAlias +from typing import Any import msgspec @@ -42,7 +42,7 @@ class Error(_Event, tag="error"): message: Any = None -OpenCodeEvent: TypeAlias = StepStart | StepFinish | ToolUse | Text | Error +type OpenCodeEvent = StepStart | StepFinish | ToolUse | Text | Error _DECODER = msgspec.json.Decoder(OpenCodeEvent) diff --git a/src/takopi/schemas/pi.py b/src/takopi/schemas/pi.py index 52eb6fe..4018a33 100644 --- a/src/takopi/schemas/pi.py +++ b/src/takopi/schemas/pi.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, TypeAlias +from typing import Any import msgspec @@ -84,7 +84,7 @@ class AutoRetryEnd(_Event, tag="auto_retry_end"): finalError: str | None = None -PiEvent: TypeAlias = ( +type PiEvent = ( AgentStart | AgentEnd | MessageStart diff --git a/src/takopi/settings.py b/src/takopi/settings.py index e24f837..cbce6ba 100644 --- a/src/takopi/settings.py +++ b/src/takopi/settings.py @@ -1,7 +1,8 @@ from __future__ import annotations from pathlib import Path -from typing import Annotated, Any, ClassVar, Iterable, Literal +from typing import Annotated, Any, ClassVar, Literal +from collections.abc import Iterable from pydantic import ( BaseModel, diff --git a/src/takopi/telegram/backend.py b/src/takopi/telegram/backend.py index de15b7f..25b1f3f 100644 --- a/src/takopi/telegram/backend.py +++ b/src/takopi/telegram/backend.py @@ -50,9 +50,7 @@ def _build_startup_message( notes.append(f"failed to load: {', '.join(failed_engines)}") if notes: engine_list = f"{engine_list} ({'; '.join(notes)})" - project_aliases = sorted( - {alias for alias in runtime.project_aliases()}, key=str.lower - ) + project_aliases = sorted(set(runtime.project_aliases()), key=str.lower) project_list = ", ".join(project_aliases) if project_aliases else "none" return ( f"\N{OCTOPUS} **takopi is ready**\n\n" @@ -78,8 +76,7 @@ class TelegramBackend(TransportBackend): def interactive_setup(self, *, force: bool) -> bool: return interactive_setup(force=force) - def lock_token(self, *, transport_config: object, config_path: Path) -> str | None: - _ = config_path + def lock_token(self, *, transport_config: object, _config_path: Path) -> str | None: settings = _expect_transport_settings(transport_config) return settings.bot_token diff --git a/src/takopi/telegram/bridge.py b/src/takopi/telegram/bridge.py index d3095c1..d9d7e22 100644 --- a/src/takopi/telegram/bridge.py +++ b/src/takopi/telegram/bridge.py @@ -111,7 +111,7 @@ def _is_cancelled_label(label: str) -> bool: return stripped.lower() == "cancelled" -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class TelegramBridgeConfig: bot: BotClient runtime: TransportRuntime diff --git a/src/takopi/telegram/client.py b/src/takopi/telegram/client.py index f3b2f4a..21b4224 100644 --- a/src/takopi/telegram/client.py +++ b/src/takopi/telegram/client.py @@ -2,589 +2,40 @@ from __future__ import annotations import itertools import time -from dataclasses import dataclass, field -from typing import ( - Any, - AsyncIterator, - Awaitable, - Callable, - Hashable, - Iterable, - Protocol, - TYPE_CHECKING, - TypeVar, -) - -import msgspec -import httpx +from typing import Any +from collections.abc import Awaitable, Callable, Hashable import anyio +import httpx from ..logging import get_logger from .api_models import Chat, ChatMember, File, ForumTopic, Message, Update, User -from .types import ( - TelegramCallbackQuery, - TelegramDocument, - TelegramIncomingMessage, - TelegramIncomingUpdate, - TelegramVoice, +from .client_api import BotClient, HttpBotClient, TelegramRetryAfter +from .outbox import ( + DELETE_PRIORITY, + EDIT_PRIORITY, + SEND_PRIORITY, + OutboxOp, + TelegramOutbox, ) +from .parsing import parse_incoming_update, poll_incoming logger = get_logger(__name__) -T = TypeVar("T") - - -SEND_PRIORITY = 0 -DELETE_PRIORITY = 1 -EDIT_PRIORITY = 2 - - -class RetryAfter(Exception): - def __init__(self, retry_after: float, description: str | None = None) -> None: - super().__init__(description or f"retry after {retry_after}") - self.retry_after = float(retry_after) - self.description = description - - -class TelegramRetryAfter(RetryAfter): - pass +__all__ = [ + "BotClient", + "TelegramClient", + "TelegramRetryAfter", + "is_group_chat_id", + "parse_incoming_update", + "poll_incoming", +] def is_group_chat_id(chat_id: int) -> bool: return chat_id < 0 -def parse_incoming_update( - update: Update | dict[str, Any], - *, - chat_id: int | None = None, - chat_ids: set[int] | None = None, -) -> TelegramIncomingUpdate | None: - if isinstance(update, Update): - msg = update.message - callback_query = update.callback_query - else: - msg = update.get("message") - callback_query = update.get("callback_query") - - if isinstance(msg, dict): - return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids) - if isinstance(callback_query, dict): - return _parse_callback_query( - callback_query, - chat_id=chat_id, - chat_ids=chat_ids, - ) - return None - - -def _parse_incoming_message( - msg: dict[str, Any], - *, - chat_id: int | None = None, - chat_ids: set[int] | None = None, -) -> TelegramIncomingMessage | None: - def _parse_document_payload(payload: dict[str, Any]) -> TelegramDocument | None: - file_id = payload.get("file_id") - if not isinstance(file_id, str) or not file_id: - return None - return TelegramDocument( - file_id=file_id, - file_name=payload.get("file_name") - if isinstance(payload.get("file_name"), str) - else None, - mime_type=payload.get("mime_type") - if isinstance(payload.get("mime_type"), str) - else None, - file_size=payload.get("file_size") - if isinstance(payload.get("file_size"), int) - and not isinstance(payload.get("file_size"), bool) - else None, - raw=payload, - ) - - raw_text = msg.get("text") - text = raw_text if isinstance(raw_text, str) else None - caption = msg.get("caption") - if text is None and isinstance(caption, str): - text = caption - if text is None: - text = "" - file_command = False - if isinstance(text, str): - stripped = text.lstrip() - if stripped.startswith("/"): - token = stripped.split(maxsplit=1)[0] - file_command = token.startswith("/file") - voice_payload: TelegramVoice | None = None - voice = msg.get("voice") - if isinstance(voice, dict): - file_id = voice.get("file_id") - if not isinstance(file_id, str) or not file_id: - file_id = None - if file_id is not None: - voice_payload = TelegramVoice( - file_id=file_id, - mime_type=voice.get("mime_type") - if isinstance(voice.get("mime_type"), str) - else None, - file_size=voice.get("file_size") - if isinstance(voice.get("file_size"), int) - and not isinstance(voice.get("file_size"), bool) - else None, - duration=voice.get("duration") - if isinstance(voice.get("duration"), int) - and not isinstance(voice.get("duration"), bool) - else None, - raw=voice, - ) - if not isinstance(raw_text, str) and not isinstance(caption, str): - text = "" - document_payload: TelegramDocument | None = None - document = msg.get("document") - if isinstance(document, dict): - document_payload = _parse_document_payload(document) - if document_payload is None: - video = msg.get("video") - if isinstance(video, dict): - document_payload = _parse_document_payload(video) - if document_payload is None: - photo = msg.get("photo") - if isinstance(photo, list): - best: dict[str, Any] | None = None - best_score = -1 - for item in photo: - if not isinstance(item, dict): - continue - file_id = item.get("file_id") - if not isinstance(file_id, str) or not file_id: - continue - size = item.get("file_size") - if isinstance(size, int) and not isinstance(size, bool): - score = size - else: - width = item.get("width") - height = item.get("height") - if isinstance(width, int) and isinstance(height, int): - score = width * height - else: - score = 0 - if score > best_score: - best_score = score - best = item - if best is not None: - document_payload = _parse_document_payload(best) - if document_payload is None and file_command: - sticker = msg.get("sticker") - if isinstance(sticker, dict): - document_payload = _parse_document_payload(sticker) - has_text = isinstance(raw_text, str) or isinstance(caption, str) - if not has_text and voice_payload is None and document_payload is None: - return None - chat = msg.get("chat") - if not isinstance(chat, dict): - return None - msg_chat_id = chat.get("id") - if not isinstance(msg_chat_id, int): - return None - chat_type = chat.get("type") if isinstance(chat.get("type"), str) else None - is_forum = chat.get("is_forum") - if not isinstance(is_forum, bool): - is_forum = None - allowed = chat_ids - if allowed is None and chat_id is not None: - allowed = {chat_id} - if allowed is not None and msg_chat_id not in allowed: - return None - message_id = msg.get("message_id") - if not isinstance(message_id, int): - return None - reply = msg.get("reply_to_message") - reply_to_message_id = None - reply_to_text = None - if isinstance(reply, dict): - reply_to_message_id = ( - reply.get("message_id") - if isinstance(reply.get("message_id"), int) - else None - ) - reply_to_text = ( - reply.get("text") if isinstance(reply.get("text"), str) else None - ) - sender = msg.get("from") - sender_id = ( - sender.get("id") - if isinstance(sender, dict) and isinstance(sender.get("id"), int) - else None - ) - media_group_id = msg.get("media_group_id") - if not isinstance(media_group_id, str): - media_group_id = None - thread_id = msg.get("message_thread_id") - if isinstance(thread_id, bool) or not isinstance(thread_id, int): - thread_id = None - is_topic_message = msg.get("is_topic_message") - if not isinstance(is_topic_message, bool): - is_topic_message = None - return TelegramIncomingMessage( - transport="telegram", - chat_id=msg_chat_id, - message_id=message_id, - text=text, - reply_to_message_id=reply_to_message_id, - reply_to_text=reply_to_text, - sender_id=sender_id, - media_group_id=media_group_id, - thread_id=thread_id, - is_topic_message=is_topic_message, - chat_type=chat_type, - is_forum=is_forum, - voice=voice_payload, - document=document_payload, - raw=msg, - ) - - -def _parse_callback_query( - query: dict[str, Any], - *, - chat_id: int | None = None, - chat_ids: set[int] | None = None, -) -> TelegramCallbackQuery | None: - callback_id = query.get("id") - if not isinstance(callback_id, str) or not callback_id: - return None - msg = query.get("message") - if not isinstance(msg, dict): - return None - chat = msg.get("chat") - if not isinstance(chat, dict): - return None - msg_chat_id = chat.get("id") - if not isinstance(msg_chat_id, int): - return None - allowed = chat_ids - if allowed is None and chat_id is not None: - allowed = {chat_id} - if allowed is not None and msg_chat_id not in allowed: - return None - message_id = msg.get("message_id") - if not isinstance(message_id, int): - return None - data = query.get("data") if isinstance(query.get("data"), str) else None - sender = query.get("from") - sender_id = ( - sender.get("id") - if isinstance(sender, dict) and isinstance(sender.get("id"), int) - else None - ) - return TelegramCallbackQuery( - transport="telegram", - chat_id=msg_chat_id, - message_id=message_id, - callback_query_id=callback_id, - data=data, - sender_id=sender_id, - raw=query, - ) - - -async def poll_incoming( - bot: BotClient, - *, - chat_id: int | None = None, - chat_ids: Iterable[int] | Callable[[], Iterable[int]] | None = None, - offset: int | None = None, -) -> AsyncIterator[TelegramIncomingUpdate]: - while True: - updates = await bot.get_updates( - offset=offset, - timeout_s=50, - allowed_updates=["message", "callback_query"], - ) - if updates is None: - logger.info("loop.get_updates.failed") - await anyio.sleep(2) - continue - logger.debug("loop.updates", updates=updates) - resolved_chat_ids = chat_ids() if callable(chat_ids) else chat_ids - allowed = set(resolved_chat_ids) if resolved_chat_ids is not None else None - if allowed is None and chat_id is not None: - allowed = {chat_id} - for upd in updates: - offset = upd.update_id + 1 - msg = parse_incoming_update(upd, chat_ids=allowed) - if msg is not None: - yield msg - - -class BotClient(Protocol): - async def close(self) -> None: ... - - async def get_updates( - self, - offset: int | None, - timeout_s: int = 50, - allowed_updates: list[str] | None = None, - ) -> list[Update] | None: ... - - async def get_file(self, file_id: str) -> File | None: ... - - async def download_file(self, file_path: str) -> bytes | None: ... - - async def send_message( - self, - chat_id: int, - text: str, - reply_to_message_id: int | None = None, - disable_notification: bool | None = False, - message_thread_id: int | None = None, - entities: list[dict] | None = None, - parse_mode: str | None = None, - reply_markup: dict[str, Any] | None = None, - *, - replace_message_id: int | None = None, - ) -> Message | None: ... - - async def send_document( - self, - chat_id: int, - filename: str, - content: bytes, - reply_to_message_id: int | None = None, - message_thread_id: int | None = None, - disable_notification: bool | None = False, - caption: str | None = None, - ) -> Message | None: ... - - async def edit_message_text( - self, - chat_id: int, - message_id: int, - text: str, - entities: list[dict] | None = None, - parse_mode: str | None = None, - reply_markup: dict[str, Any] | None = None, - *, - wait: bool = True, - ) -> Message | None: ... - - async def delete_message( - self, - chat_id: int, - message_id: int, - ) -> bool: ... - - async def set_my_commands( - self, - commands: list[dict[str, Any]], - *, - scope: dict[str, Any] | None = None, - language_code: str | None = None, - ) -> bool: ... - - async def get_me(self) -> User | None: ... - - async def answer_callback_query( - self, - callback_query_id: str, - text: str | None = None, - show_alert: bool | None = None, - ) -> bool: ... - - async def get_chat(self, chat_id: int) -> Chat | None: ... - - async def get_chat_member( - self, chat_id: int, user_id: int - ) -> ChatMember | None: ... - - async def create_forum_topic( - self, - chat_id: int, - name: str, - ) -> ForumTopic | None: ... - - async def edit_forum_topic( - self, - chat_id: int, - message_thread_id: int, - name: str, - ) -> bool: ... - - -if TYPE_CHECKING: - from anyio.abc import TaskGroup -else: - TaskGroup = object - - -@dataclass(slots=True) -class OutboxOp: - execute: Callable[[], Awaitable[Any]] - priority: int - queued_at: float - chat_id: int | None - label: str | None = None - done: anyio.Event = field(default_factory=anyio.Event) - result: Any = None - - def set_result(self, result: Any) -> None: - if self.done.is_set(): - return - self.result = result - self.done.set() - - -class TelegramOutbox: - def __init__( - self, - *, - interval_for_chat: Callable[[int | None], float], - clock: Callable[[], float] = time.monotonic, - sleep: Callable[[float], Awaitable[None]] = anyio.sleep, - on_error: Callable[[OutboxOp, Exception], None] | None = None, - on_outbox_error: Callable[[Exception], None] | None = None, - ) -> None: - self._interval_for_chat = interval_for_chat - self._clock = clock - self._sleep = sleep - self._on_error = on_error - self._on_outbox_error = on_outbox_error - self._pending: dict[Hashable, OutboxOp] = {} - self._cond = anyio.Condition() - self._start_lock = anyio.Lock() - self._closed = False - self._tg: TaskGroup | None = None - self.next_at = 0.0 - self.retry_at = 0.0 - - async def ensure_worker(self) -> None: - async with self._start_lock: - if self._tg is not None or self._closed: - return - self._tg = await anyio.create_task_group().__aenter__() - self._tg.start_soon(self.run) - - async def enqueue(self, *, key: Hashable, op: OutboxOp, wait: bool = True) -> Any: - await self.ensure_worker() - async with self._cond: - if self._closed: - op.set_result(None) - return op.result - previous = self._pending.get(key) - if previous is not None: - op.queued_at = previous.queued_at - previous.set_result(None) - self._pending[key] = op - self._cond.notify() - if not wait: - return None - await op.done.wait() - return op.result - - async def drop_pending(self, *, key: Hashable) -> None: - async with self._cond: - pending = self._pending.pop(key, None) - if pending is not None: - pending.set_result(None) - self._cond.notify() - - async def close(self) -> None: - async with self._cond: - self._closed = True - self.fail_pending() - self._cond.notify_all() - if self._tg is not None: - await self._tg.__aexit__(None, None, None) - self._tg = None - - def fail_pending(self) -> None: - for pending in list(self._pending.values()): - pending.set_result(None) - self._pending.clear() - - def pick_locked(self) -> tuple[Hashable, OutboxOp] | None: - if not self._pending: - return None - return min( - self._pending.items(), - key=lambda item: (item[1].priority, item[1].queued_at), - ) - - async def execute_op(self, op: OutboxOp) -> Any: - try: - return await op.execute() - except Exception as exc: # noqa: BLE001 - if isinstance(exc, RetryAfter): - raise - if self._on_error is not None: - self._on_error(op, exc) - return None - - async def sleep_until(self, deadline: float) -> None: - delay = deadline - self._clock() - if delay > 0: - await self._sleep(delay) - - async def run(self) -> None: - cancel_exc = anyio.get_cancelled_exc_class() - try: - while True: - async with self._cond: - while not self._pending and not self._closed: - await self._cond.wait() - if self._closed and not self._pending: - return - blocked_until = max(self.next_at, self.retry_at) - if self._clock() < blocked_until: - await self.sleep_until(blocked_until) - continue - async with self._cond: - if self._closed and not self._pending: - return - picked = self.pick_locked() - if picked is None: - continue - key, op = picked - self._pending.pop(key, None) - started_at = self._clock() - try: - result = await self.execute_op(op) - except RetryAfter as exc: - self.retry_at = max(self.retry_at, self._clock() + exc.retry_after) - async with self._cond: - if self._closed: - op.set_result(None) - elif key not in self._pending: - self._pending[key] = op - self._cond.notify() - else: - op.set_result(None) - continue - self.next_at = started_at + self._interval_for_chat(op.chat_id) - op.set_result(result) - except cancel_exc: - return - except Exception as exc: # noqa: BLE001 - async with self._cond: - self._closed = True - self.fail_pending() - self._cond.notify_all() - if self._on_outbox_error is not None: - self._on_outbox_error(exc) - return - - -def retry_after_from_payload(payload: dict[str, Any]) -> float | None: - params = payload.get("parameters") - if isinstance(params, dict): - retry_after = params.get("retry_after") - if isinstance(retry_after, (int, float)): - return float(retry_after) - return None - - class TelegramClient: def __init__( self, @@ -601,19 +52,15 @@ class TelegramClient: if client is not None: if token is not None or http_client is not None: raise ValueError("Provide either token or client, not both.") - self._client_override = client - self._base = None - self._file_base = None - self._http_client = None - self._owns_http_client = False + self._client = client else: if token is None or not token: raise ValueError("Telegram token is empty") - self._client_override = None - self._base = f"https://api.telegram.org/bot{token}" - self._file_base = f"https://api.telegram.org/file/bot{token}" - self._http_client = http_client or httpx.AsyncClient(timeout=timeout_s) - self._owns_http_client = http_client is None + self._client = HttpBotClient( + token, + timeout_s=timeout_s, + http_client=http_client, + ) self._clock = clock self._sleep = sleep self._private_interval = ( @@ -678,173 +125,18 @@ class TelegramClient: async def close(self) -> None: await self._outbox.close() - if self._client_override is not None: - await self._client_override.close() - return - if self._owns_http_client and self._http_client is not None: - await self._http_client.aclose() - - def _parse_telegram_envelope( - self, - *, - method: str, - resp: httpx.Response, - payload: Any, - ) -> Any | None: - if not isinstance(payload, dict): - logger.error( - "telegram.invalid_payload", - method=method, - url=str(resp.request.url), - payload=payload, - ) - return None - - if not payload.get("ok"): - if payload.get("error_code") == 429: - retry_after = retry_after_from_payload(payload) - retry_after = 5.0 if retry_after is None else retry_after - logger.warning( - "telegram.rate_limited", - method=method, - url=str(resp.request.url), - retry_after=retry_after, - ) - raise TelegramRetryAfter(retry_after) - logger.error( - "telegram.api_error", - method=method, - url=str(resp.request.url), - payload=payload, - ) - return None - - logger.debug("telegram.response", method=method, payload=payload) - return payload.get("result") - - async def _request( - self, - method: str, - *, - json: dict[str, Any] | None = None, - data: dict[str, Any] | None = None, - files: dict[str, Any] | None = None, - ) -> Any | None: - if self._http_client is None or self._base is None: - raise RuntimeError("TelegramClient is configured without an HTTP client.") - request_payload = json if json is not None else data - logger.debug("telegram.request", method=method, payload=request_payload) - try: - if json is not None: - resp = await self._http_client.post(f"{self._base}/{method}", json=json) - else: - resp = await self._http_client.post( - f"{self._base}/{method}", data=data, files=files - ) - except httpx.HTTPError as exc: - url = getattr(exc.request, "url", None) - logger.error( - "telegram.network_error", - method=method, - url=str(url) if url is not None else None, - error=str(exc), - error_type=exc.__class__.__name__, - ) - return None - - try: - resp.raise_for_status() - except httpx.HTTPStatusError as exc: - if resp.status_code == 429: - retry_after: float | None = None - try: - response_payload = resp.json() - except Exception: # noqa: BLE001 - response_payload = None - if isinstance(response_payload, dict): - retry_after = retry_after_from_payload(response_payload) - retry_after = 5.0 if retry_after is None else retry_after - logger.warning( - "telegram.rate_limited", - method=method, - status=resp.status_code, - url=str(resp.request.url), - retry_after=retry_after, - ) - raise TelegramRetryAfter(retry_after) from exc - body = resp.text - logger.error( - "telegram.http_error", - method=method, - status=resp.status_code, - url=str(resp.request.url), - error=str(exc), - body=body, - ) - return None - - try: - response_payload = resp.json() - except Exception as exc: # noqa: BLE001 - body = resp.text - logger.error( - "telegram.bad_response", - method=method, - status=resp.status_code, - url=str(resp.request.url), - error=str(exc), - error_type=exc.__class__.__name__, - body=body, - ) - return None - - return self._parse_telegram_envelope( - method=method, - resp=resp, - payload=response_payload, - ) - - def _decode_result( - self, - *, - method: str, - payload: Any, - model: type[T], - ) -> T | None: - if payload is None: - return None - try: - return msgspec.convert(payload, type=model) - except Exception as exc: # noqa: BLE001 - logger.error( - "telegram.decode_error", - method=method, - error=str(exc), - error_type=exc.__class__.__name__, - ) - return None + await self._client.close() async def _call_with_retry_after( self, - fn: Callable[[], Awaitable[T]], - ) -> T: + fn: Callable[[], Awaitable[Any]], + ) -> Any: while True: try: return await fn() except TelegramRetryAfter as exc: await self._sleep(exc.retry_after) - async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None: - return await self._request(method, json=json_data) - - async def _post_form( - self, - method: str, - data: dict[str, Any], - files: dict[str, Any], - ) -> Any | None: - return await self._request(method, data=data, files=files) - async def get_updates( self, offset: int | None, @@ -852,105 +144,23 @@ class TelegramClient: allowed_updates: list[str] | None = None, ) -> list[Update] | None: async def execute() -> list[Update] | None: - if self._client_override is not None: - raw = await self._client_override.get_updates( - offset=offset, - timeout_s=timeout_s, - allowed_updates=allowed_updates, - ) - if raw is None: - return None - try: - return msgspec.convert(raw, type=list[Update]) - except Exception as exc: # noqa: BLE001 - logger.error( - "telegram.decode_error", - method="getUpdates", - error=str(exc), - error_type=exc.__class__.__name__, - ) - return None - - params: dict[str, Any] = {"timeout": timeout_s} - if offset is not None: - params["offset"] = offset - if allowed_updates is not None: - params["allowed_updates"] = allowed_updates - result = await self._post("getUpdates", params) - if result is None or not isinstance(result, list): - return None - try: - return msgspec.convert(result, type=list[Update]) - except Exception as exc: # noqa: BLE001 - logger.error( - "telegram.decode_error", - method="getUpdates", - error=str(exc), - error_type=exc.__class__.__name__, - ) - return None + return await self._client.get_updates( + offset=offset, + timeout_s=timeout_s, + allowed_updates=allowed_updates, + ) return await self._call_with_retry_after(execute) async def get_file(self, file_id: str) -> File | None: async def execute() -> File | None: - if self._client_override is not None: - return await self._client_override.get_file(file_id) - result = await self._post("getFile", {"file_id": file_id}) - return self._decode_result(method="getFile", payload=result, model=File) + return await self._client.get_file(file_id) return await self._call_with_retry_after(execute) async def download_file(self, file_path: str) -> bytes | None: async def execute() -> bytes | None: - if self._client_override is not None: - return await self._client_override.download_file(file_path) - if self._http_client is None or self._file_base is None: - raise RuntimeError( - "TelegramClient is configured without an HTTP client." - ) - url = f"{self._file_base}/{file_path}" - try: - resp = await self._http_client.get(url) - except httpx.HTTPError as exc: - request_url = getattr(exc.request, "url", None) - logger.error( - "telegram.file_network_error", - url=str(request_url) if request_url is not None else None, - error=str(exc), - error_type=exc.__class__.__name__, - ) - return None - try: - resp.raise_for_status() - except httpx.HTTPStatusError as exc: - if resp.status_code == 429: - retry_after: float | None = None - try: - response_payload = resp.json() - except Exception: # noqa: BLE001 - response_payload = None - if isinstance(response_payload, dict): - retry_after = retry_after_from_payload(response_payload) - retry_after = 5.0 if retry_after is None else retry_after - logger.warning( - "telegram.rate_limited", - method="download_file", - status=resp.status_code, - url=str(resp.request.url), - retry_after=retry_after, - ) - raise TelegramRetryAfter(retry_after) from exc - - logger.error( - "telegram.file_http_error", - status=resp.status_code, - url=str(resp.request.url), - error=str(exc), - body=resp.text, - ) - return None - return resp.content + return await self._client.download_file(file_path) return await self._call_with_retry_after(execute) @@ -968,36 +178,16 @@ class TelegramClient: replace_message_id: int | None = None, ) -> Message | None: async def execute() -> Message | None: - if self._client_override is not None: - return await self._client_override.send_message( - chat_id=chat_id, - text=text, - reply_to_message_id=reply_to_message_id, - disable_notification=disable_notification, - message_thread_id=message_thread_id, - entities=entities, - parse_mode=parse_mode, - reply_markup=reply_markup, - replace_message_id=replace_message_id, - ) - params: dict[str, Any] = {"chat_id": chat_id, "text": text} - if disable_notification is not None: - params["disable_notification"] = disable_notification - if reply_to_message_id is not None: - params["reply_to_message_id"] = reply_to_message_id - if message_thread_id is not None: - params["message_thread_id"] = message_thread_id - if entities is not None: - params["entities"] = entities - if parse_mode is not None: - params["parse_mode"] = parse_mode - if reply_markup is not None: - params["reply_markup"] = reply_markup - result = await self._post("sendMessage", params) - return self._decode_result( - method="sendMessage", - payload=result, - model=Message, + return await self._client.send_message( + chat_id=chat_id, + text=text, + reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, + message_thread_id=message_thread_id, + entities=entities, + parse_mode=parse_mode, + reply_markup=reply_markup, + replace_message_id=replace_message_id, ) if replace_message_id is not None: @@ -1028,34 +218,14 @@ class TelegramClient: caption: str | None = None, ) -> Message | None: async def execute() -> Message | None: - if self._client_override is not None: - return await self._client_override.send_document( - chat_id=chat_id, - filename=filename, - content=content, - reply_to_message_id=reply_to_message_id, - message_thread_id=message_thread_id, - disable_notification=disable_notification, - caption=caption, - ) - params: dict[str, Any] = {"chat_id": chat_id} - if disable_notification is not None: - params["disable_notification"] = disable_notification - if reply_to_message_id is not None: - params["reply_to_message_id"] = reply_to_message_id - if message_thread_id is not None: - params["message_thread_id"] = message_thread_id - if caption is not None: - params["caption"] = caption - result = await self._post_form( - "sendDocument", - params, - files={"document": (filename, content)}, - ) - return self._decode_result( - method="sendDocument", - payload=result, - model=Message, + return await self._client.send_document( + chat_id=chat_id, + filename=filename, + content=content, + reply_to_message_id=reply_to_message_id, + message_thread_id=message_thread_id, + disable_notification=disable_notification, + caption=caption, ) return await self.enqueue_op( @@ -1078,32 +248,14 @@ class TelegramClient: wait: bool = True, ) -> Message | None: async def execute() -> Message | None: - if self._client_override is not None: - return await self._client_override.edit_message_text( - chat_id=chat_id, - message_id=message_id, - text=text, - entities=entities, - parse_mode=parse_mode, - reply_markup=reply_markup, - wait=wait, - ) - params: dict[str, Any] = { - "chat_id": chat_id, - "message_id": message_id, - "text": text, - } - if entities is not None: - params["entities"] = entities - if parse_mode is not None: - params["parse_mode"] = parse_mode - if reply_markup is not None: - params["reply_markup"] = reply_markup - result = await self._post("editMessageText", params) - return self._decode_result( - method="editMessageText", - payload=result, - model=Message, + return await self._client.edit_message_text( + chat_id=chat_id, + message_id=message_id, + text=text, + entities=entities, + parse_mode=parse_mode, + reply_markup=reply_markup, + wait=wait, ) return await self.enqueue_op( @@ -1123,16 +275,10 @@ class TelegramClient: await self.drop_pending_edits(chat_id=chat_id, message_id=message_id) async def execute() -> bool: - if self._client_override is not None: - return await self._client_override.delete_message( - chat_id=chat_id, - message_id=message_id, - ) - result = await self._post( - "deleteMessage", - {"chat_id": chat_id, "message_id": message_id}, + return await self._client.delete_message( + chat_id=chat_id, + message_id=message_id, ) - return bool(result) return bool( await self.enqueue_op( @@ -1152,19 +298,11 @@ class TelegramClient: language_code: str | None = None, ) -> bool: async def execute() -> bool: - if self._client_override is not None: - return await self._client_override.set_my_commands( - commands, - scope=scope, - language_code=language_code, - ) - params: dict[str, Any] = {"commands": commands} - if scope is not None: - params["scope"] = scope - if language_code is not None: - params["language_code"] = language_code - result = await self._post("setMyCommands", params) - return bool(result) + return await self._client.set_my_commands( + commands, + scope=scope, + language_code=language_code, + ) return bool( await self.enqueue_op( @@ -1178,10 +316,7 @@ class TelegramClient: async def get_me(self) -> User | None: async def execute() -> User | None: - if self._client_override is not None: - return await self._client_override.get_me() - result = await self._post("getMe", {}) - return self._decode_result(method="getMe", payload=result, model=User) + return await self._client.get_me() return await self.enqueue_op( key=self.unique_key("get_me"), @@ -1198,19 +333,11 @@ class TelegramClient: show_alert: bool | None = None, ) -> bool: async def execute() -> bool: - if self._client_override is not None: - return await self._client_override.answer_callback_query( - callback_query_id=callback_query_id, - text=text, - show_alert=show_alert, - ) - params: dict[str, Any] = {"callback_query_id": callback_query_id} - if text is not None: - params["text"] = text - if show_alert is not None: - params["show_alert"] = show_alert - result = await self._post("answerCallbackQuery", params) - return bool(result) + return await self._client.answer_callback_query( + callback_query_id=callback_query_id, + text=text, + show_alert=show_alert, + ) return bool( await self.enqueue_op( @@ -1224,10 +351,7 @@ class TelegramClient: async def get_chat(self, chat_id: int) -> Chat | None: async def execute() -> Chat | None: - if self._client_override is not None: - return await self._client_override.get_chat(chat_id) - result = await self._post("getChat", {"chat_id": chat_id}) - return self._decode_result(method="getChat", payload=result, model=Chat) + return await self._client.get_chat(chat_id) return await self.enqueue_op( key=self.unique_key("get_chat"), @@ -1239,16 +363,7 @@ class TelegramClient: async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None: async def execute() -> ChatMember | None: - if self._client_override is not None: - return await self._client_override.get_chat_member(chat_id, user_id) - result = await self._post( - "getChatMember", {"chat_id": chat_id, "user_id": user_id} - ) - return self._decode_result( - method="getChatMember", - payload=result, - model=ChatMember, - ) + return await self._client.get_chat_member(chat_id, user_id) return await self.enqueue_op( key=self.unique_key("get_chat_member"), @@ -1260,16 +375,7 @@ class TelegramClient: async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None: async def execute() -> ForumTopic | None: - if self._client_override is not None: - return await self._client_override.create_forum_topic(chat_id, name) - result = await self._post( - "createForumTopic", {"chat_id": chat_id, "name": name} - ) - return self._decode_result( - method="createForumTopic", - payload=result, - model=ForumTopic, - ) + return await self._client.create_forum_topic(chat_id, name) return await self.enqueue_op( key=self.unique_key("create_forum_topic"), @@ -1280,22 +386,17 @@ class TelegramClient: ) async def edit_forum_topic( - self, chat_id: int, message_thread_id: int, name: str + self, + chat_id: int, + message_thread_id: int, + name: str, ) -> bool: async def execute() -> bool: - if self._client_override is not None: - return await self._client_override.edit_forum_topic( - chat_id, message_thread_id, name - ) - result = await self._post( - "editForumTopic", - { - "chat_id": chat_id, - "message_thread_id": message_thread_id, - "name": name, - }, + return await self._client.edit_forum_topic( + chat_id, + message_thread_id, + name, ) - return bool(result) return bool( await self.enqueue_op( diff --git a/src/takopi/telegram/client_api.py b/src/takopi/telegram/client_api.py new file mode 100644 index 0000000..392ef1c --- /dev/null +++ b/src/takopi/telegram/client_api.py @@ -0,0 +1,537 @@ +from __future__ import annotations + +from typing import Any, Protocol, TypeVar + +import httpx +import msgspec + +from ..logging import get_logger +from .api_models import Chat, ChatMember, File, ForumTopic, Message, Update, User + +logger = get_logger(__name__) + +T = TypeVar("T") + + +class RetryAfter(Exception): + def __init__(self, retry_after: float, description: str | None = None) -> None: + super().__init__(description or f"retry after {retry_after}") + self.retry_after = float(retry_after) + self.description = description + + +class TelegramRetryAfter(RetryAfter): + pass + + +def retry_after_from_payload(payload: dict[str, Any]) -> float | None: + params = payload.get("parameters") + if isinstance(params, dict): + retry_after = params.get("retry_after") + if isinstance(retry_after, (int, float)): + return float(retry_after) + return None + + +class BotClient(Protocol): + async def close(self) -> None: ... + + async def get_updates( + self, + offset: int | None, + timeout_s: int = 50, + allowed_updates: list[str] | None = None, + ) -> list[Update] | None: ... + + async def get_file(self, file_id: str) -> File | None: ... + + async def download_file(self, file_path: str) -> bytes | None: ... + + async def send_message( + self, + chat_id: int, + text: str, + reply_to_message_id: int | None = None, + disable_notification: bool | None = False, + message_thread_id: int | None = None, + entities: list[dict] | None = None, + parse_mode: str | None = None, + reply_markup: dict[str, Any] | None = None, + *, + replace_message_id: int | None = None, + ) -> Message | None: ... + + async def send_document( + self, + chat_id: int, + filename: str, + content: bytes, + reply_to_message_id: int | None = None, + message_thread_id: int | None = None, + disable_notification: bool | None = False, + caption: str | None = None, + ) -> Message | None: ... + + async def edit_message_text( + self, + chat_id: int, + message_id: int, + text: str, + entities: list[dict] | None = None, + parse_mode: str | None = None, + reply_markup: dict[str, Any] | None = None, + *, + wait: bool = True, + ) -> Message | None: ... + + async def delete_message( + self, + chat_id: int, + message_id: int, + ) -> bool: ... + + async def set_my_commands( + self, + commands: list[dict[str, Any]], + *, + scope: dict[str, Any] | None = None, + language_code: str | None = None, + ) -> bool: ... + + async def get_me(self) -> User | None: ... + + async def answer_callback_query( + self, + callback_query_id: str, + text: str | None = None, + show_alert: bool | None = None, + ) -> bool: ... + + async def get_chat(self, chat_id: int) -> Chat | None: ... + + async def get_chat_member( + self, chat_id: int, user_id: int + ) -> ChatMember | None: ... + + async def create_forum_topic( + self, + chat_id: int, + name: str, + ) -> ForumTopic | None: ... + + async def edit_forum_topic( + self, + chat_id: int, + message_thread_id: int, + name: str, + ) -> bool: ... + + +class HttpBotClient: + def __init__( + self, + token: str, + *, + timeout_s: float = 120, + http_client: httpx.AsyncClient | None = None, + ) -> None: + if not token: + raise ValueError("Telegram token is empty") + self._base = f"https://api.telegram.org/bot{token}" + self._file_base = f"https://api.telegram.org/file/bot{token}" + self._http_client = http_client or httpx.AsyncClient(timeout=timeout_s) + self._owns_http_client = http_client is None + + async def close(self) -> None: + if self._owns_http_client: + await self._http_client.aclose() + + def _parse_telegram_envelope( + self, + *, + method: str, + resp: httpx.Response, + payload: Any, + ) -> Any | None: + if not isinstance(payload, dict): + logger.error( + "telegram.invalid_payload", + method=method, + url=str(resp.request.url), + payload=payload, + ) + return None + + if not payload.get("ok"): + if payload.get("error_code") == 429: + retry_after = retry_after_from_payload(payload) + retry_after = 5.0 if retry_after is None else retry_after + logger.warning( + "telegram.rate_limited", + method=method, + url=str(resp.request.url), + retry_after=retry_after, + ) + raise TelegramRetryAfter(retry_after) + logger.error( + "telegram.api_error", + method=method, + url=str(resp.request.url), + payload=payload, + ) + return None + + logger.debug("telegram.response", method=method, payload=payload) + return payload.get("result") + + async def _request( + self, + method: str, + *, + json: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, + files: dict[str, Any] | None = None, + ) -> Any | None: + request_payload = json if json is not None else data + logger.debug("telegram.request", method=method, payload=request_payload) + try: + if json is not None: + resp = await self._http_client.post(f"{self._base}/{method}", json=json) + else: + resp = await self._http_client.post( + f"{self._base}/{method}", data=data, files=files + ) + except httpx.HTTPError as exc: + url = getattr(exc.request, "url", None) + logger.error( + "telegram.network_error", + method=method, + url=str(url) if url is not None else None, + error=str(exc), + error_type=exc.__class__.__name__, + ) + return None + + try: + resp.raise_for_status() + except httpx.HTTPStatusError as exc: + if resp.status_code == 429: + retry_after: float | None = None + try: + response_payload = resp.json() + except Exception: # noqa: BLE001 + response_payload = None + if isinstance(response_payload, dict): + retry_after = retry_after_from_payload(response_payload) + retry_after = 5.0 if retry_after is None else retry_after + logger.warning( + "telegram.rate_limited", + method=method, + status=resp.status_code, + url=str(resp.request.url), + retry_after=retry_after, + ) + raise TelegramRetryAfter(retry_after) from exc + body = resp.text + logger.error( + "telegram.http_error", + method=method, + status=resp.status_code, + url=str(resp.request.url), + error=str(exc), + body=body, + ) + return None + + try: + response_payload = resp.json() + except Exception as exc: # noqa: BLE001 + body = resp.text + logger.error( + "telegram.bad_response", + method=method, + status=resp.status_code, + url=str(resp.request.url), + error=str(exc), + error_type=exc.__class__.__name__, + body=body, + ) + return None + + return self._parse_telegram_envelope( + method=method, + resp=resp, + payload=response_payload, + ) + + def _decode_result( + self, + *, + method: str, + payload: Any, + model: type[T], + ) -> T | None: + if payload is None: + return None + try: + return msgspec.convert(payload, type=model) + except Exception as exc: # noqa: BLE001 + logger.error( + "telegram.decode_error", + method=method, + error=str(exc), + error_type=exc.__class__.__name__, + ) + return None + + async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None: + return await self._request(method, json=json_data) + + async def _post_form( + self, + method: str, + data: dict[str, Any], + files: dict[str, Any], + ) -> Any | None: + return await self._request(method, data=data, files=files) + + async def get_updates( + self, + offset: int | None, + timeout_s: int = 50, + allowed_updates: list[str] | None = None, + ) -> list[Update] | None: + params: dict[str, Any] = {"timeout": timeout_s} + if offset is not None: + params["offset"] = offset + if allowed_updates is not None: + params["allowed_updates"] = allowed_updates + result = await self._post("getUpdates", params) + if result is None or not isinstance(result, list): + return None + try: + return msgspec.convert(result, type=list[Update]) + except Exception as exc: # noqa: BLE001 + logger.error( + "telegram.decode_error", + method="getUpdates", + error=str(exc), + error_type=exc.__class__.__name__, + ) + return None + + async def get_file(self, file_id: str) -> File | None: + result = await self._post("getFile", {"file_id": file_id}) + return self._decode_result(method="getFile", payload=result, model=File) + + async def download_file(self, file_path: str) -> bytes | None: + url = f"{self._file_base}/{file_path}" + try: + resp = await self._http_client.get(url) + except httpx.HTTPError as exc: + request_url = getattr(exc.request, "url", None) + logger.error( + "telegram.file_network_error", + url=str(request_url) if request_url is not None else None, + error=str(exc), + error_type=exc.__class__.__name__, + ) + return None + try: + resp.raise_for_status() + except httpx.HTTPStatusError as exc: + if resp.status_code == 429: + retry_after: float | None = None + try: + response_payload = resp.json() + except Exception: # noqa: BLE001 + response_payload = None + if isinstance(response_payload, dict): + retry_after = retry_after_from_payload(response_payload) + retry_after = 5.0 if retry_after is None else retry_after + logger.warning( + "telegram.rate_limited", + method="download_file", + status=resp.status_code, + url=str(resp.request.url), + retry_after=retry_after, + ) + raise TelegramRetryAfter(retry_after) from exc + + logger.error( + "telegram.file_http_error", + status=resp.status_code, + url=str(resp.request.url), + error=str(exc), + body=resp.text, + ) + return None + return resp.content + + async def send_message( + self, + chat_id: int, + text: str, + reply_to_message_id: int | None = None, + disable_notification: bool | None = False, + message_thread_id: int | None = None, + entities: list[dict] | None = None, + parse_mode: str | None = None, + reply_markup: dict[str, Any] | None = None, + *, + replace_message_id: int | None = None, + ) -> Message | None: + params: dict[str, Any] = {"chat_id": chat_id, "text": text} + if disable_notification is not None: + params["disable_notification"] = disable_notification + if reply_to_message_id is not None: + params["reply_to_message_id"] = reply_to_message_id + if message_thread_id is not None: + params["message_thread_id"] = message_thread_id + if entities is not None: + params["entities"] = entities + if parse_mode is not None: + params["parse_mode"] = parse_mode + if reply_markup is not None: + params["reply_markup"] = reply_markup + result = await self._post("sendMessage", params) + return self._decode_result(method="sendMessage", payload=result, model=Message) + + async def send_document( + self, + chat_id: int, + filename: str, + content: bytes, + reply_to_message_id: int | None = None, + message_thread_id: int | None = None, + disable_notification: bool | None = False, + caption: str | None = None, + ) -> Message | None: + params: dict[str, Any] = {"chat_id": chat_id} + if disable_notification is not None: + params["disable_notification"] = disable_notification + if reply_to_message_id is not None: + params["reply_to_message_id"] = reply_to_message_id + if message_thread_id is not None: + params["message_thread_id"] = message_thread_id + if caption is not None: + params["caption"] = caption + result = await self._post_form( + "sendDocument", + params, + files={"document": (filename, content)}, + ) + return self._decode_result(method="sendDocument", payload=result, model=Message) + + async def edit_message_text( + self, + chat_id: int, + message_id: int, + text: str, + entities: list[dict] | None = None, + parse_mode: str | None = None, + reply_markup: dict[str, Any] | None = None, + *, + wait: bool = True, + ) -> Message | None: + params: dict[str, Any] = { + "chat_id": chat_id, + "message_id": message_id, + "text": text, + } + if entities is not None: + params["entities"] = entities + if parse_mode is not None: + params["parse_mode"] = parse_mode + if reply_markup is not None: + params["reply_markup"] = reply_markup + result = await self._post("editMessageText", params) + return self._decode_result( + method="editMessageText", + payload=result, + model=Message, + ) + + async def delete_message( + self, + chat_id: int, + message_id: int, + ) -> bool: + result = await self._post( + "deleteMessage", + {"chat_id": chat_id, "message_id": message_id}, + ) + return bool(result) + + async def set_my_commands( + self, + commands: list[dict[str, Any]], + *, + scope: dict[str, Any] | None = None, + language_code: str | None = None, + ) -> bool: + params: dict[str, Any] = {"commands": commands} + if scope is not None: + params["scope"] = scope + if language_code is not None: + params["language_code"] = language_code + result = await self._post("setMyCommands", params) + return bool(result) + + async def get_me(self) -> User | None: + result = await self._post("getMe", {}) + return self._decode_result(method="getMe", payload=result, model=User) + + async def answer_callback_query( + self, + callback_query_id: str, + text: str | None = None, + show_alert: bool | None = None, + ) -> bool: + params: dict[str, Any] = {"callback_query_id": callback_query_id} + if text is not None: + params["text"] = text + if show_alert is not None: + params["show_alert"] = show_alert + result = await self._post("answerCallbackQuery", params) + return bool(result) + + async def get_chat(self, chat_id: int) -> Chat | None: + result = await self._post("getChat", {"chat_id": chat_id}) + return self._decode_result(method="getChat", payload=result, model=Chat) + + async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None: + result = await self._post( + "getChatMember", {"chat_id": chat_id, "user_id": user_id} + ) + return self._decode_result( + method="getChatMember", + payload=result, + model=ChatMember, + ) + + async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None: + result = await self._post( + "createForumTopic", {"chat_id": chat_id, "name": name} + ) + return self._decode_result( + method="createForumTopic", + payload=result, + model=ForumTopic, + ) + + async def edit_forum_topic( + self, + chat_id: int, + message_thread_id: int, + name: str, + ) -> bool: + result = await self._post( + "editForumTopic", + { + "chat_id": chat_id, + "message_thread_id": message_thread_id, + "name": name, + }, + ) + return bool(result) diff --git a/src/takopi/telegram/commands.py b/src/takopi/telegram/commands.py deleted file mode 100644 index 649c1ff..0000000 --- a/src/takopi/telegram/commands.py +++ /dev/null @@ -1,1744 +0,0 @@ -from __future__ import annotations - -from collections.abc import AsyncIterator, Awaitable, Callable, Sequence -from dataclasses import dataclass -from functools import partial -from pathlib import Path -from typing import cast - -import anyio - -from ..commands import ( - CommandContext, - CommandExecutor, - RunMode, - RunRequest, - RunResult, - get_command, -) -from ..context import RunContext -from ..config import ConfigError -from ..directives import DirectiveError -from ..ids import RESERVED_COMMAND_IDS, is_valid_id -from ..logging import bind_run_context, clear_context, get_logger -from ..markdown import MarkdownParts -from ..model import EngineId, ResumeToken, TakopiEvent -from ..plugins import COMMAND_GROUP, list_entrypoints -from ..progress import ProgressTracker -from ..router import RunnerUnavailableError -from ..runner import Runner -from ..runner_bridge import ( - ExecBridgeConfig, - IncomingMessage as RunnerIncomingMessage, - RunningTasks, - handle_message, -) -from ..scheduler import ThreadScheduler -from ..transport import MessageRef, RenderedMessage, SendOptions -from ..transport_runtime import ResolvedMessage, TransportRuntime -from ..utils.paths import reset_run_base_dir, set_run_base_dir -from .bridge import TelegramBridgeConfig, send_plain -from .chat_prefs import ChatPrefsStore -from .chat_sessions import ChatSessionStore -from .context import ( - _format_context, - _format_ctx_status, - _merge_topic_context, - _parse_project_branch_args, - _usage_ctx_set, - _usage_topic, -) -from .engine_defaults import resolve_engine_for_message -from .files import ( - default_upload_name, - default_upload_path, - deny_reason, - format_bytes, - normalize_relative_path, - parse_file_command, - parse_file_prompt, - resolve_path_within_root, - split_command_args, - write_bytes_atomic, - ZipTooLargeError, - zip_directory, -) -from .render import prepare_telegram -from .topic_state import TopicStateStore -from .topics import ( - _maybe_rename_topic, - _maybe_update_topic_context, - _topic_key, - _topic_title, - _topics_chat_project, - _topics_command_error, -) -from .types import TelegramCallbackQuery, TelegramDocument, TelegramIncomingMessage - -logger = get_logger(__name__) - -__all__ = [ - "FILE_GET_USAGE", - "FILE_PUT_USAGE", - "_dispatch_command", - "_handle_agent_command", - "_handle_chat_new_command", - "_handle_file_command", - "_handle_file_get", - "_handle_file_put", - "_handle_file_put_default", - "_handle_media_group", - "_parse_slash_command", - "_reserved_commands", - "_set_command_menu", - "build_bot_commands", - "handle_callback_cancel", - "handle_cancel", - "is_cancel_command", -] - -_MAX_BOT_COMMANDS = 100 -FILE_PUT_USAGE = "usage: `/file put `" -FILE_GET_USAGE = "usage: `/file get `" -AGENT_USAGE = "usage: `/agent`, `/agent set `, or `/agent clear`" - - -def is_cancel_command(text: str) -> bool: - stripped = text.strip() - if not stripped: - return False - command = stripped.split(maxsplit=1)[0] - return command == "/cancel" or command.startswith("/cancel@") - - -def _parse_slash_command(text: str) -> tuple[str | None, str]: - stripped = text.lstrip() - if not stripped.startswith("/"): - return None, text - lines = stripped.splitlines() - if not lines: - return None, text - first_line = lines[0] - token, _, rest = first_line.partition(" ") - command = token[1:] - if not command: - return None, text - if "@" in command: - command = command.split("@", 1)[0] - args_text = rest - if len(lines) > 1: - tail = "\n".join(lines[1:]) - args_text = f"{args_text}\n{tail}" if args_text else tail - return command.lower(), args_text - - -def build_bot_commands( - runtime: TransportRuntime, *, include_file: bool = True -) -> list[dict[str, str]]: - commands: list[dict[str, str]] = [] - seen: set[str] = set() - for engine_id in runtime.available_engine_ids(): - cmd = engine_id.lower() - if cmd in seen: - continue - commands.append({"command": cmd, "description": f"use agent: {cmd}"}) - seen.add(cmd) - for alias in runtime.project_aliases(): - cmd = alias.lower() - if cmd in seen: - continue - if not is_valid_id(cmd): - logger.debug( - "startup.command_menu.skip_project", - alias=alias, - ) - continue - commands.append({"command": cmd, "description": f"work on: {cmd}"}) - seen.add(cmd) - allowlist = runtime.allowlist - for ep in list_entrypoints( - COMMAND_GROUP, - allowlist=allowlist, - reserved_ids=RESERVED_COMMAND_IDS, - ): - try: - backend = get_command(ep.name, allowlist=allowlist) - except ConfigError as exc: - logger.info( - "startup.command_menu.skip_command", - command=ep.name, - error=str(exc), - ) - continue - cmd = backend.id.lower() - if cmd in seen: - continue - if not is_valid_id(cmd): - logger.debug( - "startup.command_menu.skip_command_id", - command=cmd, - ) - continue - description = backend.description or f"command: {cmd}" - commands.append({"command": cmd, "description": description}) - seen.add(cmd) - if include_file and "file" not in seen: - commands.append({"command": "file", "description": "upload or fetch files"}) - seen.add("file") - if "cancel" not in seen: - commands.append({"command": "cancel", "description": "cancel run"}) - if len(commands) > _MAX_BOT_COMMANDS: - logger.warning( - "startup.command_menu.too_many", - count=len(commands), - limit=_MAX_BOT_COMMANDS, - ) - commands = commands[:_MAX_BOT_COMMANDS] - if not any(cmd["command"] == "cancel" for cmd in commands): - commands[-1] = {"command": "cancel", "description": "cancel run"} - return commands - - -def _reserved_commands(runtime: TransportRuntime) -> set[str]: - return { - *{engine.lower() for engine in runtime.engine_ids}, - *{alias.lower() for alias in runtime.project_aliases()}, - *RESERVED_COMMAND_IDS, - } - - -def _reply_sender( - cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage -) -> Callable[..., Awaitable[None]]: - return partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) - - -async def _set_command_menu(cfg: TelegramBridgeConfig) -> None: - commands = build_bot_commands(cfg.runtime, include_file=cfg.files.enabled) - if not commands: - return - try: - ok = await cfg.bot.set_my_commands(commands) - except Exception as exc: # noqa: BLE001 - logger.info( - "startup.command_menu.failed", - error=str(exc), - error_type=exc.__class__.__name__, - ) - return - if not ok: - logger.info("startup.command_menu.rejected") - return - logger.info( - "startup.command_menu.updated", - commands=[cmd["command"] for cmd in commands], - ) - - -@dataclass(slots=True) -class _FilePutPlan: - resolved: ResolvedMessage - run_root: Path - path_value: str | None - force: bool - - -@dataclass(slots=True) -class _FilePutResult: - name: str - rel_path: Path | None - size: int | None - error: str | None - - -@dataclass(slots=True) -class _SavedFilePut: - context: RunContext | None - rel_path: Path - size: int - - -@dataclass(slots=True) -class _SavedFilePutGroup: - context: RunContext | None - base_dir: Path | None - saved: list[_FilePutResult] - failed: list[_FilePutResult] - - -@dataclass(slots=True) -class _ResumeLineProxy: - runner: Runner - - @property - def engine(self) -> str: - return self.runner.engine - - def is_resume_line(self, line: str) -> bool: - return self.runner.is_resume_line(line) - - def format_resume(self, _: ResumeToken) -> str: - return "" - - def extract_resume(self, text: str | None) -> ResumeToken | None: - return self.runner.extract_resume(text) - - def run( - self, prompt: str, resume: ResumeToken | None - ) -> AsyncIterator[TakopiEvent]: - return self.runner.run(prompt, resume) - - -def _should_show_resume_line( - *, - show_resume_line: bool, - stateful_mode: bool, - context: RunContext | None, -) -> bool: - if show_resume_line: - return True - if not stateful_mode: - return True - if context is None or context.project is None: - return True - return False - - -def resolve_file_put_paths( - plan: _FilePutPlan, - *, - cfg: TelegramBridgeConfig, - require_dir: bool, -) -> tuple[Path | None, Path | None, str | None]: - path_value = plan.path_value - if not path_value: - return None, None, None - if require_dir or path_value.endswith("/"): - base_dir = normalize_relative_path(path_value) - if base_dir is None: - return None, None, "invalid upload path." - deny_rule = deny_reason(base_dir, cfg.files.deny_globs) - if deny_rule is not None: - return None, None, f"path denied by rule: {deny_rule}" - base_target = resolve_path_within_root(plan.run_root, base_dir) - if base_target is None: - return None, None, "upload path escapes the repo root." - if base_target.exists() and not base_target.is_dir(): - return None, None, "upload path is a file." - return base_dir, None, None - rel_path = normalize_relative_path(path_value) - if rel_path is None: - return None, None, "invalid upload path." - return None, rel_path, None - - -async def _check_file_permissions( - cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage -) -> bool: - reply = _reply_sender(cfg, msg) - sender_id = msg.sender_id - if sender_id is None: - await reply(text="cannot verify sender for file transfer.") - return False - if cfg.files.allowed_user_ids: - if sender_id not in cfg.files.allowed_user_ids: - await reply(text="file transfer is not allowed for this user.") - return False - return True - is_private = msg.chat_type == "private" - if msg.chat_type is None: - is_private = msg.chat_id > 0 - if is_private: - return True - member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) - if member is None: - await reply(text="failed to verify file transfer permissions.") - return False - if member.status in {"creator", "administrator"}: - return True - await reply(text="file transfer is restricted to group admins.") - return False - - -async def _check_agent_permissions( - cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage -) -> bool: - reply = _reply_sender(cfg, msg) - sender_id = msg.sender_id - if sender_id is None: - await reply(text="cannot verify sender for agent defaults.") - return False - is_private = msg.chat_type == "private" - if msg.chat_type is None: - is_private = msg.chat_id > 0 - if is_private: - return True - member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) - if member is None: - await reply(text="failed to verify agent permissions.") - return False - if member.status in {"creator", "administrator"}: - return True - await reply(text="changing default agents is restricted to group admins.") - return False - - -async def _prepare_file_put_plan( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - args_text: str, - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, -) -> _FilePutPlan | None: - reply = _reply_sender(cfg, msg) - if not await _check_file_permissions(cfg, msg): - return None - try: - resolved = cfg.runtime.resolve_message( - text=args_text, - reply_text=msg.reply_to_text, - ambient_context=ambient_context, - chat_id=msg.chat_id, - ) - except DirectiveError as exc: - await reply(text=f"error:\n{exc}") - return None - topic_key = _topic_key(msg, cfg) if topic_store is not None else None - await _maybe_update_topic_context( - cfg=cfg, - topic_store=topic_store, - topic_key=topic_key, - context=resolved.context, - context_source=resolved.context_source, - ) - if resolved.context is None or resolved.context.project is None: - await reply(text="no project context available for file upload.") - return None - try: - run_root = cfg.runtime.resolve_run_cwd(resolved.context) - except ConfigError as exc: - await reply(text=f"error:\n{exc}") - return None - if run_root is None: - await reply(text="no project context available for file upload.") - return None - path_value, force, error = parse_file_prompt(resolved.prompt, allow_empty=True) - if error is not None: - await reply(text=error) - return None - return _FilePutPlan( - resolved=resolved, - run_root=run_root, - path_value=path_value, - force=force, - ) - - -def _format_file_put_failures(failed: Sequence[_FilePutResult]) -> str | None: - if not failed: - return None - errors = ", ".join( - f"`{item.name}` ({item.error})" for item in failed if item.error is not None - ) - if not errors: - return None - return f"failed: {errors}" - - -async def _save_document_payload( - cfg: TelegramBridgeConfig, - *, - document: TelegramDocument, - run_root: Path, - rel_path: Path | None, - base_dir: Path | None, - force: bool, -) -> _FilePutResult: - name = default_upload_name(document.file_name, None) - if ( - document.file_size is not None - and document.file_size > cfg.files.max_upload_bytes - ): - return _FilePutResult( - name=name, - rel_path=None, - size=None, - error="file is too large to upload.", - ) - file_info = await cfg.bot.get_file(document.file_id) - if file_info is None: - return _FilePutResult( - name=name, - rel_path=None, - size=None, - error="failed to fetch file metadata.", - ) - file_path = file_info.file_path - name = default_upload_name(document.file_name, file_path) - resolved_path = rel_path - if resolved_path is None: - if base_dir is None: - resolved_path = default_upload_path( - cfg.files.uploads_dir, document.file_name, file_path - ) - else: - resolved_path = base_dir / name - deny_rule = deny_reason(resolved_path, cfg.files.deny_globs) - if deny_rule is not None: - return _FilePutResult( - name=name, - rel_path=None, - size=None, - error=f"path denied by rule: {deny_rule}", - ) - target = resolve_path_within_root(run_root, resolved_path) - if target is None: - return _FilePutResult( - name=name, - rel_path=None, - size=None, - error="upload path escapes the repo root.", - ) - if target.exists(): - if target.is_dir(): - return _FilePutResult( - name=name, - rel_path=None, - size=None, - error="upload target is a directory.", - ) - if not force: - return _FilePutResult( - name=name, - rel_path=None, - size=None, - error="file already exists; use --force to overwrite.", - ) - payload = await cfg.bot.download_file(file_path) - if payload is None: - return _FilePutResult( - name=name, - rel_path=None, - size=None, - error="failed to download file.", - ) - if len(payload) > cfg.files.max_upload_bytes: - return _FilePutResult( - name=name, - rel_path=None, - size=None, - error="file is too large to upload.", - ) - try: - write_bytes_atomic(target, payload) - except OSError as exc: - return _FilePutResult( - name=name, - rel_path=None, - size=None, - error=f"failed to write file: {exc}", - ) - return _FilePutResult( - name=name, - rel_path=resolved_path, - size=len(payload), - error=None, - ) - - -async def _handle_file_command( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - args_text: str, - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, -) -> None: - reply = _reply_sender(cfg, msg) - command, rest, error = parse_file_command(args_text) - if error is not None: - await reply(text=error) - return - if command == "put": - await _handle_file_put(cfg, msg, rest, ambient_context, topic_store) - else: - await _handle_file_get(cfg, msg, rest, ambient_context, topic_store) - - -async def _handle_file_put_default( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, -) -> None: - await _handle_file_put(cfg, msg, "", ambient_context, topic_store) - - -async def _save_file_put( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - args_text: str, - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, -) -> _SavedFilePut | None: - reply = _reply_sender(cfg, msg) - document = msg.document - if document is None: - await reply(text=FILE_PUT_USAGE) - return None - plan = await _prepare_file_put_plan( - cfg, - msg, - args_text, - ambient_context, - topic_store, - ) - if plan is None: - return None - base_dir, rel_path, error = resolve_file_put_paths( - plan, - cfg=cfg, - require_dir=False, - ) - if error is not None: - await reply(text=error) - return None - result = await _save_document_payload( - cfg, - document=document, - run_root=plan.run_root, - rel_path=rel_path, - base_dir=base_dir, - force=plan.force, - ) - if result.error is not None: - await reply(text=result.error) - return None - if result.rel_path is None or result.size is None: - await reply(text="failed to save file.") - return None - return _SavedFilePut( - context=plan.resolved.context, - rel_path=result.rel_path, - size=result.size, - ) - - -async def _handle_file_put( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - args_text: str, - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, -) -> None: - reply = _reply_sender(cfg, msg) - saved = await _save_file_put( - cfg, - msg, - args_text, - ambient_context, - topic_store, - ) - if saved is None: - return - context_label = _format_context(cfg.runtime, saved.context) - await reply( - text=( - f"saved `{saved.rel_path.as_posix()}` " - f"in `{context_label}` ({format_bytes(saved.size)})" - ), - ) - - -async def _handle_file_put_group( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - args_text: str, - messages: Sequence[TelegramIncomingMessage], - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, -) -> None: - reply = _reply_sender(cfg, msg) - saved_group = await _save_file_put_group( - cfg, - msg, - args_text, - messages, - ambient_context, - topic_store, - ) - if saved_group is None: - return - context_label = _format_context(cfg.runtime, saved_group.context) - total_bytes = sum(item.size or 0 for item in saved_group.saved) - dir_label: Path | None = saved_group.base_dir - if dir_label is None and saved_group.saved: - first_path = saved_group.saved[0].rel_path - if first_path is not None: - dir_label = first_path.parent - if saved_group.saved: - saved_names = ", ".join(f"`{item.name}`" for item in saved_group.saved) - if dir_label is not None: - dir_text = dir_label.as_posix() - if not dir_text.endswith("/"): - dir_text = f"{dir_text}/" - text = ( - f"saved {saved_names} to `{dir_text}` " - f"in `{context_label}` ({format_bytes(total_bytes)})" - ) - else: - text = ( - f"saved {saved_names} in `{context_label}` " - f"({format_bytes(total_bytes)})" - ) - else: - text = "failed to upload files." - failure_text = _format_file_put_failures(saved_group.failed) - if failure_text is not None: - text = f"{text}\n\n{failure_text}" - await reply(text=text) - - -async def _save_file_put_group( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - args_text: str, - messages: Sequence[TelegramIncomingMessage], - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, -) -> _SavedFilePutGroup | None: - reply = _reply_sender(cfg, msg) - documents = [item.document for item in messages if item.document is not None] - if not documents: - await reply(text=FILE_PUT_USAGE) - return None - plan = await _prepare_file_put_plan( - cfg, - msg, - args_text, - ambient_context, - topic_store, - ) - if plan is None: - return None - base_dir, _, error = resolve_file_put_paths( - plan, - cfg=cfg, - require_dir=True, - ) - if error is not None: - await reply(text=error) - return None - saved: list[_FilePutResult] = [] - failed: list[_FilePutResult] = [] - for document in documents: - result = await _save_document_payload( - cfg, - document=document, - run_root=plan.run_root, - rel_path=None, - base_dir=base_dir, - force=plan.force, - ) - if result.error is None: - saved.append(result) - else: - failed.append(result) - return _SavedFilePutGroup( - context=plan.resolved.context, - base_dir=base_dir, - saved=saved, - failed=failed, - ) - - -async def _handle_media_group( - cfg: TelegramBridgeConfig, - messages: Sequence[TelegramIncomingMessage], - topic_store: TopicStateStore | None, - run_prompt: Callable[ - [TelegramIncomingMessage, str, ResolvedMessage], Awaitable[None] - ] - | None = None, - resolve_prompt: Callable[ - [TelegramIncomingMessage, str, RunContext | None], - Awaitable[ResolvedMessage | None], - ] - | None = None, -) -> None: - if not messages: - return - ordered = sorted(messages, key=lambda item: item.message_id) - command_msg = next( - (item for item in ordered if item.text.strip()), - ordered[0], - ) - reply = _reply_sender(cfg, command_msg) - topic_key = _topic_key(command_msg, cfg) if topic_store is not None else None - chat_project = _topics_chat_project(cfg, command_msg.chat_id) - bound_context = ( - await topic_store.get_context(*topic_key) - if topic_store is not None and topic_key is not None - else None - ) - ambient_context = _merge_topic_context( - chat_project=chat_project, bound=bound_context - ) - command_id, args_text = _parse_slash_command(command_msg.text) - if command_id == "file": - command, rest, error = parse_file_command(args_text) - if error is not None: - await reply(text=error) - return - if command == "put": - await _handle_file_put_group( - cfg, - command_msg, - rest, - ordered, - ambient_context, - topic_store, - ) - return - if cfg.files.enabled and cfg.files.auto_put: - caption_text = command_msg.text.strip() - if cfg.files.auto_put_mode == "prompt" and caption_text: - if resolve_prompt is None: - try: - resolved = cfg.runtime.resolve_message( - text=caption_text, - reply_text=command_msg.reply_to_text, - ambient_context=ambient_context, - chat_id=command_msg.chat_id, - ) - except DirectiveError as exc: - await reply(text=f"error:\n{exc}") - return - else: - resolved = await resolve_prompt( - command_msg, caption_text, ambient_context - ) - if resolved is None: - return - saved_group = await _save_file_put_group( - cfg, - command_msg, - "", - ordered, - resolved.context, - topic_store, - ) - if saved_group is None: - return - if not saved_group.saved: - failure_text = _format_file_put_failures(saved_group.failed) - text = "failed to upload files." - if failure_text is not None: - text = f"{text}\n\n{failure_text}" - await reply(text=text) - return - if saved_group.failed: - failure_text = _format_file_put_failures(saved_group.failed) - if failure_text is not None: - await reply(text=f"some files failed to upload.\n\n{failure_text}") - if run_prompt is None: - await reply(text=FILE_PUT_USAGE) - return - paths = [ - item.rel_path.as_posix() - for item in saved_group.saved - if item.rel_path is not None - ] - files_text = "\n".join(f"- {path}" for path in paths) - prompt_base = resolved.prompt - annotation = f"[uploaded files]\n{files_text}" - if prompt_base and prompt_base.strip(): - prompt = f"{prompt_base}\n\n{annotation}" - else: - prompt = annotation - await run_prompt(command_msg, prompt, resolved) - return - if not caption_text: - await _handle_file_put_group( - cfg, - command_msg, - "", - ordered, - ambient_context, - topic_store, - ) - return - await reply(text=FILE_PUT_USAGE) - - -async def _handle_file_get( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - args_text: str, - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, -) -> None: - reply = _reply_sender(cfg, msg) - if not await _check_file_permissions(cfg, msg): - return - try: - resolved = cfg.runtime.resolve_message( - text=args_text, - reply_text=msg.reply_to_text, - ambient_context=ambient_context, - chat_id=msg.chat_id, - ) - except DirectiveError as exc: - await reply(text=f"error:\n{exc}") - return - topic_key = _topic_key(msg, cfg) if topic_store is not None else None - await _maybe_update_topic_context( - cfg=cfg, - topic_store=topic_store, - topic_key=topic_key, - context=resolved.context, - context_source=resolved.context_source, - ) - if resolved.context is None or resolved.context.project is None: - await reply(text="no project context available for file download.") - return - try: - run_root = cfg.runtime.resolve_run_cwd(resolved.context) - except ConfigError as exc: - await reply(text=f"error:\n{exc}") - return - if run_root is None: - await reply(text="no project context available for file download.") - return - path_value = resolved.prompt - if not path_value.strip(): - await reply(text=FILE_GET_USAGE) - return - rel_path = normalize_relative_path(path_value) - if rel_path is None: - await reply(text="invalid download path.") - return - deny_rule = deny_reason(rel_path, cfg.files.deny_globs) - if deny_rule is not None: - await reply(text=f"path denied by rule: {deny_rule}") - return - target = resolve_path_within_root(run_root, rel_path) - if target is None: - await reply(text="download path escapes the repo root.") - return - if not target.exists(): - await reply(text="file does not exist.") - return - if target.is_dir(): - try: - payload = zip_directory( - run_root, - rel_path, - cfg.files.deny_globs, - max_bytes=cfg.files.max_download_bytes, - ) - except ZipTooLargeError: - await reply(text="file is too large to send.") - return - except OSError as exc: - await reply(text=f"failed to read directory: {exc}") - return - filename = f"{rel_path.name or 'archive'}.zip" - else: - try: - size = target.stat().st_size - if size > cfg.files.max_download_bytes: - await reply(text="file is too large to send.") - return - payload = target.read_bytes() - except OSError as exc: - await reply(text=f"failed to read file: {exc}") - return - filename = target.name - if len(payload) > cfg.files.max_download_bytes: - await reply(text="file is too large to send.") - return - sent = await cfg.bot.send_document( - chat_id=msg.chat_id, - filename=filename, - content=payload, - reply_to_message_id=msg.message_id, - message_thread_id=msg.thread_id, - ) - if sent is None: - await reply(text="failed to send file.") - return - - -async def _handle_ctx_command( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - args_text: str, - store: TopicStateStore, - *, - resolved_scope: str | None = None, - scope_chat_ids: frozenset[int] | None = None, -) -> None: - reply = _reply_sender(cfg, msg) - error = _topics_command_error( - cfg, - msg.chat_id, - resolved_scope=resolved_scope, - scope_chat_ids=scope_chat_ids, - ) - if error is not None: - await reply(text=error) - return - chat_project = _topics_chat_project(cfg, msg.chat_id) - tkey = _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids) - if tkey is None: - await reply(text="this command only works inside a topic.") - return - tokens = split_command_args(args_text) - action = tokens[0].lower() if tokens else "show" - if action in {"show", ""}: - snapshot = await store.get_thread(*tkey) - bound = snapshot.context if snapshot is not None else None - ambient = _merge_topic_context(chat_project=chat_project, bound=bound) - resolved = cfg.runtime.resolve_message( - text="", - reply_text=msg.reply_to_text, - chat_id=msg.chat_id, - ambient_context=ambient, - ) - text = _format_ctx_status( - cfg=cfg, - runtime=cfg.runtime, - bound=bound, - resolved=resolved.context, - context_source=resolved.context_source, - snapshot=snapshot, - chat_project=chat_project, - ) - await reply(text=text) - return - if action == "set": - rest = " ".join(tokens[1:]) - context, error = _parse_project_branch_args( - rest, - runtime=cfg.runtime, - require_branch=False, - chat_project=chat_project, - ) - if error is not None: - await reply( - text=f"error:\n{error}\n{_usage_ctx_set(chat_project=chat_project)}", - ) - return - if context is None: - await reply( - text=f"error:\n{_usage_ctx_set(chat_project=chat_project)}", - ) - return - await store.set_context(*tkey, context) - await _maybe_rename_topic( - cfg, - store, - chat_id=tkey[0], - thread_id=tkey[1], - context=context, - ) - await reply( - text=f"topic bound to `{_format_context(cfg.runtime, context)}`", - ) - return - if action == "clear": - await store.clear_context(*tkey) - await reply(text="topic binding cleared.") - return - await reply( - text="unknown `/ctx` command. use `/ctx`, `/ctx set`, or `/ctx clear`.", - ) - - -async def _handle_agent_command( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - args_text: str, - ambient_context: RunContext | None, - topic_store: TopicStateStore | None, - chat_prefs: ChatPrefsStore | None, - *, - resolved_scope: str | None = None, - scope_chat_ids: frozenset[int] | None = None, -) -> None: - reply = _reply_sender(cfg, msg) - tkey = ( - _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids) - if topic_store is not None - else None - ) - tokens = split_command_args(args_text) - action = tokens[0].lower() if tokens else "show" - - if action in {"show", ""}: - try: - resolved = cfg.runtime.resolve_message( - text="", - reply_text=msg.reply_to_text, - ambient_context=ambient_context, - chat_id=msg.chat_id, - ) - except DirectiveError as exc: - await reply(text=f"error:\n{exc}") - return - selection = await resolve_engine_for_message( - runtime=cfg.runtime, - context=resolved.context, - explicit_engine=None, - chat_id=msg.chat_id, - topic_key=tkey, - topic_store=topic_store, - chat_prefs=chat_prefs, - ) - source_labels = { - "directive": "directive", - "topic_default": "topic default", - "chat_default": "chat default", - "project_default": "project default", - "global_default": "global default", - } - agent_line = f"agent: {selection.engine} ({source_labels[selection.source]})" - topic_default = selection.topic_default or "none" - if tkey is None: - topic_default = "none" - if chat_prefs is None: - chat_default = "unavailable" - else: - chat_default = selection.chat_default or "none" - project_default = ( - selection.project_default - if selection.project_default is not None - else "none" - ) - defaults_line = ( - "defaults: " - f"topic: {topic_default}, " - f"chat: {chat_default}, " - f"project: {project_default}, " - f"global: {cfg.runtime.default_engine}" - ) - available = ", ".join(cfg.runtime.engine_ids) - available_line = f"available: {available}" - await reply(text="\n\n".join([agent_line, defaults_line, available_line])) - return - - if action == "set": - if len(tokens) < 2: - await reply(text=AGENT_USAGE) - return - if not await _check_agent_permissions(cfg, msg): - return - engine = tokens[1].strip().lower() - if engine not in cfg.runtime.engine_ids: - available = ", ".join(cfg.runtime.engine_ids) - await reply( - text=f"unknown engine `{engine}`.\navailable agents: `{available}`", - ) - return - if tkey is not None: - if topic_store is None: - await reply(text="topic defaults are unavailable.") - return - await topic_store.set_default_engine(tkey[0], tkey[1], engine) - await reply(text=f"topic default agent set to `{engine}`") - return - if chat_prefs is None: - await reply(text="chat defaults are unavailable (no config path).") - return - await chat_prefs.set_default_engine(msg.chat_id, engine) - await reply(text=f"chat default agent set to `{engine}`") - return - - if action == "clear": - if not await _check_agent_permissions(cfg, msg): - return - if tkey is not None: - if topic_store is None: - await reply(text="topic defaults are unavailable.") - return - await topic_store.clear_default_engine(tkey[0], tkey[1]) - await reply(text="topic default agent cleared.") - return - if chat_prefs is None: - await reply(text="chat defaults are unavailable (no config path).") - return - await chat_prefs.clear_default_engine(msg.chat_id) - await reply(text="chat default agent cleared.") - return - - await reply(text=AGENT_USAGE) - - -async def _handle_new_command( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - store: TopicStateStore, - *, - resolved_scope: str | None = None, - scope_chat_ids: frozenset[int] | None = None, -) -> None: - reply = _reply_sender(cfg, msg) - error = _topics_command_error( - cfg, - msg.chat_id, - resolved_scope=resolved_scope, - scope_chat_ids=scope_chat_ids, - ) - if error is not None: - await reply(text=error) - return - tkey = _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids) - if tkey is None: - await reply(text="this command only works inside a topic.") - return - await store.clear_sessions(*tkey) - await reply(text="cleared stored sessions for this topic.") - - -async def _handle_chat_new_command( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - store: ChatSessionStore, - session_key: tuple[int, int | None] | None, -) -> None: - reply = _reply_sender(cfg, msg) - if session_key is None: - await reply(text="no stored sessions to clear for this chat.") - return - await store.clear_sessions(session_key[0], session_key[1]) - if msg.chat_type == "private": - text = "cleared stored sessions for this chat." - else: - text = "cleared stored sessions for you in this chat." - await reply(text=text) - - -async def _handle_topic_command( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - args_text: str, - store: TopicStateStore, - *, - resolved_scope: str | None = None, - scope_chat_ids: frozenset[int] | None = None, -) -> None: - reply = _reply_sender(cfg, msg) - error = _topics_command_error( - cfg, - msg.chat_id, - resolved_scope=resolved_scope, - scope_chat_ids=scope_chat_ids, - ) - if error is not None: - await reply(text=error) - return - chat_project = _topics_chat_project(cfg, msg.chat_id) - context, error = _parse_project_branch_args( - args_text, - runtime=cfg.runtime, - require_branch=True, - chat_project=chat_project, - ) - if error is not None or context is None: - usage = _usage_topic(chat_project=chat_project) - text = f"error:\n{error}\n{usage}" if error else usage - await reply(text=text) - return - existing = await store.find_thread_for_context(msg.chat_id, context) - if existing is not None: - await reply( - text=f"topic already exists for {_format_context(cfg.runtime, context)} " - "in this chat.", - ) - return - title = _topic_title(runtime=cfg.runtime, context=context) - created = await cfg.bot.create_forum_topic(msg.chat_id, title) - if created is None: - await reply(text="failed to create topic.") - return - thread_id = created.message_thread_id - await store.set_context( - msg.chat_id, - thread_id, - context, - topic_title=title, - ) - await reply(text=f"created topic `{title}`.") - bound_text = f"topic bound to `{_format_context(cfg.runtime, context)}`" - rendered_text, entities = prepare_telegram(MarkdownParts(header=bound_text)) - await cfg.exec_cfg.transport.send( - channel_id=msg.chat_id, - message=RenderedMessage(text=rendered_text, extra={"entities": entities}), - options=SendOptions(thread_id=thread_id), - ) - - -async def handle_cancel( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - running_tasks: RunningTasks, -) -> None: - reply = _reply_sender(cfg, msg) - chat_id = msg.chat_id - reply_id = msg.reply_to_message_id - - if reply_id is None: - if msg.reply_to_text: - await reply(text="nothing is currently running for that message.") - return - await reply(text="reply to the progress message to cancel.") - return - - progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id) - running_task = running_tasks.get(progress_ref) - if running_task is None: - await reply(text="nothing is currently running for that message.") - return - - logger.info( - "cancel.requested", - chat_id=chat_id, - progress_message_id=reply_id, - ) - running_task.cancel_requested.set() - - -async def handle_callback_cancel( - cfg: TelegramBridgeConfig, - query: TelegramCallbackQuery, - running_tasks: RunningTasks, -) -> None: - progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id) - running_task = running_tasks.get(progress_ref) - if running_task is None: - await cfg.bot.answer_callback_query( - callback_query_id=query.callback_query_id, - text="nothing is currently running for that message.", - ) - return - logger.info( - "cancel.requested", - chat_id=query.chat_id, - progress_message_id=query.message_id, - ) - running_task.cancel_requested.set() - await cfg.bot.answer_callback_query( - callback_query_id=query.callback_query_id, - text="cancelling...", - ) - - -async def _send_runner_unavailable( - exec_cfg: ExecBridgeConfig, - *, - chat_id: int, - user_msg_id: int, - resume_token: ResumeToken | None, - runner: Runner, - reason: str, - thread_id: int | None = None, -) -> None: - tracker = ProgressTracker(engine=runner.engine) - tracker.set_resume(resume_token) - state = tracker.snapshot(resume_formatter=runner.format_resume) - message = exec_cfg.presenter.render_final( - state, - elapsed_s=0.0, - status="error", - answer=f"error:\n{reason}", - ) - reply_to = MessageRef(channel_id=chat_id, message_id=user_msg_id) - await exec_cfg.transport.send( - channel_id=chat_id, - message=message, - options=SendOptions(reply_to=reply_to, notify=True, thread_id=thread_id), - ) - - -async def _run_engine( - *, - exec_cfg: ExecBridgeConfig, - runtime: TransportRuntime, - running_tasks: RunningTasks | None, - chat_id: int, - user_msg_id: int, - text: str, - resume_token: ResumeToken | None, - context: RunContext | None, - reply_ref: MessageRef | None = None, - on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] - | None = None, - engine_override: EngineId | None = None, - thread_id: int | None = None, - show_resume_line: bool = True, -) -> None: - reply = partial( - send_plain, - exec_cfg.transport, - chat_id=chat_id, - user_msg_id=user_msg_id, - thread_id=thread_id, - ) - try: - try: - entry = runtime.resolve_runner( - resume_token=resume_token, - engine_override=engine_override, - ) - except RunnerUnavailableError as exc: - await reply(text=f"error:\n{exc}") - return - runner: Runner = entry.runner - if not show_resume_line: - runner = cast(Runner, _ResumeLineProxy(runner)) - if not entry.available: - reason = entry.issue or "engine unavailable" - await _send_runner_unavailable( - exec_cfg, - chat_id=chat_id, - user_msg_id=user_msg_id, - resume_token=resume_token, - runner=runner, - reason=reason, - thread_id=thread_id, - ) - return - try: - cwd = runtime.resolve_run_cwd(context) - except ConfigError as exc: - await reply(text=f"error:\n{exc}") - return - run_base_token = set_run_base_dir(cwd) - try: - run_fields = { - "chat_id": chat_id, - "user_msg_id": user_msg_id, - "engine": runner.engine, - "resume": resume_token.value if resume_token else None, - } - if context is not None: - run_fields["project"] = context.project - run_fields["branch"] = context.branch - if cwd is not None: - run_fields["cwd"] = str(cwd) - bind_run_context(**run_fields) - context_line = runtime.format_context_line(context) - incoming = RunnerIncomingMessage( - channel_id=chat_id, - message_id=user_msg_id, - text=text, - reply_to=reply_ref, - thread_id=thread_id, - ) - await handle_message( - exec_cfg, - runner=runner, - incoming=incoming, - resume_token=resume_token, - context=context, - context_line=context_line, - strip_resume_line=runtime.is_resume_line, - running_tasks=running_tasks, - on_thread_known=on_thread_known, - ) - finally: - reset_run_base_dir(run_base_token) - except Exception as exc: - logger.exception( - "handle.worker_failed", - error=str(exc), - error_type=exc.__class__.__name__, - ) - finally: - clear_context() - - -class _CaptureTransport: - def __init__(self) -> None: - self._next_id = 1 - self.last_message: RenderedMessage | None = None - - async def send( - self, - *, - channel_id: int | str, - message: RenderedMessage, - options: SendOptions | None = None, - ) -> MessageRef: - thread_id = options.thread_id if options is not None else None - ref = MessageRef(channel_id=channel_id, message_id=self._next_id) - self._next_id += 1 - self.last_message = message - return MessageRef( - channel_id=ref.channel_id, - message_id=ref.message_id, - thread_id=thread_id, - ) - - async def edit( - self, *, ref: MessageRef, message: RenderedMessage, wait: bool = True - ) -> MessageRef: - _ = ref, wait - self.last_message = message - return ref - - async def delete(self, *, ref: MessageRef) -> bool: - _ = ref - return True - - async def close(self) -> None: - return None - - -class _TelegramCommandExecutor(CommandExecutor): - def __init__( - self, - *, - exec_cfg: ExecBridgeConfig, - runtime: TransportRuntime, - running_tasks: RunningTasks, - scheduler: ThreadScheduler, - on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None, - chat_id: int, - user_msg_id: int, - thread_id: int | None, - show_resume_line: bool, - stateful_mode: bool, - default_engine_override: EngineId | None, - ) -> None: - self._exec_cfg = exec_cfg - self._runtime = runtime - self._running_tasks = running_tasks - self._scheduler = scheduler - self._on_thread_known = on_thread_known - self._chat_id = chat_id - self._user_msg_id = user_msg_id - self._thread_id = thread_id - self._show_resume_line = show_resume_line - self._stateful_mode = stateful_mode - self._default_engine_override = default_engine_override - self._reply_ref = MessageRef( - channel_id=chat_id, - message_id=user_msg_id, - thread_id=thread_id, - ) - - def _apply_default_context(self, request: RunRequest) -> RunRequest: - if request.context is not None: - return request - context = self._runtime.default_context_for_chat(self._chat_id) - if context is None: - return request - return RunRequest( - prompt=request.prompt, - engine=request.engine, - context=context, - ) - - def _apply_default_engine(self, request: RunRequest) -> RunRequest: - if request.engine is not None or self._default_engine_override is None: - return request - return RunRequest( - prompt=request.prompt, - engine=self._default_engine_override, - context=request.context, - ) - - async def send( - self, - message: RenderedMessage | str, - *, - reply_to: MessageRef | None = None, - notify: bool = True, - ) -> MessageRef | None: - rendered = ( - message - if isinstance(message, RenderedMessage) - else RenderedMessage(text=message) - ) - reply_ref = self._reply_ref if reply_to is None else reply_to - return await self._exec_cfg.transport.send( - channel_id=self._chat_id, - message=rendered, - options=SendOptions( - reply_to=reply_ref, - notify=notify, - thread_id=self._thread_id, - ), - ) - - async def run_one( - self, request: RunRequest, *, mode: RunMode = "emit" - ) -> RunResult: - request = self._apply_default_context(request) - request = self._apply_default_engine(request) - effective_show_resume_line = _should_show_resume_line( - show_resume_line=self._show_resume_line, - stateful_mode=self._stateful_mode, - context=request.context, - ) - engine = self._runtime.resolve_engine( - engine_override=request.engine, - context=request.context, - ) - on_thread_known = ( - self._scheduler.note_thread_known - if self._on_thread_known is None - else self._on_thread_known - ) - if mode == "capture": - capture = _CaptureTransport() - exec_cfg = ExecBridgeConfig( - transport=capture, - presenter=self._exec_cfg.presenter, - final_notify=False, - ) - await _run_engine( - exec_cfg=exec_cfg, - runtime=self._runtime, - running_tasks={}, - chat_id=self._chat_id, - user_msg_id=self._user_msg_id, - text=request.prompt, - resume_token=None, - context=request.context, - reply_ref=self._reply_ref, - on_thread_known=on_thread_known, - engine_override=engine, - thread_id=self._thread_id, - show_resume_line=effective_show_resume_line, - ) - return RunResult(engine=engine, message=capture.last_message) - await _run_engine( - exec_cfg=self._exec_cfg, - runtime=self._runtime, - running_tasks=self._running_tasks, - chat_id=self._chat_id, - user_msg_id=self._user_msg_id, - text=request.prompt, - resume_token=None, - context=request.context, - reply_ref=self._reply_ref, - on_thread_known=on_thread_known, - engine_override=engine, - thread_id=self._thread_id, - show_resume_line=effective_show_resume_line, - ) - return RunResult(engine=engine, message=None) - - async def run_many( - self, - requests: Sequence[RunRequest], - *, - mode: RunMode = "emit", - parallel: bool = False, - ) -> list[RunResult]: - if not parallel: - return [await self.run_one(request, mode=mode) for request in requests] - results: list[RunResult | None] = [None] * len(requests) - - async with anyio.create_task_group() as tg: - - async def run_idx(idx: int, request: RunRequest) -> None: - results[idx] = await self.run_one(request, mode=mode) - - for idx, request in enumerate(requests): - tg.start_soon(run_idx, idx, request) - - return [result for result in results if result is not None] - - -async def _dispatch_command( - cfg: TelegramBridgeConfig, - msg: TelegramIncomingMessage, - text: str, - command_id: str, - args_text: str, - running_tasks: RunningTasks, - scheduler: ThreadScheduler, - on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None, - stateful_mode: bool, - default_engine_override: EngineId | None, -) -> None: - allowlist = cfg.runtime.allowlist - chat_id = msg.chat_id - user_msg_id = msg.message_id - reply_ref = ( - MessageRef( - channel_id=chat_id, - message_id=msg.reply_to_message_id, - thread_id=msg.thread_id, - ) - if msg.reply_to_message_id is not None - else None - ) - executor = _TelegramCommandExecutor( - exec_cfg=cfg.exec_cfg, - runtime=cfg.runtime, - running_tasks=running_tasks, - scheduler=scheduler, - on_thread_known=on_thread_known, - chat_id=chat_id, - user_msg_id=user_msg_id, - thread_id=msg.thread_id, - show_resume_line=cfg.show_resume_line, - stateful_mode=stateful_mode, - default_engine_override=default_engine_override, - ) - message_ref = MessageRef( - channel_id=chat_id, - message_id=user_msg_id, - thread_id=msg.thread_id, - sender_id=msg.sender_id, - raw=msg.raw, - ) - try: - backend = get_command(command_id, allowlist=allowlist, required=False) - except ConfigError as exc: - await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True) - return - if backend is None: - return - try: - plugin_config = cfg.runtime.plugin_config(command_id) - except ConfigError as exc: - await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True) - return - ctx = CommandContext( - command=command_id, - text=text, - args_text=args_text, - args=split_command_args(args_text), - message=message_ref, - reply_to=reply_ref, - reply_text=msg.reply_to_text, - config_path=cfg.runtime.config_path, - plugin_config=plugin_config, - runtime=cfg.runtime, - executor=executor, - ) - try: - result = await backend.handle(ctx) - except Exception as exc: - logger.exception( - "command.failed", - command=command_id, - error=str(exc), - error_type=exc.__class__.__name__, - ) - await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True) - return - if result is not None: - reply_to = message_ref if result.reply_to is None else result.reply_to - await executor.send(result.text, reply_to=reply_to, notify=result.notify) diff --git a/src/takopi/telegram/commands/__init__.py b/src/takopi/telegram/commands/__init__.py new file mode 100644 index 0000000..3ba0aac --- /dev/null +++ b/src/takopi/telegram/commands/__init__.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from .cancel import handle_callback_cancel, handle_cancel +from .menu import build_bot_commands +from .parse import is_cancel_command + +__all__ = [ + "build_bot_commands", + "handle_callback_cancel", + "handle_cancel", + "is_cancel_command", +] diff --git a/src/takopi/telegram/commands/agent.py b/src/takopi/telegram/commands/agent.py new file mode 100644 index 0000000..fec3672 --- /dev/null +++ b/src/takopi/telegram/commands/agent.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ...context import RunContext +from ...directives import DirectiveError +from ..chat_prefs import ChatPrefsStore +from ..engine_defaults import resolve_engine_for_message +from ..files import split_command_args +from ..topic_state import TopicStateStore +from ..topics import _topic_key +from ..types import TelegramIncomingMessage +from .reply import make_reply + +if TYPE_CHECKING: + from ..bridge import TelegramBridgeConfig + +AGENT_USAGE = "usage: `/agent`, `/agent set `, or `/agent clear`" + + +async def _check_agent_permissions( + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage +) -> bool: + reply = make_reply(cfg, msg) + sender_id = msg.sender_id + if sender_id is None: + await reply(text="cannot verify sender for agent defaults.") + return False + is_private = msg.chat_type == "private" + if msg.chat_type is None: + is_private = msg.chat_id > 0 + if is_private: + return True + member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) + if member is None: + await reply(text="failed to verify agent permissions.") + return False + if member.status in {"creator", "administrator"}: + return True + await reply(text="changing default agents is restricted to group admins.") + return False + + +async def _handle_agent_command( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + ambient_context: RunContext | None, + topic_store: TopicStateStore | None, + chat_prefs: ChatPrefsStore | None, + *, + resolved_scope: str | None = None, + scope_chat_ids: frozenset[int] | None = None, +) -> None: + reply = make_reply(cfg, msg) + tkey = ( + _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids) + if topic_store is not None + else None + ) + tokens = split_command_args(args_text) + action = tokens[0].lower() if tokens else "show" + + if action in {"show", ""}: + try: + resolved = cfg.runtime.resolve_message( + text="", + reply_text=msg.reply_to_text, + ambient_context=ambient_context, + chat_id=msg.chat_id, + ) + except DirectiveError as exc: + await reply(text=f"error:\n{exc}") + return + selection = await resolve_engine_for_message( + runtime=cfg.runtime, + context=resolved.context, + explicit_engine=None, + chat_id=msg.chat_id, + topic_key=tkey, + topic_store=topic_store, + chat_prefs=chat_prefs, + ) + source_labels = { + "directive": "directive", + "topic_default": "topic default", + "chat_default": "chat default", + "project_default": "project default", + "global_default": "global default", + } + agent_line = f"agent: {selection.engine} ({source_labels[selection.source]})" + topic_default = selection.topic_default or "none" + if tkey is None: + topic_default = "none" + if chat_prefs is None: + chat_default = "unavailable" + else: + chat_default = selection.chat_default or "none" + project_default = ( + selection.project_default + if selection.project_default is not None + else "none" + ) + defaults_line = ( + "defaults: " + f"topic: {topic_default}, " + f"chat: {chat_default}, " + f"project: {project_default}, " + f"global: {cfg.runtime.default_engine}" + ) + available = ", ".join(cfg.runtime.engine_ids) + available_line = f"available: {available}" + await reply(text="\n\n".join([agent_line, defaults_line, available_line])) + return + + if action == "set": + if len(tokens) < 2: + await reply(text=AGENT_USAGE) + return + if not await _check_agent_permissions(cfg, msg): + return + engine = tokens[1].strip().lower() + if engine not in cfg.runtime.engine_ids: + available = ", ".join(cfg.runtime.engine_ids) + await reply( + text=f"unknown engine `{engine}`.\navailable agents: `{available}`", + ) + return + if tkey is not None: + if topic_store is None: + await reply(text="topic defaults are unavailable.") + return + await topic_store.set_default_engine(tkey[0], tkey[1], engine) + await reply(text=f"topic default agent set to `{engine}`") + return + if chat_prefs is None: + await reply(text="chat defaults are unavailable (no config path).") + return + await chat_prefs.set_default_engine(msg.chat_id, engine) + await reply(text=f"chat default agent set to `{engine}`") + return + + if action == "clear": + if not await _check_agent_permissions(cfg, msg): + return + if tkey is not None: + if topic_store is None: + await reply(text="topic defaults are unavailable.") + return + await topic_store.clear_default_engine(tkey[0], tkey[1]) + await reply(text="topic default agent cleared.") + return + if chat_prefs is None: + await reply(text="chat defaults are unavailable (no config path).") + return + await chat_prefs.clear_default_engine(msg.chat_id) + await reply(text="chat default agent cleared.") + return + + await reply(text=AGENT_USAGE) diff --git a/src/takopi/telegram/commands/cancel.py b/src/takopi/telegram/commands/cancel.py new file mode 100644 index 0000000..cb0959e --- /dev/null +++ b/src/takopi/telegram/commands/cancel.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ...logging import get_logger +from ...runner_bridge import RunningTasks +from ...transport import MessageRef +from ..types import TelegramCallbackQuery, TelegramIncomingMessage +from .reply import make_reply + +if TYPE_CHECKING: + from ..bridge import TelegramBridgeConfig + +logger = get_logger(__name__) + + +async def handle_cancel( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + running_tasks: RunningTasks, +) -> None: + reply = make_reply(cfg, msg) + chat_id = msg.chat_id + reply_id = msg.reply_to_message_id + + if reply_id is None: + if msg.reply_to_text: + await reply(text="nothing is currently running for that message.") + return + await reply(text="reply to the progress message to cancel.") + return + + progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id) + running_task = running_tasks.get(progress_ref) + if running_task is None: + await reply(text="nothing is currently running for that message.") + return + + logger.info( + "cancel.requested", + chat_id=chat_id, + progress_message_id=reply_id, + ) + running_task.cancel_requested.set() + + +async def handle_callback_cancel( + cfg: TelegramBridgeConfig, + query: TelegramCallbackQuery, + running_tasks: RunningTasks, +) -> None: + progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id) + running_task = running_tasks.get(progress_ref) + if running_task is None: + await cfg.bot.answer_callback_query( + callback_query_id=query.callback_query_id, + text="nothing is currently running for that message.", + ) + return + logger.info( + "cancel.requested", + chat_id=query.chat_id, + progress_message_id=query.message_id, + ) + running_task.cancel_requested.set() + await cfg.bot.answer_callback_query( + callback_query_id=query.callback_query_id, + text="cancelling...", + ) diff --git a/src/takopi/telegram/commands/dispatch.py b/src/takopi/telegram/commands/dispatch.py new file mode 100644 index 0000000..78f8ad3 --- /dev/null +++ b/src/takopi/telegram/commands/dispatch.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING + +import anyio + +from ...commands import CommandContext, get_command +from ...config import ConfigError +from ...logging import get_logger +from ...model import EngineId, ResumeToken +from ...runner_bridge import RunningTasks +from ...scheduler import ThreadScheduler +from ...transport import MessageRef +from ..files import split_command_args +from ..types import TelegramIncomingMessage +from .executor import _TelegramCommandExecutor + +if TYPE_CHECKING: + from ..bridge import TelegramBridgeConfig + +logger = get_logger(__name__) + + +async def _dispatch_command( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + text: str, + command_id: str, + args_text: str, + running_tasks: RunningTasks, + scheduler: ThreadScheduler, + on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None, + stateful_mode: bool, + default_engine_override: EngineId | None, +) -> None: + allowlist = cfg.runtime.allowlist + chat_id = msg.chat_id + user_msg_id = msg.message_id + reply_ref = ( + MessageRef( + channel_id=chat_id, + message_id=msg.reply_to_message_id, + thread_id=msg.thread_id, + ) + if msg.reply_to_message_id is not None + else None + ) + executor = _TelegramCommandExecutor( + exec_cfg=cfg.exec_cfg, + runtime=cfg.runtime, + running_tasks=running_tasks, + scheduler=scheduler, + on_thread_known=on_thread_known, + chat_id=chat_id, + user_msg_id=user_msg_id, + thread_id=msg.thread_id, + show_resume_line=cfg.show_resume_line, + stateful_mode=stateful_mode, + default_engine_override=default_engine_override, + ) + message_ref = MessageRef( + channel_id=chat_id, + message_id=user_msg_id, + thread_id=msg.thread_id, + sender_id=msg.sender_id, + raw=msg.raw, + ) + try: + backend = get_command(command_id, allowlist=allowlist, required=False) + except ConfigError as exc: + await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True) + return + if backend is None: + return + try: + plugin_config = cfg.runtime.plugin_config(command_id) + except ConfigError as exc: + await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True) + return + ctx = CommandContext( + command=command_id, + text=text, + args_text=args_text, + args=split_command_args(args_text), + message=message_ref, + reply_to=reply_ref, + reply_text=msg.reply_to_text, + config_path=cfg.runtime.config_path, + plugin_config=plugin_config, + runtime=cfg.runtime, + executor=executor, + ) + try: + result = await backend.handle(ctx) + except Exception as exc: + logger.exception( + "command.failed", + command=command_id, + error=str(exc), + error_type=exc.__class__.__name__, + ) + await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True) + return + if result is not None: + reply_to = message_ref if result.reply_to is None else result.reply_to + await executor.send(result.text, reply_to=reply_to, notify=result.notify) diff --git a/src/takopi/telegram/commands/executor.py b/src/takopi/telegram/commands/executor.py new file mode 100644 index 0000000..842ac2b --- /dev/null +++ b/src/takopi/telegram/commands/executor.py @@ -0,0 +1,382 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence +from dataclasses import dataclass +from functools import partial +from typing import cast + +import anyio + +from ...commands import CommandExecutor, RunMode, RunRequest, RunResult +from ...config import ConfigError +from ...context import RunContext +from ...logging import bind_run_context, clear_context, get_logger +from ...model import EngineId, ResumeToken, TakopiEvent +from ...progress import ProgressTracker +from ...router import RunnerUnavailableError +from ...runner import Runner +from ...runner_bridge import ( + ExecBridgeConfig, + IncomingMessage as RunnerIncomingMessage, + RunningTasks, + handle_message, +) +from ...scheduler import ThreadScheduler +from ...transport import MessageRef, RenderedMessage, SendOptions +from ...transport_runtime import TransportRuntime +from ...utils.paths import reset_run_base_dir, set_run_base_dir +from ..bridge import send_plain + +logger = get_logger(__name__) + + +@dataclass(slots=True) +class _ResumeLineProxy: + runner: Runner + + @property + def engine(self) -> str: + return self.runner.engine + + def is_resume_line(self, line: str) -> bool: + return self.runner.is_resume_line(line) + + def format_resume(self, _: ResumeToken) -> str: + return "" + + def extract_resume(self, text: str | None) -> ResumeToken | None: + return self.runner.extract_resume(text) + + def run( + self, prompt: str, resume: ResumeToken | None + ) -> AsyncIterator[TakopiEvent]: + return self.runner.run(prompt, resume) + + +def _should_show_resume_line( + *, + show_resume_line: bool, + stateful_mode: bool, + context: RunContext | None, +) -> bool: + if show_resume_line or not stateful_mode: + return True + return context is None or context.project is None + + +async def _send_runner_unavailable( + exec_cfg: ExecBridgeConfig, + *, + chat_id: int, + user_msg_id: int, + resume_token: ResumeToken | None, + runner: Runner, + reason: str, + thread_id: int | None = None, +) -> None: + tracker = ProgressTracker(engine=runner.engine) + tracker.set_resume(resume_token) + state = tracker.snapshot(resume_formatter=runner.format_resume) + message = exec_cfg.presenter.render_final( + state, + elapsed_s=0.0, + status="error", + answer=f"error:\n{reason}", + ) + reply_to = MessageRef(channel_id=chat_id, message_id=user_msg_id) + await exec_cfg.transport.send( + channel_id=chat_id, + message=message, + options=SendOptions(reply_to=reply_to, notify=True, thread_id=thread_id), + ) + + +async def _run_engine( + *, + exec_cfg: ExecBridgeConfig, + runtime: TransportRuntime, + running_tasks: RunningTasks | None, + chat_id: int, + user_msg_id: int, + text: str, + resume_token: ResumeToken | None, + context: RunContext | None, + reply_ref: MessageRef | None = None, + on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] + | None = None, + engine_override: EngineId | None = None, + thread_id: int | None = None, + show_resume_line: bool = True, +) -> None: + reply = partial( + send_plain, + exec_cfg.transport, + chat_id=chat_id, + user_msg_id=user_msg_id, + thread_id=thread_id, + ) + try: + try: + entry = runtime.resolve_runner( + resume_token=resume_token, + engine_override=engine_override, + ) + except RunnerUnavailableError as exc: + await reply(text=f"error:\n{exc}") + return + runner: Runner = entry.runner + if not show_resume_line: + runner = cast(Runner, _ResumeLineProxy(runner)) + if not entry.available: + reason = entry.issue or "engine unavailable" + await _send_runner_unavailable( + exec_cfg, + chat_id=chat_id, + user_msg_id=user_msg_id, + resume_token=resume_token, + runner=runner, + reason=reason, + thread_id=thread_id, + ) + return + try: + cwd = runtime.resolve_run_cwd(context) + except ConfigError as exc: + await reply(text=f"error:\n{exc}") + return + run_base_token = set_run_base_dir(cwd) + try: + run_fields = { + "chat_id": chat_id, + "user_msg_id": user_msg_id, + "engine": runner.engine, + "resume": resume_token.value if resume_token else None, + } + if context is not None: + run_fields["project"] = context.project + run_fields["branch"] = context.branch + if cwd is not None: + run_fields["cwd"] = str(cwd) + bind_run_context(**run_fields) + context_line = runtime.format_context_line(context) + incoming = RunnerIncomingMessage( + channel_id=chat_id, + message_id=user_msg_id, + text=text, + reply_to=reply_ref, + thread_id=thread_id, + ) + await handle_message( + exec_cfg, + runner=runner, + incoming=incoming, + resume_token=resume_token, + context=context, + context_line=context_line, + strip_resume_line=runtime.is_resume_line, + running_tasks=running_tasks, + on_thread_known=on_thread_known, + ) + finally: + reset_run_base_dir(run_base_token) + except Exception as exc: + logger.exception( + "handle.worker_failed", + error=str(exc), + error_type=exc.__class__.__name__, + ) + finally: + clear_context() + + +class _CaptureTransport: + def __init__(self) -> None: + self._next_id = 1 + self.last_message: RenderedMessage | None = None + + async def send( + self, + *, + channel_id: int | str, + message: RenderedMessage, + options: SendOptions | None = None, + ) -> MessageRef: + thread_id = options.thread_id if options is not None else None + ref = MessageRef(channel_id=channel_id, message_id=self._next_id) + self._next_id += 1 + self.last_message = message + return MessageRef( + channel_id=ref.channel_id, + message_id=ref.message_id, + thread_id=thread_id, + ) + + async def edit( + self, *, ref: MessageRef, message: RenderedMessage, wait: bool = True + ) -> MessageRef: + self.last_message = message + return ref + + async def delete(self, *, ref: MessageRef) -> bool: + return True + + async def close(self) -> None: + return None + + +class _TelegramCommandExecutor(CommandExecutor): + def __init__( + self, + *, + exec_cfg: ExecBridgeConfig, + runtime: TransportRuntime, + running_tasks: RunningTasks, + scheduler: ThreadScheduler, + on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None, + chat_id: int, + user_msg_id: int, + thread_id: int | None, + show_resume_line: bool, + stateful_mode: bool, + default_engine_override: EngineId | None, + ) -> None: + self._exec_cfg = exec_cfg + self._runtime = runtime + self._running_tasks = running_tasks + self._scheduler = scheduler + self._on_thread_known = on_thread_known + self._chat_id = chat_id + self._user_msg_id = user_msg_id + self._thread_id = thread_id + self._show_resume_line = show_resume_line + self._stateful_mode = stateful_mode + self._default_engine_override = default_engine_override + self._reply_ref = MessageRef( + channel_id=chat_id, + message_id=user_msg_id, + thread_id=thread_id, + ) + + def _apply_default_context(self, request: RunRequest) -> RunRequest: + if request.context is not None: + return request + context = self._runtime.default_context_for_chat(self._chat_id) + if context is None: + return request + return RunRequest( + prompt=request.prompt, + engine=request.engine, + context=context, + ) + + def _apply_default_engine(self, request: RunRequest) -> RunRequest: + if request.engine is not None or self._default_engine_override is None: + return request + return RunRequest( + prompt=request.prompt, + engine=self._default_engine_override, + context=request.context, + ) + + async def send( + self, + message: RenderedMessage | str, + *, + reply_to: MessageRef | None = None, + notify: bool = True, + ) -> MessageRef | None: + rendered = ( + message + if isinstance(message, RenderedMessage) + else RenderedMessage(text=message) + ) + reply_ref = self._reply_ref if reply_to is None else reply_to + return await self._exec_cfg.transport.send( + channel_id=self._chat_id, + message=rendered, + options=SendOptions( + reply_to=reply_ref, + notify=notify, + thread_id=self._thread_id, + ), + ) + + async def run_one( + self, request: RunRequest, *, mode: RunMode = "emit" + ) -> RunResult: + request = self._apply_default_context(request) + request = self._apply_default_engine(request) + effective_show_resume_line = _should_show_resume_line( + show_resume_line=self._show_resume_line, + stateful_mode=self._stateful_mode, + context=request.context, + ) + engine = self._runtime.resolve_engine( + engine_override=request.engine, + context=request.context, + ) + on_thread_known = ( + self._scheduler.note_thread_known + if self._on_thread_known is None + else self._on_thread_known + ) + if mode == "capture": + capture = _CaptureTransport() + exec_cfg = ExecBridgeConfig( + transport=capture, + presenter=self._exec_cfg.presenter, + final_notify=False, + ) + await _run_engine( + exec_cfg=exec_cfg, + runtime=self._runtime, + running_tasks={}, + chat_id=self._chat_id, + user_msg_id=self._user_msg_id, + text=request.prompt, + resume_token=None, + context=request.context, + reply_ref=self._reply_ref, + on_thread_known=on_thread_known, + engine_override=engine, + thread_id=self._thread_id, + show_resume_line=effective_show_resume_line, + ) + return RunResult(engine=engine, message=capture.last_message) + await _run_engine( + exec_cfg=self._exec_cfg, + runtime=self._runtime, + running_tasks=self._running_tasks, + chat_id=self._chat_id, + user_msg_id=self._user_msg_id, + text=request.prompt, + resume_token=None, + context=request.context, + reply_ref=self._reply_ref, + on_thread_known=on_thread_known, + engine_override=engine, + thread_id=self._thread_id, + show_resume_line=effective_show_resume_line, + ) + return RunResult(engine=engine, message=None) + + async def run_many( + self, + requests: Sequence[RunRequest], + *, + mode: RunMode = "emit", + parallel: bool = False, + ) -> list[RunResult]: + if not parallel: + return [await self.run_one(request, mode=mode) for request in requests] + results: list[RunResult | None] = [None] * len(requests) + + async with anyio.create_task_group() as tg: + + async def run_idx(idx: int, request: RunRequest) -> None: + results[idx] = await self.run_one(request, mode=mode) + + for idx, request in enumerate(requests): + tg.start_soon(run_idx, idx, request) + + return [result for result in results if result is not None] diff --git a/src/takopi/telegram/commands/file_transfer.py b/src/takopi/telegram/commands/file_transfer.py new file mode 100644 index 0000000..43266ec --- /dev/null +++ b/src/takopi/telegram/commands/file_transfer.py @@ -0,0 +1,589 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +from ...config import ConfigError +from ...context import RunContext +from ...directives import DirectiveError +from ...transport_runtime import ResolvedMessage +from ..context import _format_context +from ..files import ( + default_upload_name, + default_upload_path, + deny_reason, + format_bytes, + normalize_relative_path, + parse_file_command, + parse_file_prompt, + resolve_path_within_root, + write_bytes_atomic, + ZipTooLargeError, + zip_directory, +) +from ..topic_state import TopicStateStore +from ..topics import _maybe_update_topic_context, _topic_key +from ..types import TelegramDocument, TelegramIncomingMessage +from .reply import make_reply + +if TYPE_CHECKING: + from ..bridge import TelegramBridgeConfig + +FILE_PUT_USAGE = "usage: `/file put `" +FILE_GET_USAGE = "usage: `/file get `" + + +@dataclass(slots=True) +class _FilePutPlan: + resolved: ResolvedMessage + run_root: Path + path_value: str | None + force: bool + + +@dataclass(slots=True) +class _FilePutResult: + name: str + rel_path: Path | None + size: int | None + error: str | None + + +@dataclass(slots=True) +class _SavedFilePut: + context: RunContext | None + rel_path: Path + size: int + + +@dataclass(slots=True) +class _SavedFilePutGroup: + context: RunContext | None + base_dir: Path | None + saved: list[_FilePutResult] + failed: list[_FilePutResult] + + +def resolve_file_put_paths( + plan: _FilePutPlan, + *, + cfg: TelegramBridgeConfig, + require_dir: bool, +) -> tuple[Path | None, Path | None, str | None]: + path_value = plan.path_value + if not path_value: + return None, None, None + if require_dir or path_value.endswith("/"): + base_dir = normalize_relative_path(path_value) + if base_dir is None: + return None, None, "invalid upload path." + deny_rule = deny_reason(base_dir, cfg.files.deny_globs) + if deny_rule is not None: + return None, None, f"path denied by rule: {deny_rule}" + base_target = resolve_path_within_root(plan.run_root, base_dir) + if base_target is None: + return None, None, "upload path escapes the repo root." + if base_target.exists() and not base_target.is_dir(): + return None, None, "upload path is a file." + return base_dir, None, None + rel_path = normalize_relative_path(path_value) + if rel_path is None: + return None, None, "invalid upload path." + return None, rel_path, None + + +async def _check_file_permissions( + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage +) -> bool: + reply = make_reply(cfg, msg) + sender_id = msg.sender_id + if sender_id is None: + await reply(text="cannot verify sender for file transfer.") + return False + if cfg.files.allowed_user_ids: + if sender_id not in cfg.files.allowed_user_ids: + await reply(text="file transfer is not allowed for this user.") + return False + return True + is_private = msg.chat_type == "private" + if msg.chat_type is None: + is_private = msg.chat_id > 0 + if is_private: + return True + member = await cfg.bot.get_chat_member(msg.chat_id, sender_id) + if member is None: + await reply(text="failed to verify file transfer permissions.") + return False + if member.status in {"creator", "administrator"}: + return True + await reply(text="file transfer is restricted to group admins.") + return False + + +async def _prepare_file_put_plan( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + ambient_context: RunContext | None, + topic_store: TopicStateStore | None, +) -> _FilePutPlan | None: + reply = make_reply(cfg, msg) + if not await _check_file_permissions(cfg, msg): + return None + try: + resolved = cfg.runtime.resolve_message( + text=args_text, + reply_text=msg.reply_to_text, + ambient_context=ambient_context, + chat_id=msg.chat_id, + ) + except DirectiveError as exc: + await reply(text=f"error:\n{exc}") + return None + topic_key = _topic_key(msg, cfg) if topic_store is not None else None + await _maybe_update_topic_context( + cfg=cfg, + topic_store=topic_store, + topic_key=topic_key, + context=resolved.context, + context_source=resolved.context_source, + ) + if resolved.context is None or resolved.context.project is None: + await reply(text="no project context available for file upload.") + return None + try: + run_root = cfg.runtime.resolve_run_cwd(resolved.context) + except ConfigError as exc: + await reply(text=f"error:\n{exc}") + return None + if run_root is None: + await reply(text="no project context available for file upload.") + return None + path_value, force, error = parse_file_prompt(resolved.prompt, allow_empty=True) + if error is not None: + await reply(text=error) + return None + return _FilePutPlan( + resolved=resolved, + run_root=run_root, + path_value=path_value, + force=force, + ) + + +def _format_file_put_failures(failed: Sequence[_FilePutResult]) -> str | None: + if not failed: + return None + errors = ", ".join( + f"`{item.name}` ({item.error})" for item in failed if item.error is not None + ) + if not errors: + return None + return f"failed: {errors}" + + +async def _save_document_payload( + cfg: TelegramBridgeConfig, + *, + document: TelegramDocument, + run_root: Path, + rel_path: Path | None, + base_dir: Path | None, + force: bool, +) -> _FilePutResult: + name = default_upload_name(document.file_name, None) + if ( + document.file_size is not None + and document.file_size > cfg.files.max_upload_bytes + ): + return _FilePutResult( + name=name, + rel_path=None, + size=None, + error="file is too large to upload.", + ) + file_info = await cfg.bot.get_file(document.file_id) + if file_info is None: + return _FilePutResult( + name=name, + rel_path=None, + size=None, + error="failed to fetch file metadata.", + ) + file_path = file_info.file_path + name = default_upload_name(document.file_name, file_path) + resolved_path = rel_path + if resolved_path is None: + if base_dir is None: + resolved_path = default_upload_path( + cfg.files.uploads_dir, document.file_name, file_path + ) + else: + resolved_path = base_dir / name + deny_rule = deny_reason(resolved_path, cfg.files.deny_globs) + if deny_rule is not None: + return _FilePutResult( + name=name, + rel_path=None, + size=None, + error=f"path denied by rule: {deny_rule}", + ) + target = resolve_path_within_root(run_root, resolved_path) + if target is None: + return _FilePutResult( + name=name, + rel_path=None, + size=None, + error="upload path escapes the repo root.", + ) + if target.exists(): + if target.is_dir(): + return _FilePutResult( + name=name, + rel_path=None, + size=None, + error="upload target is a directory.", + ) + if not force: + return _FilePutResult( + name=name, + rel_path=None, + size=None, + error="file already exists; use --force to overwrite.", + ) + payload = await cfg.bot.download_file(file_path) + if payload is None: + return _FilePutResult( + name=name, + rel_path=None, + size=None, + error="failed to download file.", + ) + if len(payload) > cfg.files.max_upload_bytes: + return _FilePutResult( + name=name, + rel_path=None, + size=None, + error="file is too large to upload.", + ) + try: + write_bytes_atomic(target, payload) + except OSError as exc: + return _FilePutResult( + name=name, + rel_path=None, + size=None, + error=f"failed to write file: {exc}", + ) + return _FilePutResult( + name=name, + rel_path=resolved_path, + size=len(payload), + error=None, + ) + + +async def _handle_file_command( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + ambient_context: RunContext | None, + topic_store: TopicStateStore | None, +) -> None: + reply = make_reply(cfg, msg) + command, rest, error = parse_file_command(args_text) + if error is not None: + await reply(text=error) + return + if command == "put": + await _handle_file_put(cfg, msg, rest, ambient_context, topic_store) + else: + await _handle_file_get(cfg, msg, rest, ambient_context, topic_store) + + +async def _handle_file_put_default( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + ambient_context: RunContext | None, + topic_store: TopicStateStore | None, +) -> None: + await _handle_file_put(cfg, msg, "", ambient_context, topic_store) + + +async def _save_file_put( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + ambient_context: RunContext | None, + topic_store: TopicStateStore | None, +) -> _SavedFilePut | None: + reply = make_reply(cfg, msg) + document = msg.document + if document is None: + await reply(text=FILE_PUT_USAGE) + return None + plan = await _prepare_file_put_plan( + cfg, + msg, + args_text, + ambient_context, + topic_store, + ) + if plan is None: + return None + base_dir, rel_path, error = resolve_file_put_paths( + plan, + cfg=cfg, + require_dir=False, + ) + if error is not None: + await reply(text=error) + return None + result = await _save_document_payload( + cfg, + document=document, + run_root=plan.run_root, + rel_path=rel_path, + base_dir=base_dir, + force=plan.force, + ) + if result.error is not None: + await reply(text=result.error) + return None + if result.rel_path is None or result.size is None: + await reply(text="failed to save file.") + return None + return _SavedFilePut( + context=plan.resolved.context, + rel_path=result.rel_path, + size=result.size, + ) + + +async def _handle_file_put( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + ambient_context: RunContext | None, + topic_store: TopicStateStore | None, +) -> None: + reply = make_reply(cfg, msg) + saved = await _save_file_put( + cfg, + msg, + args_text, + ambient_context, + topic_store, + ) + if saved is None: + return + context_label = _format_context(cfg.runtime, saved.context) + await reply( + text=( + f"saved `{saved.rel_path.as_posix()}` " + f"in `{context_label}` ({format_bytes(saved.size)})" + ), + ) + + +async def _handle_file_put_group( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + messages: Sequence[TelegramIncomingMessage], + ambient_context: RunContext | None, + topic_store: TopicStateStore | None, +) -> None: + reply = make_reply(cfg, msg) + saved_group = await _save_file_put_group( + cfg, + msg, + args_text, + messages, + ambient_context, + topic_store, + ) + if saved_group is None: + return + context_label = _format_context(cfg.runtime, saved_group.context) + total_bytes = sum(item.size or 0 for item in saved_group.saved) + dir_label: Path | None = saved_group.base_dir + if dir_label is None and saved_group.saved: + first_path = saved_group.saved[0].rel_path + if first_path is not None: + dir_label = first_path.parent + if saved_group.saved: + saved_names = ", ".join(f"`{item.name}`" for item in saved_group.saved) + if dir_label is not None: + dir_text = dir_label.as_posix() + if not dir_text.endswith("/"): + dir_text = f"{dir_text}/" + text = ( + f"saved {saved_names} to `{dir_text}` " + f"in `{context_label}` ({format_bytes(total_bytes)})" + ) + else: + text = ( + f"saved {saved_names} in `{context_label}` " + f"({format_bytes(total_bytes)})" + ) + else: + text = "failed to upload files." + failure_text = _format_file_put_failures(saved_group.failed) + if failure_text is not None: + text = f"{text}\n\n{failure_text}" + await reply(text=text) + + +async def _save_file_put_group( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + messages: Sequence[TelegramIncomingMessage], + ambient_context: RunContext | None, + topic_store: TopicStateStore | None, +) -> _SavedFilePutGroup | None: + reply = make_reply(cfg, msg) + documents = [item.document for item in messages if item.document is not None] + if not documents: + await reply(text=FILE_PUT_USAGE) + return None + plan = await _prepare_file_put_plan( + cfg, + msg, + args_text, + ambient_context, + topic_store, + ) + if plan is None: + return None + base_dir, _, error = resolve_file_put_paths( + plan, + cfg=cfg, + require_dir=True, + ) + if error is not None: + await reply(text=error) + return None + saved: list[_FilePutResult] = [] + failed: list[_FilePutResult] = [] + for document in documents: + result = await _save_document_payload( + cfg, + document=document, + run_root=plan.run_root, + rel_path=None, + base_dir=base_dir, + force=plan.force, + ) + if result.error is None: + saved.append(result) + else: + failed.append(result) + return _SavedFilePutGroup( + context=plan.resolved.context, + base_dir=base_dir, + saved=saved, + failed=failed, + ) + + +async def _handle_file_get( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + ambient_context: RunContext | None, + topic_store: TopicStateStore | None, +) -> None: + reply = make_reply(cfg, msg) + if not await _check_file_permissions(cfg, msg): + return + try: + resolved = cfg.runtime.resolve_message( + text=args_text, + reply_text=msg.reply_to_text, + ambient_context=ambient_context, + chat_id=msg.chat_id, + ) + except DirectiveError as exc: + await reply(text=f"error:\n{exc}") + return + topic_key = _topic_key(msg, cfg) if topic_store is not None else None + await _maybe_update_topic_context( + cfg=cfg, + topic_store=topic_store, + topic_key=topic_key, + context=resolved.context, + context_source=resolved.context_source, + ) + if resolved.context is None or resolved.context.project is None: + await reply(text="no project context available for file download.") + return + try: + run_root = cfg.runtime.resolve_run_cwd(resolved.context) + except ConfigError as exc: + await reply(text=f"error:\n{exc}") + return + if run_root is None: + await reply(text="no project context available for file download.") + return + path_value = resolved.prompt + if not path_value.strip(): + await reply(text=FILE_GET_USAGE) + return + rel_path = normalize_relative_path(path_value) + if rel_path is None: + await reply(text="invalid download path.") + return + deny_rule = deny_reason(rel_path, cfg.files.deny_globs) + if deny_rule is not None: + await reply(text=f"path denied by rule: {deny_rule}") + return + target = resolve_path_within_root(run_root, rel_path) + if target is None: + await reply(text="download path escapes the repo root.") + return + if not target.exists(): + await reply(text="file does not exist.") + return + if target.is_dir(): + try: + payload = zip_directory( + run_root, + rel_path, + cfg.files.deny_globs, + max_bytes=cfg.files.max_download_bytes, + ) + except ZipTooLargeError: + await reply(text="file is too large to send.") + return + except OSError as exc: + await reply(text=f"failed to read directory: {exc}") + return + filename = f"{rel_path.name or 'archive'}.zip" + else: + try: + size = target.stat().st_size + if size > cfg.files.max_download_bytes: + await reply(text="file is too large to send.") + return + payload = target.read_bytes() + except OSError as exc: + await reply(text=f"failed to read file: {exc}") + return + filename = target.name + if len(payload) > cfg.files.max_download_bytes: + await reply(text="file is too large to send.") + return + sent = await cfg.bot.send_document( + chat_id=msg.chat_id, + filename=filename, + content=payload, + reply_to_message_id=msg.message_id, + message_thread_id=msg.thread_id, + ) + if sent is None: + await reply(text="failed to send file.") + return diff --git a/src/takopi/telegram/commands/media.py b/src/takopi/telegram/commands/media.py new file mode 100644 index 0000000..b38709e --- /dev/null +++ b/src/takopi/telegram/commands/media.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Sequence +from typing import TYPE_CHECKING + +from ...context import RunContext +from ...directives import DirectiveError +from ...transport_runtime import ResolvedMessage +from ..context import _merge_topic_context +from ..files import parse_file_command +from ..topic_state import TopicStateStore +from ..topics import _topic_key, _topics_chat_project +from ..types import TelegramIncomingMessage +from .file_transfer import ( + FILE_PUT_USAGE, + _format_file_put_failures, + _handle_file_put_group, + _save_file_put_group, +) +from .parse import _parse_slash_command +from .reply import make_reply + +if TYPE_CHECKING: + from ..bridge import TelegramBridgeConfig + + +async def _handle_media_group( + cfg: TelegramBridgeConfig, + messages: Sequence[TelegramIncomingMessage], + topic_store: TopicStateStore | None, + run_prompt: Callable[ + [TelegramIncomingMessage, str, ResolvedMessage], Awaitable[None] + ] + | None = None, + resolve_prompt: Callable[ + [TelegramIncomingMessage, str, RunContext | None], + Awaitable[ResolvedMessage | None], + ] + | None = None, +) -> None: + if not messages: + return + ordered = sorted(messages, key=lambda item: item.message_id) + command_msg = next( + (item for item in ordered if item.text.strip()), + ordered[0], + ) + reply = make_reply(cfg, command_msg) + topic_key = _topic_key(command_msg, cfg) if topic_store is not None else None + chat_project = _topics_chat_project(cfg, command_msg.chat_id) + bound_context = ( + await topic_store.get_context(*topic_key) + if topic_store is not None and topic_key is not None + else None + ) + ambient_context = _merge_topic_context( + chat_project=chat_project, bound=bound_context + ) + command_id, args_text = _parse_slash_command(command_msg.text) + if command_id == "file": + command, rest, error = parse_file_command(args_text) + if error is not None: + await reply(text=error) + return + if command == "put": + await _handle_file_put_group( + cfg, + command_msg, + rest, + ordered, + ambient_context, + topic_store, + ) + return + if cfg.files.enabled and cfg.files.auto_put: + caption_text = command_msg.text.strip() + if cfg.files.auto_put_mode == "prompt" and caption_text: + if resolve_prompt is None: + try: + resolved = cfg.runtime.resolve_message( + text=caption_text, + reply_text=command_msg.reply_to_text, + ambient_context=ambient_context, + chat_id=command_msg.chat_id, + ) + except DirectiveError as exc: + await reply(text=f"error:\n{exc}") + return + else: + resolved = await resolve_prompt( + command_msg, caption_text, ambient_context + ) + if resolved is None: + return + saved_group = await _save_file_put_group( + cfg, + command_msg, + "", + ordered, + resolved.context, + topic_store, + ) + if saved_group is None: + return + if not saved_group.saved: + failure_text = _format_file_put_failures(saved_group.failed) + text = "failed to upload files." + if failure_text is not None: + text = f"{text}\n\n{failure_text}" + await reply(text=text) + return + if saved_group.failed: + failure_text = _format_file_put_failures(saved_group.failed) + if failure_text is not None: + await reply(text=f"some files failed to upload.\n\n{failure_text}") + if run_prompt is None: + await reply(text=FILE_PUT_USAGE) + return + paths = [ + item.rel_path.as_posix() + for item in saved_group.saved + if item.rel_path is not None + ] + files_text = "\n".join(f"- {path}" for path in paths) + prompt_base = resolved.prompt + annotation = f"[uploaded files]\n{files_text}" + if prompt_base and prompt_base.strip(): + prompt = f"{prompt_base}\n\n{annotation}" + else: + prompt = annotation + await run_prompt(command_msg, prompt, resolved) + return + if not caption_text: + await _handle_file_put_group( + cfg, + command_msg, + "", + ordered, + ambient_context, + topic_store, + ) + return + await reply(text=FILE_PUT_USAGE) diff --git a/src/takopi/telegram/commands/menu.py b/src/takopi/telegram/commands/menu.py new file mode 100644 index 0000000..1cbfbaa --- /dev/null +++ b/src/takopi/telegram/commands/menu.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ...commands import get_command +from ...config import ConfigError +from ...ids import RESERVED_COMMAND_IDS, is_valid_id +from ...logging import get_logger +from ...plugins import COMMAND_GROUP, list_entrypoints +from ...transport_runtime import TransportRuntime + +if TYPE_CHECKING: + from ..bridge import TelegramBridgeConfig + +logger = get_logger(__name__) + +_MAX_BOT_COMMANDS = 100 + + +def build_bot_commands( + runtime: TransportRuntime, *, include_file: bool = True +) -> list[dict[str, str]]: + commands: list[dict[str, str]] = [] + seen: set[str] = set() + for engine_id in runtime.available_engine_ids(): + cmd = engine_id.lower() + if cmd in seen: + continue + commands.append({"command": cmd, "description": f"use agent: {cmd}"}) + seen.add(cmd) + for alias in runtime.project_aliases(): + cmd = alias.lower() + if cmd in seen: + continue + if not is_valid_id(cmd): + logger.debug( + "startup.command_menu.skip_project", + alias=alias, + ) + continue + commands.append({"command": cmd, "description": f"work on: {cmd}"}) + seen.add(cmd) + allowlist = runtime.allowlist + for ep in list_entrypoints( + COMMAND_GROUP, + allowlist=allowlist, + reserved_ids=RESERVED_COMMAND_IDS, + ): + try: + backend = get_command(ep.name, allowlist=allowlist) + except ConfigError as exc: + logger.info( + "startup.command_menu.skip_command", + command=ep.name, + error=str(exc), + ) + continue + cmd = backend.id.lower() + if cmd in seen: + continue + if not is_valid_id(cmd): + logger.debug( + "startup.command_menu.skip_command_id", + command=cmd, + ) + continue + description = backend.description or f"command: {cmd}" + commands.append({"command": cmd, "description": description}) + seen.add(cmd) + if include_file and "file" not in seen: + commands.append({"command": "file", "description": "upload or fetch files"}) + seen.add("file") + if "cancel" not in seen: + commands.append({"command": "cancel", "description": "cancel run"}) + if len(commands) > _MAX_BOT_COMMANDS: + logger.warning( + "startup.command_menu.too_many", + count=len(commands), + limit=_MAX_BOT_COMMANDS, + ) + commands = commands[:_MAX_BOT_COMMANDS] + if not any(cmd["command"] == "cancel" for cmd in commands): + commands[-1] = {"command": "cancel", "description": "cancel run"} + return commands + + +def _reserved_commands(runtime: TransportRuntime) -> set[str]: + return { + *{engine.lower() for engine in runtime.engine_ids}, + *{alias.lower() for alias in runtime.project_aliases()}, + *RESERVED_COMMAND_IDS, + } + + +async def _set_command_menu(cfg: TelegramBridgeConfig) -> None: + commands = build_bot_commands(cfg.runtime, include_file=cfg.files.enabled) + if not commands: + return + try: + ok = await cfg.bot.set_my_commands(commands) + except Exception as exc: # noqa: BLE001 + logger.info( + "startup.command_menu.failed", + error=str(exc), + error_type=exc.__class__.__name__, + ) + return + if not ok: + logger.info("startup.command_menu.rejected") + return + logger.info( + "startup.command_menu.updated", + commands=[cmd["command"] for cmd in commands], + ) diff --git a/src/takopi/telegram/commands/parse.py b/src/takopi/telegram/commands/parse.py new file mode 100644 index 0000000..77060c8 --- /dev/null +++ b/src/takopi/telegram/commands/parse.py @@ -0,0 +1,30 @@ +from __future__ import annotations + + +def is_cancel_command(text: str) -> bool: + stripped = text.strip() + if not stripped: + return False + command = stripped.split(maxsplit=1)[0] + return command == "/cancel" or command.startswith("/cancel@") + + +def _parse_slash_command(text: str) -> tuple[str | None, str]: + stripped = text.lstrip() + if not stripped.startswith("/"): + return None, text + lines = stripped.splitlines() + if not lines: + return None, text + first_line = lines[0] + token, _, rest = first_line.partition(" ") + command = token[1:] + if not command: + return None, text + if "@" in command: + command = command.split("@", 1)[0] + args_text = rest + if len(lines) > 1: + tail = "\n".join(lines[1:]) + args_text = f"{args_text}\n{tail}" if args_text else tail + return command.lower(), args_text diff --git a/src/takopi/telegram/commands/reply.py b/src/takopi/telegram/commands/reply.py new file mode 100644 index 0000000..4916149 --- /dev/null +++ b/src/takopi/telegram/commands/reply.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from functools import partial +from typing import TYPE_CHECKING + +from ..bridge import send_plain +from ..types import TelegramIncomingMessage + +if TYPE_CHECKING: + from ..bridge import TelegramBridgeConfig + + +def make_reply( + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage +) -> Callable[..., Awaitable[None]]: + return partial( + send_plain, + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + thread_id=msg.thread_id, + ) diff --git a/src/takopi/telegram/commands/topics.py b/src/takopi/telegram/commands/topics.py new file mode 100644 index 0000000..3d2f399 --- /dev/null +++ b/src/takopi/telegram/commands/topics.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ...markdown import MarkdownParts +from ...transport import RenderedMessage, SendOptions +from ..chat_sessions import ChatSessionStore +from ..context import ( + _format_context, + _format_ctx_status, + _merge_topic_context, + _parse_project_branch_args, + _usage_ctx_set, + _usage_topic, +) +from ..files import split_command_args +from ..render import prepare_telegram +from ..topic_state import TopicStateStore +from ..topics import ( + _maybe_rename_topic, + _topic_key, + _topic_title, + _topics_chat_project, + _topics_command_error, +) +from ..types import TelegramIncomingMessage +from .reply import make_reply + +if TYPE_CHECKING: + from ..bridge import TelegramBridgeConfig + + +async def _handle_ctx_command( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + store: TopicStateStore, + *, + resolved_scope: str | None = None, + scope_chat_ids: frozenset[int] | None = None, +) -> None: + reply = make_reply(cfg, msg) + error = _topics_command_error( + cfg, + msg.chat_id, + resolved_scope=resolved_scope, + scope_chat_ids=scope_chat_ids, + ) + if error is not None: + await reply(text=error) + return + chat_project = _topics_chat_project(cfg, msg.chat_id) + tkey = _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids) + if tkey is None: + await reply(text="this command only works inside a topic.") + return + tokens = split_command_args(args_text) + action = tokens[0].lower() if tokens else "show" + if action in {"show", ""}: + snapshot = await store.get_thread(*tkey) + bound = snapshot.context if snapshot is not None else None + ambient = _merge_topic_context(chat_project=chat_project, bound=bound) + resolved = cfg.runtime.resolve_message( + text="", + reply_text=msg.reply_to_text, + chat_id=msg.chat_id, + ambient_context=ambient, + ) + text = _format_ctx_status( + cfg=cfg, + runtime=cfg.runtime, + bound=bound, + resolved=resolved.context, + context_source=resolved.context_source, + snapshot=snapshot, + chat_project=chat_project, + ) + await reply(text=text) + return + if action == "set": + rest = " ".join(tokens[1:]) + context, error = _parse_project_branch_args( + rest, + runtime=cfg.runtime, + require_branch=False, + chat_project=chat_project, + ) + if error is not None: + await reply( + text=f"error:\n{error}\n{_usage_ctx_set(chat_project=chat_project)}", + ) + return + if context is None: + await reply( + text=f"error:\n{_usage_ctx_set(chat_project=chat_project)}", + ) + return + await store.set_context(*tkey, context) + await _maybe_rename_topic( + cfg, + store, + chat_id=tkey[0], + thread_id=tkey[1], + context=context, + ) + await reply( + text=f"topic bound to `{_format_context(cfg.runtime, context)}`", + ) + return + if action == "clear": + await store.clear_context(*tkey) + await reply(text="topic binding cleared.") + return + await reply( + text="unknown `/ctx` command. use `/ctx`, `/ctx set`, or `/ctx clear`.", + ) + + +async def _handle_new_command( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + store: TopicStateStore, + *, + resolved_scope: str | None = None, + scope_chat_ids: frozenset[int] | None = None, +) -> None: + reply = make_reply(cfg, msg) + error = _topics_command_error( + cfg, + msg.chat_id, + resolved_scope=resolved_scope, + scope_chat_ids=scope_chat_ids, + ) + if error is not None: + await reply(text=error) + return + tkey = _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids) + if tkey is None: + await reply(text="this command only works inside a topic.") + return + await store.clear_sessions(*tkey) + await reply(text="cleared stored sessions for this topic.") + + +async def _handle_chat_new_command( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + store: ChatSessionStore, + session_key: tuple[int, int | None] | None, +) -> None: + reply = make_reply(cfg, msg) + if session_key is None: + await reply(text="no stored sessions to clear for this chat.") + return + await store.clear_sessions(session_key[0], session_key[1]) + if msg.chat_type == "private": + text = "cleared stored sessions for this chat." + else: + text = "cleared stored sessions for you in this chat." + await reply(text=text) + + +async def _handle_topic_command( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + store: TopicStateStore, + *, + resolved_scope: str | None = None, + scope_chat_ids: frozenset[int] | None = None, +) -> None: + reply = make_reply(cfg, msg) + error = _topics_command_error( + cfg, + msg.chat_id, + resolved_scope=resolved_scope, + scope_chat_ids=scope_chat_ids, + ) + if error is not None: + await reply(text=error) + return + chat_project = _topics_chat_project(cfg, msg.chat_id) + context, error = _parse_project_branch_args( + args_text, + runtime=cfg.runtime, + require_branch=True, + chat_project=chat_project, + ) + if error is not None or context is None: + usage = _usage_topic(chat_project=chat_project) + text = f"error:\n{error}\n{usage}" if error else usage + await reply(text=text) + return + existing = await store.find_thread_for_context(msg.chat_id, context) + if existing is not None: + await reply( + text=f"topic already exists for {_format_context(cfg.runtime, context)} " + "in this chat.", + ) + return + title = _topic_title(runtime=cfg.runtime, context=context) + created = await cfg.bot.create_forum_topic(msg.chat_id, title) + if created is None: + await reply(text="failed to create topic.") + return + thread_id = created.message_thread_id + await store.set_context( + msg.chat_id, + thread_id, + context, + topic_title=title, + ) + await reply(text=f"created topic `{title}`.") + bound_text = f"topic bound to `{_format_context(cfg.runtime, context)}`" + rendered_text, entities = prepare_telegram(MarkdownParts(header=bound_text)) + await cfg.exec_cfg.transport.send( + channel_id=msg.chat_id, + message=RenderedMessage(text=rendered_text, extra={"entities": entities}), + options=SendOptions(thread_id=thread_id), + ) diff --git a/src/takopi/telegram/loop.py b/src/takopi/telegram/loop.py index 624ee42..2040495 100644 --- a/src/takopi/telegram/loop.py +++ b/src/takopi/telegram/loop.py @@ -20,26 +20,25 @@ from ..transport import MessageRef from ..transport_runtime import ResolvedMessage from ..context import RunContext from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain -from .commands import ( +from .commands.agent import _handle_agent_command +from .commands.cancel import handle_callback_cancel, handle_cancel +from .commands.dispatch import _dispatch_command +from .commands.executor import _run_engine, _should_show_resume_line +from .commands.file_transfer import ( FILE_PUT_USAGE, - _dispatch_command, - _handle_agent_command, - _handle_chat_new_command, - _handle_ctx_command, _handle_file_command, _handle_file_put_default, - _handle_media_group, + _save_file_put, +) +from .commands.media import _handle_media_group +from .commands.menu import _reserved_commands, _set_command_menu +from .commands.parse import _parse_slash_command, is_cancel_command +from .commands.reply import make_reply +from .commands.topics import ( + _handle_chat_new_command, + _handle_ctx_command, _handle_new_command, _handle_topic_command, - _parse_slash_command, - _reserved_commands, - _save_file_put, - _should_show_resume_line, - _run_engine, - _set_command_menu, - handle_callback_cancel, - handle_cancel, - is_cancel_command, ) from .context import _merge_topic_context, _usage_ctx_set, _usage_topic from .topics import ( @@ -519,13 +518,7 @@ async def run_main_loop( text: str, ambient_context: RunContext | None, ) -> ResolvedMessage | None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = make_reply(cfg, msg) try: resolved = cfg.runtime.resolve_message( text=text, @@ -757,13 +750,7 @@ async def run_main_loop( if reply_id is not None else None ) - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=chat_id, - user_msg_id=user_msg_id, - thread_id=msg.thread_id, - ) + reply = make_reply(cfg, msg) text = msg.text if msg.voice is not None: text = await transcribe_voice( diff --git a/src/takopi/telegram/onboarding.py b/src/takopi/telegram/onboarding.py index 9323354..9525b7a 100644 --- a/src/takopi/telegram/onboarding.py +++ b/src/takopi/telegram/onboarding.py @@ -286,8 +286,8 @@ def _confirm(message: str, *, default: bool = True) -> bool | None: exit_with_result(event) @bindings.add(Keys.Any) - def other(event): - _ = event + def other(_event): + return None question = Question( PromptSession(get_prompt_tokens, key_bindings=bindings, style=merged_style).app diff --git a/src/takopi/telegram/outbox.py b/src/takopi/telegram/outbox.py new file mode 100644 index 0000000..3c31aab --- /dev/null +++ b/src/takopi/telegram/outbox.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any, TYPE_CHECKING +from collections.abc import Awaitable, Callable, Hashable + +import anyio + +from .client_api import RetryAfter + +SEND_PRIORITY = 0 +DELETE_PRIORITY = 1 +EDIT_PRIORITY = 2 + + +@dataclass(slots=True) +class OutboxOp: + execute: Callable[[], Awaitable[Any]] + priority: int + queued_at: float + chat_id: int | None + label: str | None = None + done: anyio.Event = field(default_factory=anyio.Event) + result: Any = None + + def set_result(self, result: Any) -> None: + if self.done.is_set(): + return + self.result = result + self.done.set() + + +class TelegramOutbox: + def __init__( + self, + *, + interval_for_chat: Callable[[int | None], float], + clock: Callable[[], float] = time.monotonic, + sleep: Callable[[float], Awaitable[None]] = anyio.sleep, + on_error: Callable[[OutboxOp, Exception], None] | None = None, + on_outbox_error: Callable[[Exception], None] | None = None, + ) -> None: + self._interval_for_chat = interval_for_chat + self._clock = clock + self._sleep = sleep + self._on_error = on_error + self._on_outbox_error = on_outbox_error + self._pending: dict[Hashable, OutboxOp] = {} + self._cond = anyio.Condition() + self._start_lock = anyio.Lock() + self._closed = False + self._tg: TaskGroup | None = None + self.next_at = 0.0 + self.retry_at = 0.0 + + async def ensure_worker(self) -> None: + async with self._start_lock: + if self._tg is not None or self._closed: + return + self._tg = await anyio.create_task_group().__aenter__() + self._tg.start_soon(self.run) + + async def enqueue(self, *, key: Hashable, op: OutboxOp, wait: bool = True) -> Any: + await self.ensure_worker() + async with self._cond: + if self._closed: + op.set_result(None) + return op.result + previous = self._pending.get(key) + if previous is not None: + op.queued_at = previous.queued_at + previous.set_result(None) + self._pending[key] = op + self._cond.notify() + if not wait: + return None + await op.done.wait() + return op.result + + async def drop_pending(self, *, key: Hashable) -> None: + async with self._cond: + pending = self._pending.pop(key, None) + if pending is not None: + pending.set_result(None) + self._cond.notify() + + async def close(self) -> None: + async with self._cond: + self._closed = True + self.fail_pending() + self._cond.notify_all() + if self._tg is not None: + await self._tg.__aexit__(None, None, None) + self._tg = None + + def fail_pending(self) -> None: + for pending in list(self._pending.values()): + pending.set_result(None) + self._pending.clear() + + def pick_locked(self) -> tuple[Hashable, OutboxOp] | None: + if not self._pending: + return None + return min( + self._pending.items(), + key=lambda item: (item[1].priority, item[1].queued_at), + ) + + async def execute_op(self, op: OutboxOp) -> Any: + try: + return await op.execute() + except Exception as exc: # noqa: BLE001 + if isinstance(exc, RetryAfter): + raise + if self._on_error is not None: + self._on_error(op, exc) + return None + + async def sleep_until(self, deadline: float) -> None: + delay = deadline - self._clock() + if delay > 0: + await self._sleep(delay) + + async def run(self) -> None: + cancel_exc = anyio.get_cancelled_exc_class() + try: + while True: + async with self._cond: + while not self._pending and not self._closed: + await self._cond.wait() + if self._closed and not self._pending: + return + blocked_until = max(self.next_at, self.retry_at) + if self._clock() < blocked_until: + await self.sleep_until(blocked_until) + continue + async with self._cond: + if self._closed and not self._pending: + return + picked = self.pick_locked() + if picked is None: + continue + key, op = picked + self._pending.pop(key, None) + started_at = self._clock() + try: + result = await self.execute_op(op) + except RetryAfter as exc: + self.retry_at = max(self.retry_at, self._clock() + exc.retry_after) + async with self._cond: + if self._closed: + op.set_result(None) + elif key not in self._pending: + self._pending[key] = op + self._cond.notify() + else: + op.set_result(None) + continue + self.next_at = started_at + self._interval_for_chat(op.chat_id) + op.set_result(result) + except cancel_exc: + return + except Exception as exc: # noqa: BLE001 + async with self._cond: + self._closed = True + self.fail_pending() + self._cond.notify_all() + if self._on_outbox_error is not None: + self._on_outbox_error(exc) + return + + +if TYPE_CHECKING: + from anyio.abc import TaskGroup +else: + TaskGroup = object diff --git a/src/takopi/telegram/parsing.py b/src/takopi/telegram/parsing.py new file mode 100644 index 0000000..821a634 --- /dev/null +++ b/src/takopi/telegram/parsing.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +from typing import Any +from collections.abc import AsyncIterator, Callable, Iterable + +import anyio + +from ..logging import get_logger +from .api_models import Update +from .client_api import BotClient +from .types import ( + TelegramCallbackQuery, + TelegramDocument, + TelegramIncomingMessage, + TelegramIncomingUpdate, + TelegramVoice, +) + +logger = get_logger(__name__) + + +def parse_incoming_update( + update: Update | dict[str, Any], + *, + chat_id: int | None = None, + chat_ids: set[int] | None = None, +) -> TelegramIncomingUpdate | None: + if isinstance(update, Update): + msg = update.message + callback_query = update.callback_query + else: + msg = update.get("message") + callback_query = update.get("callback_query") + + if isinstance(msg, dict): + return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids) + if isinstance(callback_query, dict): + return _parse_callback_query( + callback_query, + chat_id=chat_id, + chat_ids=chat_ids, + ) + return None + + +def _parse_incoming_message( + msg: dict[str, Any], + *, + chat_id: int | None = None, + chat_ids: set[int] | None = None, +) -> TelegramIncomingMessage | None: + def _parse_document_payload(payload: dict[str, Any]) -> TelegramDocument | None: + file_id = payload.get("file_id") + if not isinstance(file_id, str) or not file_id: + return None + return TelegramDocument( + file_id=file_id, + file_name=payload.get("file_name") + if isinstance(payload.get("file_name"), str) + else None, + mime_type=payload.get("mime_type") + if isinstance(payload.get("mime_type"), str) + else None, + file_size=payload.get("file_size") + if isinstance(payload.get("file_size"), int) + and not isinstance(payload.get("file_size"), bool) + else None, + raw=payload, + ) + + raw_text = msg.get("text") + text = raw_text if isinstance(raw_text, str) else None + caption = msg.get("caption") + if text is None and isinstance(caption, str): + text = caption + if text is None: + text = "" + file_command = False + if isinstance(text, str): + stripped = text.lstrip() + if stripped.startswith("/"): + token = stripped.split(maxsplit=1)[0] + file_command = token.startswith("/file") + voice_payload: TelegramVoice | None = None + voice = msg.get("voice") + if isinstance(voice, dict): + file_id = voice.get("file_id") + if not isinstance(file_id, str) or not file_id: + file_id = None + if file_id is not None: + voice_payload = TelegramVoice( + file_id=file_id, + mime_type=voice.get("mime_type") + if isinstance(voice.get("mime_type"), str) + else None, + file_size=voice.get("file_size") + if isinstance(voice.get("file_size"), int) + and not isinstance(voice.get("file_size"), bool) + else None, + duration=voice.get("duration") + if isinstance(voice.get("duration"), int) + and not isinstance(voice.get("duration"), bool) + else None, + raw=voice, + ) + if not isinstance(raw_text, str) and not isinstance(caption, str): + text = "" + document_payload: TelegramDocument | None = None + document = msg.get("document") + if isinstance(document, dict): + document_payload = _parse_document_payload(document) + if document_payload is None: + video = msg.get("video") + if isinstance(video, dict): + document_payload = _parse_document_payload(video) + if document_payload is None: + photo = msg.get("photo") + if isinstance(photo, list): + best: dict[str, Any] | None = None + best_score = -1 + for item in photo: + if not isinstance(item, dict): + continue + file_id = item.get("file_id") + if not isinstance(file_id, str) or not file_id: + continue + size = item.get("file_size") + if isinstance(size, int) and not isinstance(size, bool): + score = size + else: + width = item.get("width") + height = item.get("height") + if isinstance(width, int) and isinstance(height, int): + score = width * height + else: + score = 0 + if score > best_score: + best_score = score + best = item + if best is not None: + document_payload = _parse_document_payload(best) + if document_payload is None and file_command: + sticker = msg.get("sticker") + if isinstance(sticker, dict): + document_payload = _parse_document_payload(sticker) + has_text = isinstance(raw_text, str) or isinstance(caption, str) + if not has_text and voice_payload is None and document_payload is None: + return None + chat = msg.get("chat") + if not isinstance(chat, dict): + return None + msg_chat_id = chat.get("id") + if not isinstance(msg_chat_id, int): + return None + chat_type = chat.get("type") if isinstance(chat.get("type"), str) else None + is_forum = chat.get("is_forum") + if not isinstance(is_forum, bool): + is_forum = None + allowed = chat_ids + if allowed is None and chat_id is not None: + allowed = {chat_id} + if allowed is not None and msg_chat_id not in allowed: + return None + message_id = msg.get("message_id") + if not isinstance(message_id, int): + return None + reply = msg.get("reply_to_message") + reply_to_message_id = None + reply_to_text = None + if isinstance(reply, dict): + reply_to_message_id = ( + reply.get("message_id") + if isinstance(reply.get("message_id"), int) + else None + ) + reply_to_text = ( + reply.get("text") if isinstance(reply.get("text"), str) else None + ) + sender = msg.get("from") + sender_id = ( + sender.get("id") + if isinstance(sender, dict) and isinstance(sender.get("id"), int) + else None + ) + media_group_id = msg.get("media_group_id") + if not isinstance(media_group_id, str): + media_group_id = None + thread_id = msg.get("message_thread_id") + if isinstance(thread_id, bool) or not isinstance(thread_id, int): + thread_id = None + is_topic_message = msg.get("is_topic_message") + if not isinstance(is_topic_message, bool): + is_topic_message = None + return TelegramIncomingMessage( + transport="telegram", + chat_id=msg_chat_id, + message_id=message_id, + text=text, + reply_to_message_id=reply_to_message_id, + reply_to_text=reply_to_text, + sender_id=sender_id, + media_group_id=media_group_id, + thread_id=thread_id, + is_topic_message=is_topic_message, + chat_type=chat_type, + is_forum=is_forum, + voice=voice_payload, + document=document_payload, + raw=msg, + ) + + +def _parse_callback_query( + query: dict[str, Any], + *, + chat_id: int | None = None, + chat_ids: set[int] | None = None, +) -> TelegramCallbackQuery | None: + callback_id = query.get("id") + if not isinstance(callback_id, str) or not callback_id: + return None + msg = query.get("message") + if not isinstance(msg, dict): + return None + chat = msg.get("chat") + if not isinstance(chat, dict): + return None + msg_chat_id = chat.get("id") + if not isinstance(msg_chat_id, int): + return None + allowed = chat_ids + if allowed is None and chat_id is not None: + allowed = {chat_id} + if allowed is not None and msg_chat_id not in allowed: + return None + message_id = msg.get("message_id") + if not isinstance(message_id, int): + return None + data = query.get("data") if isinstance(query.get("data"), str) else None + sender = query.get("from") + sender_id = ( + sender.get("id") + if isinstance(sender, dict) and isinstance(sender.get("id"), int) + else None + ) + return TelegramCallbackQuery( + transport="telegram", + chat_id=msg_chat_id, + message_id=message_id, + callback_query_id=callback_id, + data=data, + sender_id=sender_id, + raw=query, + ) + + +async def poll_incoming( + bot: BotClient, + *, + chat_id: int | None = None, + chat_ids: Iterable[int] | Callable[[], Iterable[int]] | None = None, + offset: int | None = None, +) -> AsyncIterator[TelegramIncomingUpdate]: + while True: + updates = await bot.get_updates( + offset=offset, + timeout_s=50, + allowed_updates=["message", "callback_query"], + ) + if updates is None: + logger.info("loop.get_updates.failed") + await anyio.sleep(2) + continue + logger.debug("loop.updates", updates=updates) + resolved_chat_ids = chat_ids() if callable(chat_ids) else chat_ids + allowed = set(resolved_chat_ids) if resolved_chat_ids is not None else None + if allowed is None and chat_id is not None: + allowed = {chat_id} + for upd in updates: + offset = upd.update_id + 1 + msg = parse_incoming_update(upd, chat_ids=allowed) + if msg is not None: + yield msg diff --git a/src/takopi/telegram/state_store.py b/src/takopi/telegram/state_store.py index 35582dd..8d5c504 100644 --- a/src/takopi/telegram/state_store.py +++ b/src/takopi/telegram/state_store.py @@ -1,14 +1,13 @@ from __future__ import annotations -import json -import os +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable, Generic, Protocol, TypeVar +from typing import Any, Protocol import anyio import msgspec -T = TypeVar("T", bound="_VersionedState") +from ..utils.json_state import atomic_write_json class _Logger(Protocol): @@ -19,7 +18,7 @@ class _VersionedState(Protocol): version: int -class JsonStateStore(Generic[T]): +class JsonStateStore[T: _VersionedState]: def __init__( self, path: Path, @@ -84,11 +83,6 @@ class JsonStateStore(Generic[T]): self._state = payload def _save_locked(self) -> None: - self._path.parent.mkdir(parents=True, exist_ok=True) payload = msgspec.to_builtins(self._state) - tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp") - with open(tmp_path, "w", encoding="utf-8") as handle: - json.dump(payload, handle, indent=2, sort_keys=True) - handle.write("\n") - os.replace(tmp_path, self._path) + atomic_write_json(self._path, payload) self._mtime_ns = self._stat_mtime_ns() diff --git a/src/takopi/transport.py b/src/takopi/transport.py index f07b1ef..aba3949 100644 --- a/src/takopi/transport.py +++ b/src/takopi/transport.py @@ -1,11 +1,11 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Protocol, TypeAlias +from typing import Any, Protocol -ChannelId: TypeAlias = int | str -MessageId: TypeAlias = int | str -ThreadId: TypeAlias = int | str +type ChannelId = int | str +type MessageId = int | str +type ThreadId = int | str @dataclass(frozen=True, slots=True) diff --git a/src/takopi/transport_runtime.py b/src/takopi/transport_runtime.py index af1abf7..9200d70 100644 --- a/src/takopi/transport_runtime.py +++ b/src/takopi/transport_runtime.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Iterable, Mapping from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal, TypeAlias +from typing import Any, Literal from .config import ConfigError, ProjectsConfig from .context import RunContext @@ -19,7 +19,7 @@ from .router import AutoRouter, EngineStatus from .runner import Runner from .worktrees import WorktreeError, resolve_run_cwd -ContextSource: TypeAlias = Literal[ +type ContextSource = Literal[ "reply_ctx", "directives", "ambient", @@ -234,13 +234,13 @@ class TransportRuntime: project_key = ambient_context.project else: project_key = default_project - if branch is None: - if ( - ambient_context is not None - and ambient_context.branch is not None - and project_key == ambient_context.project - ): - branch = ambient_context.branch + if ( + branch is None + and ambient_context is not None + and ambient_context.branch is not None + and project_key == ambient_context.project + ): + branch = ambient_context.branch context: RunContext | None = None if project_key is not None or branch is not None: context = RunContext(project=project_key, branch=branch) diff --git a/src/takopi/transports.py b/src/takopi/transports.py index 21e625f..3c9268c 100644 --- a/src/takopi/transports.py +++ b/src/takopi/transports.py @@ -2,7 +2,8 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path -from typing import Iterable, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable +from collections.abc import Iterable from .backends import EngineBackend, SetupIssue from .plugins import TRANSPORT_GROUP, list_ids, load_plugin_backend @@ -34,7 +35,7 @@ class TransportBackend(Protocol): def interactive_setup(self, *, force: bool) -> bool: ... def lock_token( - self, *, transport_config: object, config_path: Path + self, *, transport_config: object, _config_path: Path ) -> str | None: ... def build_and_run( diff --git a/src/takopi/utils/json_state.py b/src/takopi/utils/json_state.py new file mode 100644 index 0000000..76e6414 --- /dev/null +++ b/src/takopi/utils/json_state.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any + + +def atomic_write_json( + path: Path, + payload: Any, + *, + indent: int = 2, + sort_keys: bool = True, +) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(f"{path.suffix}.tmp") + with open(tmp_path, "w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=indent, sort_keys=sort_keys) + handle.write("\n") + os.replace(tmp_path, path) diff --git a/tests/factories.py b/tests/factories.py index 7f190c6..ca45636 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -14,7 +14,7 @@ from takopi.model import ( def session_started(engine: str, value: str, title: str = "Codex") -> TakopiEvent: - engine_id = EngineId(engine) + engine_id: EngineId = engine return StartedEvent( engine=engine_id, resume=ResumeToken(engine=engine_id, value=value), @@ -29,7 +29,7 @@ def action_started( detail: dict[str, Any] | None = None, engine: str = "codex", ) -> TakopiEvent: - engine_id = EngineId(engine) + engine_id: EngineId = engine return ActionEvent( engine=engine_id, action=Action( @@ -50,7 +50,7 @@ def action_completed( detail: dict[str, Any] | None = None, engine: str = "codex", ) -> TakopiEvent: - engine_id = EngineId(engine) + engine_id: EngineId = engine return ActionEvent( engine=engine_id, action=Action( diff --git a/tests/plugin_fixtures.py b/tests/plugin_fixtures.py index ded3273..e179e19 100644 --- a/tests/plugin_fixtures.py +++ b/tests/plugin_fixtures.py @@ -1,7 +1,8 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, Iterable +from typing import Any +from collections.abc import Callable, Iterable @dataclass(frozen=True, slots=True) diff --git a/tests/test_exec_bridge.py b/tests/test_exec_bridge.py index 1a3ed5a..6a89f3c 100644 --- a/tests/test_exec_bridge.py +++ b/tests/test_exec_bridge.py @@ -5,7 +5,7 @@ import pytest from takopi.runner_bridge import ExecBridgeConfig, IncomingMessage, handle_message from takopi.markdown import MarkdownParts, MarkdownPresenter -from takopi.model import EngineId, ResumeToken, TakopiEvent +from takopi.model import ResumeToken, TakopiEvent from takopi.telegram.render import prepare_telegram from takopi.runners.codex import CodexRunner from takopi.runners.mock import Advance, Emit, Raise, Return, ScriptRunner, Wait @@ -13,7 +13,7 @@ from takopi.settings import load_settings, require_telegram from takopi.transport import MessageRef, RenderedMessage, SendOptions from tests.factories import action_completed, action_started -CODEX_ENGINE = EngineId("codex") +CODEX_ENGINE = "codex" class _FakeTransport: diff --git a/tests/test_exec_runner.py b/tests/test_exec_runner.py index 8a696c5..bff08ae 100644 --- a/tests/test_exec_runner.py +++ b/tests/test_exec_runner.py @@ -7,14 +7,13 @@ from collections.abc import AsyncIterator from takopi.model import ( ActionEvent, CompletedEvent, - EngineId, ResumeToken, StartedEvent, TakopiEvent, ) from takopi.runners.codex import CodexRunner, find_exec_only_flag -CODEX_ENGINE = EngineId("codex") +CODEX_ENGINE = "codex" @pytest.mark.anyio diff --git a/tests/test_git_utils.py b/tests/test_git_utils.py index 06834bd..70f9024 100644 --- a/tests/test_git_utils.py +++ b/tests/test_git_utils.py @@ -45,11 +45,10 @@ def test_resolve_default_base_prefers_master_over_main(monkeypatch) -> None: return None def _fake_ok(args, **kwargs): - if args == ["show-ref", "--verify", "--quiet", "refs/heads/master"]: - return True - if args == ["show-ref", "--verify", "--quiet", "refs/heads/main"]: - return True - return False + return args in ( + ["show-ref", "--verify", "--quiet", "refs/heads/master"], + ["show-ref", "--verify", "--quiet", "refs/heads/main"], + ) monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout) monkeypatch.setattr("takopi.utils.git.git_ok", _fake_ok) diff --git a/tests/test_runner_contract.py b/tests/test_runner_contract.py index 9f2cb64..e0ab958 100644 --- a/tests/test_runner_contract.py +++ b/tests/test_runner_contract.py @@ -7,7 +7,6 @@ from takopi.model import ( Action, ActionEvent, CompletedEvent, - EngineId, ResumeToken, StartedEvent, TakopiEvent, @@ -15,7 +14,7 @@ from takopi.model import ( from takopi.runners.mock import Emit, Return, ScriptRunner, Wait from tests.factories import action_started -CODEX_ENGINE = EngineId("codex") +CODEX_ENGINE = "codex" @pytest.mark.anyio @@ -84,7 +83,7 @@ async def test_runner_releases_lock_when_consumer_closes() -> None: gate = anyio.Event() runner = ScriptRunner([Wait(gate)], engine=CODEX_ENGINE, resume_value="sid") - gen = cast(AsyncGenerator[TakopiEvent, None], runner.run("hello", None)) + gen = cast(AsyncGenerator[TakopiEvent], runner.run("hello", None)) try: while True: evt = await anext(gen) @@ -94,7 +93,7 @@ async def test_runner_releases_lock_when_consumer_closes() -> None: await gen.aclose() gen2 = cast( - AsyncGenerator[TakopiEvent, None], + AsyncGenerator[TakopiEvent], runner.run("again", ResumeToken(engine=CODEX_ENGINE, value="sid")), ) try: diff --git a/tests/test_runner_utils.py b/tests/test_runner_utils.py index f03d01b..dde0aad 100644 --- a/tests/test_runner_utils.py +++ b/tests/test_runner_utils.py @@ -8,7 +8,6 @@ import takopi.runner as runner_module from takopi.model import ( ActionEvent, CompletedEvent, - EngineId, ResumeToken, StartedEvent, TakopiEvent, @@ -22,7 +21,7 @@ from takopi.runner import ( class _DummyRunner(ResumeTokenMixin, BaseRunner): - engine = EngineId("dummy") + engine = "dummy" resume_re = re.compile(r"(?im)^`?dummy resume (?P[^`\s]+)`?$") async def run_impl( @@ -39,7 +38,7 @@ class _DummyRunner(ResumeTokenMixin, BaseRunner): class _DummyJsonlRunner(JsonlSubprocessRunner): - engine = EngineId("dummy-jsonl") + engine = "dummy-jsonl" def command(self) -> str: return "dummy" @@ -67,7 +66,7 @@ class _DummyJsonlRunner(JsonlSubprocessRunner): class _BareJsonlRunner(JsonlSubprocessRunner): - engine = EngineId("bare-jsonl") + engine = "bare-jsonl" class _RunJsonlRunner(_DummyJsonlRunner): @@ -177,7 +176,7 @@ async def test_base_runner_run_locked_handles_resume() -> None: @pytest.mark.anyio async def test_base_runner_rejects_wrong_resume_engine() -> None: runner = _DummyRunner() - bad_resume = ResumeToken(engine=EngineId("other"), value="oops") + bad_resume = ResumeToken(engine="other", value="oops") with pytest.raises(RuntimeError): _ = [evt async for evt in runner.run("hello", bad_resume)] @@ -185,7 +184,7 @@ async def test_base_runner_rejects_wrong_resume_engine() -> None: @pytest.mark.anyio async def test_base_runner_run_impl_not_implemented() -> None: class _BareRunner(BaseRunner): - engine = EngineId("bare") + engine = "bare" runner = _BareRunner() with pytest.raises(NotImplementedError): @@ -204,7 +203,7 @@ def test_resume_token_format_and_extract() -> None: assert runner.extract_resume(None) is None with pytest.raises(RuntimeError): - runner.format_resume(ResumeToken(engine=EngineId("other"), value="bad")) + runner.format_resume(ResumeToken(engine="other", value="bad")) def test_session_lock_reuse() -> None: @@ -294,7 +293,7 @@ def test_jsonl_helpers() -> None: assert found == resume assert emit is False - mismatch = StartedEvent(engine=EngineId("other"), resume=resume, title="t") + mismatch = StartedEvent(engine="other", resume=resume, title="t") with pytest.raises(RuntimeError): runner.handle_started_event(mismatch, expected_session=None, found_session=None) diff --git a/tests/test_telegram_backend.py b/tests/test_telegram_backend.py index a67da12..b644a6e 100644 --- a/tests/test_telegram_backend.py +++ b/tests/test_telegram_backend.py @@ -6,7 +6,6 @@ from typing import Any import pytest from takopi.config import ProjectsConfig -from takopi.model import EngineId from takopi.router import AutoRouter, RunnerEntry from takopi.runners.mock import Return, ScriptRunner from takopi.settings import ( @@ -19,8 +18,8 @@ from takopi.transport_runtime import TransportRuntime def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None: - codex = EngineId("codex") - pi = EngineId("pi") + codex = "codex" + pi = "pi" runner = ScriptRunner([Return(answer="ok")], engine=codex) missing = ScriptRunner([Return(answer="ok")], engine=pi) router = AutoRouter( @@ -53,9 +52,9 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None: def test_build_startup_message_surfaces_unavailable_engine_reasons( tmp_path: Path, ) -> None: - codex = EngineId("codex") - pi = EngineId("pi") - claude = EngineId("claude") + codex = "codex" + pi = "pi" + claude = "claude" runner = ScriptRunner([Return(answer="ok")], engine=codex) bad_cfg = ScriptRunner([Return(answer="ok")], engine=pi) load_err = ScriptRunner([Return(answer="ok")], engine=claude) @@ -100,7 +99,7 @@ def test_telegram_backend_build_and_run_wires_config( encoding="utf-8", ) - codex = EngineId("codex") + codex = "codex" runner = ScriptRunner([Return(answer="ok")], engine=codex) router = AutoRouter( entries=[RunnerEntry(engine=codex, runner=runner)], diff --git a/tests/test_telegram_bridge.py b/tests/test_telegram_bridge.py index 5731763..0968141 100644 --- a/tests/test_telegram_bridge.py +++ b/tests/test_telegram_bridge.py @@ -6,8 +6,9 @@ import anyio import pytest from takopi import commands, plugins +from takopi.telegram.commands.executor import _CaptureTransport, _run_engine +from takopi.telegram.commands.file_transfer import _handle_file_get, _handle_file_put import takopi.telegram.loop as telegram_loop -import takopi.telegram.commands as telegram_commands import takopi.telegram.topics as telegram_topics from takopi.directives import parse_directives from takopi.telegram.api_models import ( @@ -39,7 +40,7 @@ from takopi.context import RunContext from takopi.config import ProjectConfig, ProjectsConfig from takopi.runner_bridge import ExecBridgeConfig, RunningTask from takopi.markdown import MarkdownPresenter -from takopi.model import EngineId, ResumeToken +from takopi.model import ResumeToken from takopi.progress import ProgressTracker from takopi.router import AutoRouter, RunnerEntry from takopi.transport_runtime import TransportRuntime @@ -52,7 +53,7 @@ from takopi.telegram.types import ( from takopi.transport import MessageRef, RenderedMessage, SendOptions from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints -CODEX_ENGINE = EngineId("codex") +CODEX_ENGINE = "codex" def _empty_projects() -> ProjectsConfig: @@ -905,9 +906,7 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None: ), ) - await telegram_commands._handle_file_put( - cfg, msg, "/proj uploads/hello.txt", None, None - ) + await _handle_file_put(cfg, msg, "/proj uploads/hello.txt", None, None) target = tmp_path / "uploads" / "hello.txt" assert target.read_bytes() == payload @@ -966,7 +965,7 @@ async def test_handle_file_get_sends_document_for_allowed_user( chat_type="supergroup", ) - await telegram_commands._handle_file_get(cfg, msg, "/proj hello.txt", None, None) + await _handle_file_get(cfg, msg, "/proj hello.txt", None, None) assert bot.document_calls assert bot.document_calls[0]["filename"] == "hello.txt" @@ -1263,7 +1262,7 @@ async def test_send_with_resume_reports_when_missing() -> None: @pytest.mark.anyio async def test_run_engine_hides_resume_line_in_topics() -> None: - transport = telegram_commands._CaptureTransport() + transport = _CaptureTransport() runner = ScriptRunner( [Return(answer="ok")], engine=CODEX_ENGINE, @@ -1279,7 +1278,7 @@ async def test_run_engine_hides_resume_line_in_topics() -> None: projects=_empty_projects(), ) - await telegram_commands._run_engine( + await _run_engine( exec_cfg=exec_cfg, runtime=runtime, running_tasks={}, @@ -1456,14 +1455,14 @@ async def test_run_main_loop_auto_resumes_topic_default_engine( 123, 77, ResumeToken(engine=CODEX_ENGINE, value="resume-codex") ) await store.set_session_resume( - 123, 77, ResumeToken(engine=EngineId("claude"), value="resume-claude") + 123, 77, ResumeToken(engine="claude", value="resume-claude") ) await store.set_default_engine(123, 77, "claude") transport = _FakeTransport() bot = _FakeBot() codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) - claude_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("claude")) + claude_runner = ScriptRunner([Return(answer="ok")], engine="claude") router = AutoRouter( entries=[ RunnerEntry(engine=codex_runner.engine, runner=codex_runner), @@ -1521,7 +1520,7 @@ async def test_run_main_loop_auto_resumes_topic_default_engine( assert codex_runner.calls == [] assert len(claude_runner.calls) == 1 assert claude_runner.calls[0][1] == ResumeToken( - engine=EngineId("claude"), value="resume-claude" + engine="claude", value="resume-claude" ) @@ -2443,7 +2442,7 @@ async def test_run_main_loop_command_uses_project_default_engine( transport = _FakeTransport() bot = _FakeBot() codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) - pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi")) + pi_runner = ScriptRunner([Return(answer="ok")], engine="pi") router = AutoRouter( entries=[ RunnerEntry(engine=codex_runner.engine, runner=codex_runner), @@ -2525,7 +2524,7 @@ async def test_run_main_loop_command_defaults_to_chat_project( transport = _FakeTransport() bot = _FakeBot() codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) - pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi")) + pi_runner = ScriptRunner([Return(answer="ok")], engine="pi") router = AutoRouter( entries=[ RunnerEntry(engine=codex_runner.engine, runner=codex_runner), diff --git a/tests/test_telegram_client.py b/tests/test_telegram_client.py index f34d962..3c0d4fd 100644 --- a/tests/test_telegram_client.py +++ b/tests/test_telegram_client.py @@ -3,6 +3,7 @@ import pytest from takopi.logging import setup_logging from takopi.telegram.client import TelegramClient, TelegramRetryAfter +from takopi.telegram.client_api import HttpBotClient @pytest.mark.anyio @@ -25,9 +26,9 @@ async def test_telegram_429_no_retry() -> None: client = httpx.AsyncClient(transport=transport) try: - tg = TelegramClient("123:abcDEF_ghij", http_client=client) + api = HttpBotClient("123:abcDEF_ghij", http_client=client) with pytest.raises(TelegramRetryAfter) as exc: - await tg._post("sendMessage", {"chat_id": 1, "text": "hi"}) + await api._post("sendMessage", {"chat_id": 1, "text": "hi"}) finally: await client.aclose() @@ -49,8 +50,8 @@ async def test_no_token_in_logs_on_http_error( client = httpx.AsyncClient(transport=transport) try: - tg = TelegramClient(token, http_client=client) - await tg._post("getUpdates", {"timeout": 1}) + api = HttpBotClient(token, http_client=client) + await api._post("getUpdates", {"timeout": 1}) finally: await client.aclose() @@ -79,9 +80,9 @@ async def test_telegram_429_no_retry_post_form() -> None: client = httpx.AsyncClient(transport=transport) try: - tg = TelegramClient("123:abcDEF_ghij", http_client=client) + api = HttpBotClient("123:abcDEF_ghij", http_client=client) with pytest.raises(TelegramRetryAfter) as exc: - await tg._post_form( + await api._post_form( "sendDocument", {"chat_id": 1}, files={"document": ("note.txt", b"hi")}, @@ -102,9 +103,9 @@ async def test_telegram_429_defaults_retry_after_on_bad_body() -> None: client = httpx.AsyncClient(transport=transport) try: - tg = TelegramClient("123:abcDEF_ghij", http_client=client) + api = HttpBotClient("123:abcDEF_ghij", http_client=client) with pytest.raises(TelegramRetryAfter) as exc: - await tg._post("sendMessage", {"chat_id": 1, "text": "hi"}) + await api._post("sendMessage", {"chat_id": 1, "text": "hi"}) finally: await client.aclose() @@ -124,8 +125,8 @@ async def test_telegram_ok_false_returns_none() -> None: client = httpx.AsyncClient(transport=transport) try: - tg = TelegramClient("123:abcDEF_ghij", http_client=client) - result = await tg._post("getUpdates", {"timeout": 1}) + api = HttpBotClient("123:abcDEF_ghij", http_client=client) + result = await api._post("getUpdates", {"timeout": 1}) finally: await client.aclose() @@ -141,8 +142,8 @@ async def test_telegram_invalid_payload_returns_none() -> None: client = httpx.AsyncClient(transport=transport) try: - tg = TelegramClient("123:abcDEF_ghij", http_client=client) - result = await tg._post("getUpdates", {"timeout": 1}) + api = HttpBotClient("123:abcDEF_ghij", http_client=client) + result = await api._post("getUpdates", {"timeout": 1}) finally: await client.aclose() diff --git a/tests/test_telegram_engine_defaults.py b/tests/test_telegram_engine_defaults.py index f4d48ce..d43a7f2 100644 --- a/tests/test_telegram_engine_defaults.py +++ b/tests/test_telegram_engine_defaults.py @@ -4,7 +4,6 @@ import pytest from takopi.config import ProjectConfig, ProjectsConfig from takopi.context import RunContext -from takopi.model import EngineId from takopi.router import AutoRouter, RunnerEntry from takopi.runners.mock import Return, ScriptRunner from takopi.telegram.chat_prefs import ChatPrefsStore @@ -15,8 +14,8 @@ from takopi.transport_runtime import TransportRuntime @pytest.mark.anyio async def test_resolve_engine_for_message_sources(tmp_path) -> None: - codex = ScriptRunner([Return(answer="ok")], engine=EngineId("codex")) - pi = ScriptRunner([Return(answer="ok")], engine=EngineId("pi")) + codex = ScriptRunner([Return(answer="ok")], engine="codex") + pi = ScriptRunner([Return(answer="ok")], engine="pi") router = AutoRouter( entries=[ RunnerEntry(engine=codex.engine, runner=codex), @@ -42,7 +41,7 @@ async def test_resolve_engine_for_message_sources(tmp_path) -> None: resolved = await resolve_engine_for_message( runtime=runtime, context=RunContext(project="proj"), - explicit_engine=EngineId("codex"), + explicit_engine="codex", chat_id=1, topic_key=(1, 10), topic_store=topic_store, diff --git a/tests/test_transport_registry.py b/tests/test_transport_registry.py index bad214a..7569576 100644 --- a/tests/test_transport_registry.py +++ b/tests/test_transport_registry.py @@ -15,8 +15,8 @@ class DummyTransport: def interactive_setup(self, *, force: bool) -> bool: raise NotImplementedError - def lock_token(self, *, transport_config: object, config_path): - _ = transport_config, config_path + def lock_token(self, *, transport_config: object, _config_path): + _ = transport_config, _config_path raise NotImplementedError def build_and_run(