From f856338b94e9e3d379103889435cf879b139e192 Mon Sep 17 00:00:00 2001 From: banteg <4562643+banteg@users.noreply.github.com> Date: Fri, 9 Jan 2026 03:23:57 +0400 Subject: [PATCH] feat: plugins and public api (#71) --- docs/adding-a-runner.md | 25 +- docs/architecture.md | 73 ++- docs/developing.md | 34 +- docs/plugins.md | 307 ++++++++++++ docs/public-api.md | 252 ++++++++++ pyproject.toml | 9 + readme.md | 13 +- src/takopi/api.py | 84 ++++ src/takopi/cli.py | 294 ++++++++++-- src/takopi/commands.py | 151 ++++++ src/takopi/directives.py | 143 ++++++ src/takopi/engines.py | 112 +++-- src/takopi/ids.py | 15 + src/takopi/plugins.py | 283 +++++++++++ src/takopi/settings.py | 14 + src/takopi/telegram/__init__.py | 7 +- src/takopi/telegram/backend.py | 38 +- src/takopi/telegram/bridge.py | 778 +++++++++++++++++-------------- src/takopi/telegram/client.py | 8 +- src/takopi/telegram/types.py | 16 + src/takopi/transport.py | 12 - src/takopi/transport_runtime.py | 192 ++++++++ src/takopi/transports.py | 80 ++-- tests/conftest.py | 7 + tests/plugin_fixtures.py | 47 ++ tests/test_command_registry.py | 47 ++ tests/test_engine_discovery.py | 47 +- tests/test_plugins.py | 184 ++++++++ tests/test_projects_config.py | 6 +- tests/test_telegram_bridge.py | 378 ++++++++++++--- tests/test_transport_registry.py | 56 ++- tests/test_transport_runtime.py | 45 ++ 32 files changed, 3135 insertions(+), 622 deletions(-) create mode 100644 docs/plugins.md create mode 100644 docs/public-api.md create mode 100644 src/takopi/api.py create mode 100644 src/takopi/commands.py create mode 100644 src/takopi/directives.py create mode 100644 src/takopi/ids.py create mode 100644 src/takopi/plugins.py create mode 100644 src/takopi/telegram/types.py create mode 100644 src/takopi/transport_runtime.py create mode 100644 tests/plugin_fixtures.py create mode 100644 tests/test_command_registry.py create mode 100644 tests/test_plugins.py create mode 100644 tests/test_transport_runtime.py diff --git a/docs/adding-a-runner.md b/docs/adding-a-runner.md index e0aa35a..1eee034 100644 --- a/docs/adding-a-runner.md +++ b/docs/adding-a-runner.md @@ -5,10 +5,15 @@ This guide explains how to add a **new engine runner** to Takopi. A *runner* is the adapter between an engine-specific CLI (Codex, Claude Code, …) and Takopi’s **normalized event model** (`StartedEvent`, `ActionEvent`, `CompletedEvent`). +If you are building an external plugin package, read `docs/plugins.md` first. + Takopi is designed so that adding a runner usually means **adding one new module** under `src/takopi/runners/` plus a small **msgspec schema** module under `src/takopi/schemas/`— no changes to the bridge, renderer, or CLI. +When writing code intended for plugins, prefer importing from `takopi.api` +instead of internal modules. + The walkthrough below uses an **imaginary engine** named **Acme** (`acme`) and intentionally mirrors the patterns used in `runners/claude.py`. @@ -74,6 +79,12 @@ Choose a stable engine id string. This string becomes: - The CLI subcommand (`takopi acme`) - The `ResumeToken.engine` +Engine ids must match the plugin ID regex: + +``` +^[a-z0-9_]{1,32}$ +``` + For Acme we’ll use: - Engine id: `"acme"` @@ -114,8 +125,18 @@ src/takopi/runners/ acme.py # ← new ``` -Takopi discovers engines by importing modules in `takopi.runners` and looking for a -module-level `BACKEND: EngineBackend` (see `takopi.engines`). +Takopi discovers engines via **entrypoints**. Every engine backend must be exposed +as an entrypoint under `takopi.engine_backends`, and the entrypoint name must match +the backend id. + +For in-repo engines, add an entrypoint in `pyproject.toml`: + +```toml +[project.entry-points."takopi.engine_backends"] +acme = "takopi.runners.acme:BACKEND" +``` + +For external plugins, use your package’s `pyproject.toml` with the same group. --- diff --git a/docs/architecture.md b/docs/architecture.md index bb88862..8d16d34 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -9,10 +9,19 @@ flowchart TB cli_desc["Entry point, config loading, lock file"] end + subgraph Plugins["Plugin Layer"] + entrypoints[plugins.py
entrypoint discovery] + engines[engines.py] + transports[transports.py] + commands[commands.py] + api[api.py
public plugin API] + end + subgraph Orchestration["Orchestration Layer"] router[AutoRouter
router.py] scheduler[ThreadScheduler
scheduler.py] projects[ProjectsConfig
config.py] + runtime[TransportRuntime
transport_runtime.py] end subgraph Bridge["Bridge Layer"] @@ -42,8 +51,18 @@ flowchart TB cli --> router cli --> scheduler cli --> projects + cli --> engines + cli --> transports + cli --> commands + engines --> entrypoints + transports --> entrypoints + commands --> entrypoints + router --> runtime + projects --> runtime router --> tg_bridge scheduler --> tg_bridge + runtime --> tg_bridge + tg_bridge --> commands tg_bridge --> runner_bridge runner_bridge --> runner_proto runner_proto --> runners @@ -59,6 +78,21 @@ flowchart TB --- +## Plugin Architecture + +Takopi discovers plugins via Python entrypoints and keeps loading lazy: + +- **Engine backends** (`takopi.engine_backends`) +- **Transport backends** (`takopi.transport_backends`) +- **Command backends** (`takopi.command_backends`) + +Entrypoint names become plugin IDs, are validated up front (reserved names, regex), +and are only loaded when needed. The public surface for plugin authors lives in +`takopi.api`, while transports and commands interact with core routing via +`TransportRuntime`. + +--- + ## Domain Model ```mermaid @@ -120,19 +154,27 @@ sequenceDiagram participant RunnerBridge as runner_bridge.py participant Runner participant AgentCLI as Agent CLI + participant Command as Command Plugin User->>Telegram: Send message Telegram->>Bridge: poll_incoming() - Bridge->>Bridge: Parse directives
(/engine, /project, @branch) - Bridge->>Bridge: Extract resume token
from reply - Bridge->>Bridge: Resolve worktree
(if @branch) + Bridge->>Bridge: Parse slash command + alt Command plugin + Bridge->>Command: handle(ctx) + Command->>RunnerBridge: run_one/run_many (optional) + RunnerBridge->>Telegram: Send progress/final + else Default routing + Bridge->>Bridge: Parse directives
(/engine, /project, @branch) + Bridge->>Bridge: Extract resume token
from reply + Bridge->>Bridge: Resolve worktree
(if @branch) - Bridge->>Scheduler: enqueue(ThreadJob) - Scheduler->>RunnerBridge: handle_message() + Bridge->>Scheduler: enqueue(ThreadJob) + Scheduler->>RunnerBridge: handle_message() - RunnerBridge->>Telegram: Send progress message - RunnerBridge->>Runner: run(prompt, resume) + RunnerBridge->>Telegram: Send progress message + RunnerBridge->>Runner: run(prompt, resume) + end Runner->>AgentCLI: Spawn subprocess @@ -217,8 +259,14 @@ sequenceDiagram flowchart TD cli[cli.py] --> config[config.py] cli --> engines[engines.py] + cli --> transports[transports.py] + cli --> commands[commands.py] cli --> lockfile[lockfile.py] + engines --> plugins[plugins.py] + transports --> plugins + commands --> plugins + engines --> backends[backends.py] backends --> runners[runners/] @@ -244,7 +292,10 @@ flowchart TD pi --> pi_s cli --> router[router.py] - router --> tg_bridge[telegram/bridge.py] + tg_bridge --> runtime[transport_runtime.py] + runtime --> router + runtime --> config + tg_bridge --> commands runner --> runner_bridge[runner_bridge.py] runner_bridge --> tg_bridge @@ -274,12 +325,13 @@ flowchart LR subgraph toml_contents["takopi.toml"] direction TB - global["transport
default_engine"] + global["transport
default_engine
default_project"] telegram_cfg["[transports.telegram]
bot_token = ...
chat_id = ..."] + plugins_cfg["[plugins]
enabled = [\"...\"]"] + plugins_extra["[plugins.mycommand]
setting = ..."] claude_cfg["[claude]
model = ..."] codex_cfg["[codex]
model = ..."] projects_cfg["[projects.alias]
path = ...
worktrees_dir = ...
default_engine = ..."] - default_proj["[projects]
default = ..."] end toml --> toml_contents @@ -335,6 +387,7 @@ flowchart TD | Layer | Components | Responsibility | |-------|------------|----------------| | **CLI** | `cli.py` | Entry point, config, lock | +| **Plugins** | `plugins.py`, `engines.py`, `transports.py`, `commands.py`, `api.py` | Entrypoint discovery, plugin loading, public API boundary | | **Orchestration** | `router.py`, `scheduler.py`, `config.py` | Engine selection, job queuing, project config | | **Bridge** | `telegram/bridge.py`, `runner_bridge.py` | Message handling, execution coordination | | **Runner** | `runner.py`, `runners/*.py`, `schemas/*.py` | Agent CLI subprocess, JSONL parsing, event translation | diff --git a/docs/developing.md b/docs/developing.md index a3f1d72..6211e33 100644 --- a/docs/developing.md +++ b/docs/developing.md @@ -77,9 +77,14 @@ Defines `Transport`, `MessageRef`, `RenderedMessage`, and `SendOptions`. Defines a renderer that converts `ProgressState` into `RenderedMessage` outputs. -### `transports.py` - Transport registry +### `transport_runtime.py` - Transport runtime facade -Defines the transport backend protocol, registry helpers, and built-in transport registration. +Provides the `TransportRuntime` helper used by transport backends to resolve +messages, select runners, and format context without depending on internal types. + +### `transports.py` - Transport backend loading + +Defines the transport backend protocol and entrypoint-backed loading helpers. ### `config_migrations.py` - Config migrations @@ -165,9 +170,32 @@ See `docs/transports/telegram.md` for outbox behavior, rate limiting, and retry Defines `EngineBackend`, `SetupIssue`, and the `EngineConfig` type used by runner modules. +### `plugins.py` - Entrypoint discovery + +Centralizes plugin discovery and lazy loading: + +- lists IDs without importing plugin modules +- loads a specific entrypoint on demand +- captures load errors for diagnostics +- filters by enabled list (distribution names) + +### `commands.py` - Command backend loading + +Defines the command backend protocol, command context/executor helpers, and +entrypoint-backed loading for slash-command plugins. + +### `ids.py` - Plugin ID validation + +Defines the shared ID regex used for plugin IDs and Telegram command names. + +### `api.py` - Public plugin API + +Re-exports the supported plugin surface from `takopi.api` (stable API boundary). + ### `engines.py` - Engine backend discovery -Auto-discovers runner modules in `takopi.runners` that export `BACKEND`. +Loads engine backends via entrypoints (`takopi.engine_backends`), with lazy loading +and enabled list support. ### `runners/` - Runner implementations diff --git a/docs/plugins.md b/docs/plugins.md new file mode 100644 index 0000000..ad01dfc --- /dev/null +++ b/docs/plugins.md @@ -0,0 +1,307 @@ +# Plugins + +Takopi supports **entrypoint-based plugins** for: + +- **Engine backends** (new runner implementations) +- **Transport backends** (new chat/command transports) +- **Command backends** (custom `/command` handlers) + +Plugins are **discovered lazily**: Takopi lists IDs without importing plugin code, +and loads a plugin only when it is needed (or when you explicitly request it). + +This keeps `takopi --help` fast and prevents broken plugins from bricking the CLI. + +See `public-api.md` for the stable API surface you should depend on. + +--- + +## Entrypoint groups + +Takopi uses two Python entrypoint groups: + +```toml +[project.entry-points."takopi.engine_backends"] +myengine = "myengine.backend:BACKEND" + +[project.entry-points."takopi.transport_backends"] +mytransport = "mytransport.backend:BACKEND" + +[project.entry-points."takopi.command_backends"] +mycommand = "mycommand.backend:BACKEND" +``` + +**Rules:** + +- The entrypoint **name** is the plugin ID. +- The entrypoint value must resolve to a **backend object**: + - Engine backend -> `EngineBackend` + - Transport backend -> `TransportBackend` +- The backend object **must** have `id == entrypoint name`. + +Takopi validates this at load time and will report errors via `takopi plugins --load`. + +--- + +## ID rules + +Plugin IDs are used in the CLI and (for engines/projects) in Telegram commands. +They must match: + +``` +^[a-z0-9_]{1,32}$ +``` + +If an ID does not match, it is skipped and reported as an error. + +**Reserved IDs (engines):** + +- `cancel` (core chat command) +- `init`, `plugins` (CLI commands) + +Engines using these IDs are skipped and reported as errors. + +**Reserved IDs (commands):** + +- `cancel`, `init`, `plugins` +- Any engine id or project alias (checked at runtime) + +Command backends using reserved IDs are skipped and reported as errors. + +--- + +## Enabling plugins + +Takopi supports a simple enabled list to control which plugins are visible. + +```toml +[plugins] +enabled = ["takopi-transport-slack", "takopi-engine-acme"] +auto_install = false +``` + +- `enabled = []` (default) -> load all installed plugins. +- If `enabled` is non-empty, **only distributions with matching names** are visible. +- Distribution names are taken from package metadata (case-insensitive). +- If a plugin has no resolvable distribution name and an enabled list is set, it is hidden. +- `auto_install` is **reserved** and not implemented yet. + +This enabled list affects: + +- Engine subcommands registered in the CLI +- `takopi plugins` output +- Runtime resolution of engines/transports/commands + +--- + +## Discovering plugins + +Use the CLI to inspect plugins: + +```sh +takopi plugins +takopi plugins --load +``` + +Behavior: + +- `takopi plugins` lists discovered entrypoints **without loading them**. +- `--load` loads each plugin to validate type and surface import errors. +- Errors are shown at the end, grouped by engine/transport and distribution. +- If `[plugins] enabled` is set, entries are still listed but marked `enabled`/`disabled`. + +--- + +## Engine backend plugins + +Engine plugins implement a runner for a new engine CLI and expose +an `EngineBackend` object. + +Minimal example: + +```py +# myengine/backend.py +from __future__ import annotations + +from pathlib import Path + +from takopi.api import EngineBackend, EngineConfig, Runner + +def build_runner(config: EngineConfig, config_path: Path) -> Runner: + _ = config_path + # Parse config if needed; raise ConfigError for invalid config. + return MyEngineRunner(config) + +BACKEND = EngineBackend( + id="myengine", + build_runner=build_runner, + cli_cmd="myengine", + install_cmd="pip install myengine", +) +``` + +`EngineConfig` is the raw config table (dict) from `takopi.toml`: + +```toml +[myengine] +model = "..." +``` + +Read it with `settings.engine_config("myengine", config_path=...)` in Takopi, +or just consume the dict directly in your runner builder. + +See `public-api.md` for the runner contract and helper classes like +`JsonlSubprocessRunner` and `EventFactory`. + +--- + +## Transport backend plugins + +Transport plugins connect Takopi to new messaging systems (Slack, Discord, etc). + +You must provide a `TransportBackend` object with: + +- `id` and `description` +- `check_setup()` -> returns `SetupResult` (issues + config path) +- `interactive_setup()` -> optional interactive setup flow +- `lock_token()` -> token fingerprinting for config locks +- `build_and_run()` -> build transport and start the main loop + +Minimal skeleton: + +```py +# mytransport/backend.py +from __future__ import annotations + +from pathlib import Path + +from takopi.api import ( + EngineBackend, + SetupResult, + TransportBackend, + TransportRuntime, +) + +class MyTransportBackend: + id = "mytransport" + description = "MyTransport bot" + + def check_setup( + self, engine_backend: EngineBackend, *, transport_override: str | None = None + ) -> SetupResult: + _ = engine_backend, transport_override + return SetupResult(issues=[], config_path=Path("takopi.toml")) + + def interactive_setup(self, *, force: bool) -> bool: + _ = force + return True + + def lock_token( + self, *, transport_config: dict[str, object], config_path: Path + ) -> str | None: + _ = transport_config, config_path + return None + + def build_and_run( + self, + *, + transport_config: dict[str, object], + config_path: Path, + runtime: TransportRuntime, + final_notify: bool, + default_engine_override: str | None, + ) -> None: + _ = ( + transport_config, + config_path, + runtime, + final_notify, + default_engine_override, + ) + raise NotImplementedError + +BACKEND = MyTransportBackend() +``` + +For most transports, you will want to call `handle_message()` from `takopi.api` +inside your message loop. That function implements progress updates, resume handling, +and cancellation semantics. + +--- + +## Command backend plugins + +Command plugins add custom `/command` handlers. A command only runs when the +message starts with `/command` and does **not** collide with engine ids, +project aliases, or reserved command names. + +Minimal example: + +```py +# mycommand/backend.py +from __future__ import annotations + +from takopi.api import CommandContext, CommandResult, RunRequest + +class MultiCommand: + id = "multi" + description = "run the prompt on every engine" + + async def handle(self, ctx: CommandContext) -> CommandResult | None: + prompt = ctx.args_text.strip() + if not prompt: + return CommandResult(text="usage: /multi ") + requests = [ + RunRequest(prompt=prompt, engine=engine) + for engine in ctx.runtime.available_engine_ids() + ] + results = await ctx.executor.run_many( + requests, + mode="capture", + parallel=True, + ) + blocks = [] + for result in results: + text = result.message.text if result.message else "no output" + blocks.append(f"## {result.engine}\n{text}") + return CommandResult(text="\n\n".join(blocks)) + +BACKEND = MultiCommand() +``` + +### Command plugin configuration + +Configure command plugins under `[plugins.]`: + +```toml +[plugins.multi] +engines = ["codex", "claude"] +``` + +The parsed dict is available as `ctx.plugin_config` inside `handle()`. + +--- + +## Versioning & compatibility + +Takopi exposes a **stable plugin API** via `takopi.api`. + +- `TAKOPI_PLUGIN_API_VERSION = 1` is the current API version. +- Depend on a compatible Takopi version range, for example: + +```toml +dependencies = ["takopi>=0.11,<0.12"] +``` + +When the plugin API changes, Takopi will bump the API version and document +any compatibility guidance. + +--- + +## Troubleshooting + +Common issues: + +- **Plugin missing from CLI**: check the enabled list in `[plugins] enabled`. +- **Plugin not listed**: verify entrypoint group and ID regex. +- **Load failures**: run `takopi plugins --load` and inspect errors. +- **ID mismatch**: ensure `BACKEND.id == entrypoint name`. diff --git a/docs/public-api.md b/docs/public-api.md new file mode 100644 index 0000000..657ff68 --- /dev/null +++ b/docs/public-api.md @@ -0,0 +1,252 @@ +# Public Plugin API + +Takopi's **public plugin API** is exported from: + +``` +takopi.api +``` + +Anything not imported from `takopi.api` should be considered **internal** and +subject to change. The API version is tracked by `TAKOPI_PLUGIN_API_VERSION`. + +--- + +## Versioning + +- Current API version: `TAKOPI_PLUGIN_API_VERSION = 1` +- Plugins should pin to a compatible Takopi range, e.g.: + +```toml +dependencies = ["takopi>=0.11,<0.12"] +``` + +--- + +## Exported symbols + +### Engine backends and runners + +| Symbol | Purpose | +|--------|---------| +| `EngineBackend` | Declares an engine backend (id + runner builder) | +| `EngineConfig` | Dict-based engine config table | +| `Runner` | Runner protocol | +| `BaseRunner` | Helper base class with resume locking | +| `JsonlSubprocessRunner` | Helper for JSONL-streaming CLIs | +| `EventFactory` | Helper for building takopi events | + +### Transport backends + +| Symbol | Purpose | +|--------|---------| +| `TransportBackend` | Transport backend protocol | +| `SetupIssue` | Setup issue for onboarding / validation | +| `SetupResult` | Setup issues + config path | +| `Transport` | Transport protocol (send/edit/delete) | +| `Presenter` | Renders progress to `RenderedMessage` | +| `RenderedMessage` | Rendered text + transport metadata | +| `SendOptions` | Reply/notify/replace flags | +| `MessageRef` | Transport-specific message reference | +| `TransportRuntime` | Transport runtime facade (routers/projects hidden) | +| `ResolvedMessage` | Parsed prompt + resume/context resolution | +| `ResolvedRunner` | Runner selection result | + +### Command backends + +| Symbol | Purpose | +|--------|---------| +| `CommandBackend` | Slash command plugin protocol | +| `CommandContext` | Context passed to a command handler | +| `CommandExecutor` | Helper to send messages or run engines | +| `CommandResult` | Simple response payload for a command | +| `RunRequest` | Engine run request used by commands | +| `RunResult` | Engine run result (captured output) | +| `RunMode` | `"emit"` (send) or `"capture"` (collect) | + +### Core types and helpers + +| Symbol | Purpose | +|--------|---------| +| `EngineId` | Engine id type alias | +| `ResumeToken` | Resume token (engine + value) | +| `StartedEvent` / `ActionEvent` / `CompletedEvent` | Core event types | +| `Action` | Action metadata for `ActionEvent` | +| `RunContext` | Project/branch context | +| `ConfigError` | Configuration error type | +| `DirectiveError` | Error raised when parsing directives | +| `RunnerUnavailableError` | Router error when a runner is unavailable | + +### Bridge helpers (for transport plugins) + +| Symbol | Purpose | +|--------|---------| +| `ExecBridgeConfig` | Transport + presenter config | +| `IncomingMessage` | Normalized incoming message | +| `RunningTask` / `RunningTasks` | Per-message run coordination | +| `handle_message()` | Core message handler used by transports | + +--- + +## Runner contract (engine plugins) + +Runners emit events in a strict sequence (see `tests/test_runner_contract.py`): + +- Exactly **one** `StartedEvent` +- Exactly **one** `CompletedEvent` +- `CompletedEvent` is **last** +- `CompletedEvent.resume == StartedEvent.resume` + +Action events are optional. The minimal valid run is: + +``` +StartedEvent -> CompletedEvent +``` + +### Resume tokens + +Runners own the resume format: + +- `format_resume(token)` returns a command line users can paste +- `extract_resume(text)` parses resume tokens from user text +- `is_resume_line(line)` lets Takopi strip resume lines before running + +--- + +## EngineBackend + +```py +EngineBackend( + id: str, + build_runner: Callable[[EngineConfig, Path], Runner], + cli_cmd: str | None = None, + install_cmd: str | None = None, +) +``` + +- `id` must match the entrypoint name and the ID regex. +- `build_runner` should raise `ConfigError` for invalid config. +- `cli_cmd` is used to check whether the engine CLI is on `PATH`. +- `install_cmd` is surfaced in onboarding output. + +--- + +## TransportBackend + +```py +class TransportBackend(Protocol): + id: str + description: str + + def check_setup(...) -> SetupResult: ... + def interactive_setup(self, *, force: bool) -> bool: ... + def lock_token( + self, *, transport_config: dict[str, object], config_path: Path + ) -> str | None: ... + def build_and_run( + self, + *, + transport_config: dict[str, object], + config_path: Path, + runtime: TransportRuntime, + final_notify: bool, + default_engine_override: str | None, + ) -> None: ... +``` + +Transport backends are responsible for: + +- Validating config and onboarding users (`check_setup`, `interactive_setup`) +- Providing a lock token so Takopi can prevent parallel runs +- Starting the transport loop in `build_and_run` + +--- + +## CommandBackend + +```py +class CommandBackend(Protocol): + id: str + description: str + + async def handle(self, ctx: CommandContext) -> CommandResult | None: ... +``` + +Command handlers receive a `CommandContext` with: + +- the raw command text and parsed args +- the original message + reply metadata +- `config_path` for the active `takopi.toml` (when known) +- `plugin_config` from `[plugins.]` (dict, defaults to `{}`) +- `runtime` (engine/project resolution) +- `executor` (send messages or run engines) + +Use `ctx.executor.run_one(...)` or `ctx.executor.run_many(...)` to reuse Takopi's +engine pipeline. Use `mode="capture"` to collect results and build a custom reply. + +--- + +## TransportRuntime helpers + +`TransportRuntime` keeps transports away from internal router/project types. Key helpers: + +- `resolve_message(text, reply_text)` → `ResolvedMessage` (prompt, resume token, context) +- `resolve_engine(engine_override, context)` → `EngineId` +- `resolve_runner(resume_token, engine_override)` → `ResolvedRunner` (runner + availability info) +- `resolve_run_cwd(context)` → `Path | None` (raises `ConfigError` for project/worktree issues) +- `format_context_line(context)` → `str | None` +- `available_engine_ids()` / `missing_engine_ids()` / `engine_ids` / `default_engine` +- `project_aliases()` +- `config_path` (active config path when available) +- `plugin_config(plugin_id)` → `dict` from `[plugins.]` + +--- + +## Bridge usage (transport plugins) + +Most transports can delegate message handling to `handle_message()`. Use +`TransportRuntime` to resolve messages and select a runner: + +```py +from takopi.api import ( + ExecBridgeConfig, + IncomingMessage, + RunningTask, + RunningTasks, + TransportRuntime, + handle_message, +) + +async def on_message(...): + resolved = runtime.resolve_message(text=text, reply_text=reply_text) + entry = runtime.resolve_runner( + resume_token=resolved.resume_token, + engine_override=resolved.engine_override, + ) + context_line = runtime.format_context_line(resolved.context) + incoming = IncomingMessage( + channel_id=..., + message_id=..., + text=..., + reply_to=..., + ) + await handle_message( + exec_cfg, + runner=entry.runner, + incoming=incoming, + resume_token=resolved.resume_token, + context=resolved.context, + context_line=context_line, + strip_resume_line=runtime.is_resume_line, + running_tasks=running_tasks, + on_thread_known=on_thread_known, + ) +``` + +`handle_message()` implements: + +- Progress updates and throttling +- Resume handling +- Cancellation propagation +- Final rendering + +This keeps transport backends thin and consistent with core behavior. diff --git a/pyproject.toml b/pyproject.toml index 06e0936..1a34d46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,15 @@ Issues = "https://github.com/banteg/takopi/issues" [project.scripts] takopi = "takopi.cli:main" +[project.entry-points."takopi.engine_backends"] +codex = "takopi.runners.codex:BACKEND" +claude = "takopi.runners.claude:BACKEND" +opencode = "takopi.runners.opencode:BACKEND" +pi = "takopi.runners.pi:BACKEND" + +[project.entry-points."takopi.transport_backends"] +telegram = "takopi.telegram.backend:BACKEND" + [build-system] requires = ["uv_build>=0.9.18,<0.10.0"] build-backend = "uv_build" diff --git a/readme.md b/readme.md index 2ff8159..038ed0d 100644 --- a/readme.md +++ b/readme.md @@ -123,10 +123,10 @@ takopi opencode takopi pi ``` -list available transports (and override in a run): +list available plugins (engines/transports), and override in a run: ```sh -takopi transports +takopi plugins takopi --transport telegram ``` @@ -145,6 +145,15 @@ default: progress is silent, final answer is sent as a new message so you receiv if you prefer no notifications, `--no-final-notify` edits the progress message into the final answer. +## plugins + +Takopi supports entrypoint-based plugins for engines and transports. + +See: + +- `docs/plugins.md` +- `docs/public-api.md` + ## notes * the bot only responds to the configured `chat_id` (private or group) diff --git a/src/takopi/api.py b/src/takopi/api.py new file mode 100644 index 0000000..8f8379c --- /dev/null +++ b/src/takopi/api.py @@ -0,0 +1,84 @@ +"""Stable public API for Takopi plugins.""" + +from __future__ import annotations + +from .backends import EngineBackend, EngineConfig, SetupIssue +from .commands import ( + CommandBackend, + CommandContext, + CommandExecutor, + CommandResult, + RunMode, + RunRequest, + RunResult, +) +from .config import ConfigError +from .context import RunContext +from .directives import DirectiveError +from .events import EventFactory +from .model import ( + Action, + ActionEvent, + CompletedEvent, + EngineId, + ResumeToken, + StartedEvent, +) +from .presenter import Presenter +from .router import RunnerUnavailableError +from .runner import BaseRunner, JsonlSubprocessRunner, Runner +from .runner_bridge import ( + ExecBridgeConfig, + IncomingMessage, + RunningTask, + RunningTasks, + handle_message, +) +from .transport import MessageRef, RenderedMessage, SendOptions, Transport +from .transport_runtime import ResolvedMessage, ResolvedRunner, TransportRuntime +from .transports import SetupResult, TransportBackend + +TAKOPI_PLUGIN_API_VERSION = 1 + +__all__ = [ + "Action", + "ActionEvent", + "BaseRunner", + "CompletedEvent", + "ConfigError", + "CommandBackend", + "CommandContext", + "CommandExecutor", + "CommandResult", + "EngineBackend", + "EngineConfig", + "EngineId", + "ExecBridgeConfig", + "EventFactory", + "IncomingMessage", + "JsonlSubprocessRunner", + "MessageRef", + "DirectiveError", + "Presenter", + "RenderedMessage", + "ResumeToken", + "RunMode", + "RunRequest", + "RunResult", + "ResolvedMessage", + "ResolvedRunner", + "RunContext", + "Runner", + "RunnerUnavailableError", + "RunningTask", + "RunningTasks", + "SendOptions", + "SetupIssue", + "SetupResult", + "StartedEvent", + "TAKOPI_PLUGIN_API_VERSION", + "Transport", + "TransportBackend", + "TransportRuntime", + "handle_message", +] diff --git a/src/takopi/cli.py b/src/takopi/cli.py index de29ea9..1131398 100644 --- a/src/takopi/cli.py +++ b/src/takopi/cli.py @@ -4,6 +4,7 @@ import os import shutil import sys from collections.abc import Callable +from importlib.metadata import EntryPoint from pathlib import Path import typer @@ -12,7 +13,9 @@ from . import __version__ from .backends import EngineBackend from .config import ConfigError, load_or_init_config, write_config from .config_migrations import migrate_config -from .engines import get_backend, list_backends +from .commands import get_command +from .engines import get_backend, list_backend_ids +from .ids import RESERVED_COMMAND_IDS, RESERVED_ENGINE_IDS from .lockfile import LockError, LockHandle, acquire_lock, token_fingerprint from .logging import get_logger, setup_logging from .router import AutoRouter, RunnerEntry @@ -22,12 +25,46 @@ from .settings import ( load_settings_if_exists, validate_settings_data, ) -from .transports import SetupResult, get_transport, list_transports +from .plugins import ( + COMMAND_GROUP, + ENGINE_GROUP, + TRANSPORT_GROUP, + entrypoint_distribution_name, + get_load_errors, + is_entrypoint_allowed, + list_entrypoints, + normalize_allowlist, +) +from .transports import SetupResult, get_transport +from .transport_runtime import TransportRuntime from .utils.git import resolve_default_base, resolve_main_worktree_root logger = get_logger(__name__) +def _load_settings_optional() -> tuple[TakopiSettings | None, Path | None]: + try: + loaded = load_settings_if_exists() + except ConfigError: + return None, None + if loaded is None: + return None, None + return loaded + + +def _resolve_plugins_allowlist( + settings: TakopiSettings | None, +) -> list[str] | None: + if settings is None: + return None + enabled = [ + value.strip() + for value in settings.plugins.enabled + if isinstance(value, str) and value.strip() + ] + return enabled or None + + def _print_version_and_exit() -> None: typer.echo(__version__) raise typer.Exit() @@ -72,16 +109,16 @@ def acquire_config_lock(config_path: Path, token: str | None) -> LockHandle: raise typer.Exit(code=1) from exc -def _default_engine_for_setup(override: str | None) -> str: +def _default_engine_for_setup( + override: str | None, + *, + settings: TakopiSettings | None, + config_path: Path | None, +) -> str: if override: return override - try: - loaded = load_settings_if_exists() - except ConfigError: + if settings is None or config_path is None: return "codex" - if loaded is None: - return "codex" - settings, config_path = loaded value = settings.default_engine if not isinstance(value, str) or not value.strip(): raise ConfigError( @@ -95,7 +132,7 @@ def _resolve_default_engine( override: str | None, settings: TakopiSettings, config_path: Path, - backends: list[EngineBackend], + engine_ids: list[str], ) -> str: default_engine = override or settings.default_engine or "codex" if not isinstance(default_engine, str) or not default_engine.strip(): @@ -103,9 +140,8 @@ def _resolve_default_engine( f"Invalid `default_engine` in {config_path}; expected a non-empty string." ) default_engine = default_engine.strip() - backend_ids = {backend.id for backend in backends} - if default_engine not in backend_ids: - available = ", ".join(sorted(backend_ids)) + if default_engine not in engine_ids: + available = ", ".join(sorted(engine_ids)) raise ConfigError( f"Unknown default engine {default_engine!r}. Available: {available}." ) @@ -176,6 +212,30 @@ def _build_router( return AutoRouter(entries=entries, default_engine=default_engine) +def _load_backends( + *, + engine_ids: list[str], + allowlist: list[str] | None, + default_engine: str, +) -> list[EngineBackend]: + backends: list[EngineBackend] = [] + load_issues: list[str] = [] + for engine_id in engine_ids: + try: + backend = get_backend(engine_id, allowlist=allowlist) + except ConfigError as exc: + if engine_id == default_engine: + raise + load_issues.append(f"{engine_id}: {exc}") + continue + backends.append(backend) + if not backends: + raise ConfigError("No engine backends are available.") + for issue in load_issues: + logger.warning("setup.warning", issue=issue) + return backends + + def _config_path_display(path: Path) -> str: home = Path.home() try: @@ -214,10 +274,16 @@ def _run_auto_router( setup_logging(debug=debug) lock_handle: LockHandle | None = None try: - default_engine = _default_engine_for_setup(default_engine_override) - engine_backend = get_backend(default_engine) + settings_hint, config_hint = _load_settings_optional() + allowlist = _resolve_plugins_allowlist(settings_hint) + default_engine = _default_engine_for_setup( + default_engine_override, + settings=settings_hint, + config_path=config_hint, + ) + engine_backend = get_backend(default_engine, allowlist=allowlist) transport_id = _resolve_transport_id(transport_override) - transport_backend = get_transport(transport_id) + transport_backend = get_transport(transport_id, allowlist=allowlist) except ConfigError as e: typer.echo(f"error: {e}", err=True) raise typer.Exit(code=1) @@ -227,8 +293,14 @@ def _run_auto_router( raise typer.Exit(code=1) if not transport_backend.interactive_setup(force=True): raise typer.Exit(code=1) - default_engine = _default_engine_for_setup(default_engine_override) - engine_backend = get_backend(default_engine) + settings_hint, config_hint = _load_settings_optional() + allowlist = _resolve_plugins_allowlist(settings_hint) + default_engine = _default_engine_for_setup( + default_engine_override, + settings=settings_hint, + config_path=config_hint, + ) + engine_backend = get_backend(default_engine, allowlist=allowlist) setup = transport_backend.check_setup( engine_backend, transport_override=transport_override, @@ -243,15 +315,27 @@ def _run_auto_router( default=False, ) if run_onboard and transport_backend.interactive_setup(force=True): - default_engine = _default_engine_for_setup(default_engine_override) - engine_backend = get_backend(default_engine) + settings_hint, config_hint = _load_settings_optional() + allowlist = _resolve_plugins_allowlist(settings_hint) + default_engine = _default_engine_for_setup( + default_engine_override, + settings=settings_hint, + config_path=config_hint, + ) + engine_backend = get_backend(default_engine, allowlist=allowlist) setup = transport_backend.check_setup( engine_backend, transport_override=transport_override, ) elif transport_backend.interactive_setup(force=False): - default_engine = _default_engine_for_setup(default_engine_override) - engine_backend = get_backend(default_engine) + settings_hint, config_hint = _load_settings_optional() + allowlist = _resolve_plugins_allowlist(settings_hint) + default_engine = _default_engine_for_setup( + default_engine_override, + settings=settings_hint, + config_path=config_hint, + ) + engine_backend = get_backend(default_engine, allowlist=allowlist) setup = transport_backend.check_setup( engine_backend, transport_override=transport_override, @@ -267,17 +351,23 @@ def _run_auto_router( settings, config_path = load_settings() if transport_override and transport_override != settings.transport: settings = settings.model_copy(update={"transport": transport_override}) - backends = list_backends() + allowlist = _resolve_plugins_allowlist(settings) + engine_ids = list_backend_ids(allowlist=allowlist) projects = settings.to_projects_config( config_path=config_path, - engine_ids=[backend.id for backend in backends], + engine_ids=engine_ids, reserved=("cancel",), ) default_engine = _resolve_default_engine( override=default_engine_override, settings=settings, config_path=config_path, - backends=backends, + engine_ids=engine_ids, + ) + backends = _load_backends( + engine_ids=engine_ids, + allowlist=allowlist, + default_engine=default_engine, ) router = _build_router( settings=settings, @@ -285,18 +375,27 @@ def _run_auto_router( backends=backends, default_engine=default_engine, ) + transport_config = settings.transport_config( + settings.transport, config_path=config_path + ) lock_token = transport_backend.lock_token( - settings=settings, + transport_config=transport_config, config_path=config_path, ) lock_handle = acquire_config_lock(config_path, lock_token) + runtime = TransportRuntime( + router=router, + projects=projects, + allowlist=allowlist, + config_path=config_path, + plugin_configs=settings.plugins.model_extra, + ) transport_backend.build_and_run( final_notify=final_notify, default_engine_override=default_engine_override, - settings=settings, config_path=config_path, - router=router, - projects=projects, + transport_config=transport_config, + runtime=runtime, ) except ConfigError as e: typer.echo(f"error: {e}", err=True) @@ -364,8 +463,9 @@ def init( default_alias = _default_alias_from_path(project_path) alias = _prompt_alias(alias, default_alias=default_alias) - engine_ids = [backend.id for backend in list_backends()] settings = validate_settings_data(config, config_path=config_path) + allowlist = _resolve_plugins_allowlist(settings) + engine_ids = list_backend_ids(allowlist=allowlist) projects_cfg = settings.to_projects_config( config_path=config_path, engine_ids=engine_ids, @@ -414,25 +514,92 @@ def init( typer.echo(f"saved project {alias!r} to {_config_path_display(config_path)}") -def transports_cmd() -> None: - """List available transport backends.""" - ids = list_transports() - for transport_id in ids: - typer.echo(transport_id) +def _print_entrypoints( + label: str, entrypoints: list[EntryPoint], *, allowlist: set[str] | None +) -> None: + typer.echo(f"{label}:") + if not entrypoints: + typer.echo(" (none)") + return + for ep in entrypoints: + dist = entrypoint_distribution_name(ep) or "unknown" + status = "" + if allowlist is not None: + allowed = is_entrypoint_allowed(ep, allowlist) + status = " enabled" if allowed else " disabled" + typer.echo(f" {ep.name} ({dist}){status}") -app = typer.Typer( - add_completion=False, - invoke_without_command=True, - help="Run takopi with auto-router (subcommands override the default engine).", -) +def plugins_cmd( + load: bool = typer.Option( + False, + "--load/--no-load", + help="Load plugins to validate and surface import errors.", + ), +) -> None: + """List discovered plugins and optionally validate them.""" + settings_hint, _ = _load_settings_optional() + allowlist = _resolve_plugins_allowlist(settings_hint) + + allowlist_set = normalize_allowlist(allowlist) + engine_eps = list_entrypoints( + ENGINE_GROUP, + reserved_ids=RESERVED_ENGINE_IDS, + ) + transport_eps = list_entrypoints(TRANSPORT_GROUP) + command_eps = list_entrypoints( + COMMAND_GROUP, + reserved_ids=RESERVED_COMMAND_IDS, + ) + + _print_entrypoints("engine backends", engine_eps, allowlist=allowlist_set) + _print_entrypoints("transport backends", transport_eps, allowlist=allowlist_set) + _print_entrypoints("command backends", command_eps, allowlist=allowlist_set) + + if load: + for ep in engine_eps: + if allowlist_set is not None and not is_entrypoint_allowed( + ep, allowlist_set + ): + continue + try: + get_backend(ep.name, allowlist=allowlist) + except ConfigError: + continue + for ep in transport_eps: + if allowlist_set is not None and not is_entrypoint_allowed( + ep, allowlist_set + ): + continue + try: + get_transport(ep.name, allowlist=allowlist) + except ConfigError: + continue + for ep in command_eps: + if allowlist_set is not None and not is_entrypoint_allowed( + ep, allowlist_set + ): + continue + try: + get_command(ep.name, allowlist=allowlist) + except ConfigError: + continue + + errors = get_load_errors() + if errors: + typer.echo("errors:") + for err in errors: + group = err.group + if group == ENGINE_GROUP: + group = "engine" + elif group == TRANSPORT_GROUP: + group = "transport" + elif group == COMMAND_GROUP: + group = "command" + dist = err.distribution or "unknown" + typer.echo(f" {group} {err.name} ({dist}): {err.error}") -app.command(name="init")(init) -app.command(name="transports")(transports_cmd) - - -@app.callback() def app_main( ctx: typer.Context, version: bool = typer.Option( @@ -510,16 +677,43 @@ def make_engine_cmd(engine_id: str) -> Callable[..., None]: return _cmd -def register_engine_commands() -> None: - for backend in list_backends(): - help_text = f"Run with the {backend.id} engine." - app.command(name=backend.id, help=help_text)(make_engine_cmd(backend.id)) +def _engine_ids_for_cli() -> list[str]: + allowlist: list[str] | None = None + try: + config, _ = load_or_init_config() + except ConfigError: + return list_backend_ids() + raw_plugins = config.get("plugins") + if isinstance(raw_plugins, dict): + enabled = raw_plugins.get("enabled") + if isinstance(enabled, list): + allowlist = [ + value.strip() + for value in enabled + if isinstance(value, str) and value.strip() + ] + if not allowlist: + allowlist = None + return list_backend_ids(allowlist=allowlist) -register_engine_commands() +def create_app() -> typer.Typer: + app = typer.Typer( + add_completion=False, + invoke_without_command=True, + help="Run takopi with auto-router (subcommands override the default engine).", + ) + app.command(name="init")(init) + app.command(name="plugins")(plugins_cmd) + app.callback()(app_main) + for engine_id in _engine_ids_for_cli(): + help_text = f"Run with the {engine_id} engine." + app.command(name=engine_id, help=help_text)(make_engine_cmd(engine_id)) + return app def main() -> None: + app = create_app() app() diff --git a/src/takopi/commands.py b/src/takopi/commands.py new file mode 100644 index 0000000..6b8d26c --- /dev/null +++ b/src/takopi/commands.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal, Protocol, overload, runtime_checkable + +from .config import ConfigError +from .context import RunContext +from .ids import RESERVED_COMMAND_IDS +from .model import EngineId +from .plugins import ( + COMMAND_GROUP, + PluginLoadFailed, + PluginNotFound, + load_entrypoint, + list_ids, +) +from .transport import MessageRef, RenderedMessage +from .transport_runtime import TransportRuntime + +RunMode = Literal["emit", "capture"] + + +@dataclass(frozen=True, slots=True) +class RunRequest: + prompt: str + engine: EngineId | None = None + context: RunContext | None = None + + +@dataclass(frozen=True, slots=True) +class RunResult: + engine: EngineId + message: RenderedMessage | None + + +class CommandExecutor(Protocol): + async def send( + self, + message: RenderedMessage | str, + *, + reply_to: MessageRef | None = None, + notify: bool = True, + ) -> MessageRef | None: ... + + async def run_one( + self, request: RunRequest, *, mode: RunMode = "emit" + ) -> RunResult: ... + + async def run_many( + self, + requests: Sequence[RunRequest], + *, + mode: RunMode = "emit", + parallel: bool = False, + ) -> list[RunResult]: ... + + +@dataclass(frozen=True, slots=True) +class CommandContext: + command: str + text: str + args_text: str + args: tuple[str, ...] + message: MessageRef + reply_to: MessageRef | None + reply_text: str | None + config_path: Path | None + plugin_config: dict[str, Any] + runtime: TransportRuntime + executor: CommandExecutor + + +@dataclass(frozen=True, slots=True) +class CommandResult: + text: str + notify: bool = True + reply_to: MessageRef | None = None + + +@runtime_checkable +class CommandBackend(Protocol): + id: str + description: str + + async def handle(self, ctx: CommandContext) -> CommandResult | None: ... + + +def _validate_command_backend(backend: object, ep) -> None: + if not isinstance(backend, CommandBackend): + raise TypeError(f"{ep.value} is not a CommandBackend") + if backend.id != ep.name: + raise ValueError( + f"{ep.value} command id {backend.id!r} does not match entrypoint {ep.name!r}" + ) + + +@overload +def get_command( + command_id: str, + *, + allowlist: Iterable[str] | None = None, + required: Literal[True] = True, +) -> CommandBackend: ... + + +@overload +def get_command( + command_id: str, + *, + allowlist: Iterable[str] | None = None, + required: Literal[False], +) -> CommandBackend | None: ... + + +def get_command( + command_id: str, + *, + allowlist: Iterable[str] | None = None, + required: bool = True, +) -> CommandBackend | None: + if command_id.lower() in RESERVED_COMMAND_IDS: + raise ConfigError(f"Command id {command_id!r} is reserved.") + try: + backend = load_entrypoint( + COMMAND_GROUP, + command_id, + allowlist=allowlist, + validator=_validate_command_backend, + ) + except PluginNotFound as exc: + if not required: + return None + if exc.available: + available = ", ".join(exc.available) + message = f"Unknown command {command_id!r}. Available: {available}." + else: + message = f"Unknown command {command_id!r}." + raise ConfigError(message) from exc + except PluginLoadFailed as exc: + raise ConfigError(f"Failed to load command {command_id!r}: {exc}") from exc + return backend + + +def list_command_ids(*, allowlist: Iterable[str] | None = None) -> list[str]: + return list_ids( + COMMAND_GROUP, + allowlist=allowlist, + reserved_ids=RESERVED_COMMAND_IDS, + ) diff --git a/src/takopi/directives.py b/src/takopi/directives.py new file mode 100644 index 0000000..1cd1dfb --- /dev/null +++ b/src/takopi/directives.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from .config import ProjectsConfig +from .context import RunContext +from .model import EngineId + + +@dataclass(frozen=True, slots=True) +class ParsedDirectives: + prompt: str + engine: EngineId | None + project: str | None + branch: str | None + + +class DirectiveError(RuntimeError): + pass + + +def parse_directives( + text: str, + *, + engine_ids: tuple[EngineId, ...], + projects: ProjectsConfig, +) -> ParsedDirectives: + if not text: + return ParsedDirectives(prompt="", engine=None, project=None, branch=None) + + lines = text.splitlines() + idx = next((i for i, line in enumerate(lines) if line.strip()), None) + if idx is None: + return ParsedDirectives(prompt=text, engine=None, project=None, branch=None) + + line = lines[idx].lstrip() + tokens = line.split() + if not tokens: + return ParsedDirectives(prompt=text, engine=None, project=None, branch=None) + + engine_map = {engine.lower(): engine for engine in engine_ids} + project_map = {alias.lower(): alias for alias in projects.projects} + + engine: EngineId | None = None + project: str | None = None + branch: str | None = None + consumed = 0 + + for token in tokens: + if token.startswith("/"): + name = token[1:] + if "@" in name: + name = name.split("@", 1)[0] + if not name: + break + key = name.lower() + engine_candidate = engine_map.get(key) + project_candidate = project_map.get(key) + if engine_candidate is not None: + if engine is not None: + raise DirectiveError("multiple engine directives") + engine = engine_candidate + consumed += 1 + continue + if project_candidate is not None: + if project is not None: + raise DirectiveError("multiple project directives") + project = project_candidate + consumed += 1 + continue + break + if token.startswith("@"): + value = token[1:] + if not value: + break + if branch is not None: + raise DirectiveError("multiple @branch directives") + branch = value + consumed += 1 + continue + break + + if consumed == 0: + return ParsedDirectives(prompt=text, engine=None, project=None, branch=None) + + if consumed < len(tokens): + remainder = " ".join(tokens[consumed:]) + lines[idx] = remainder + else: + lines.pop(idx) + + prompt = "\n".join(lines).strip() + return ParsedDirectives( + prompt=prompt, engine=engine, project=project, branch=branch + ) + + +def parse_context_line( + text: str | None, *, projects: ProjectsConfig +) -> RunContext | None: + if not text: + return None + ctx: RunContext | None = None + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("`") and stripped.endswith("`") and len(stripped) > 1: + stripped = stripped[1:-1].strip() + elif stripped.startswith("`"): + stripped = stripped[1:].strip() + elif stripped.endswith("`"): + stripped = stripped[:-1].strip() + if not stripped.lower().startswith("ctx:"): + continue + content = stripped.split(":", 1)[1].strip() + if not content: + continue + tokens = content.split() + if not tokens: + continue + project = tokens[0] + branch = None + if len(tokens) >= 2: + if tokens[1] == "@" and len(tokens) >= 3: + branch = tokens[2] + elif tokens[1].startswith("@"): + branch = tokens[1][1:] + project_key = project.lower() + if project_key not in projects.projects: + raise DirectiveError(f"unknown project {project!r} in ctx line") + ctx = RunContext(project=project_key, branch=branch) + return ctx + + +def format_context_line( + context: RunContext | None, *, projects: ProjectsConfig +) -> str | None: + if context is None or context.project is None: + return None + project_cfg = projects.projects.get(context.project) + alias = project_cfg.alias if project_cfg is not None else context.project + if context.branch: + return f"`ctx: {alias} @ {context.branch}`" + return f"`ctx: {alias}`" diff --git a/src/takopi/engines.py b/src/takopi/engines.py index e7e449b..0065d40 100644 --- a/src/takopi/engines.py +++ b/src/takopi/engines.py @@ -1,71 +1,67 @@ from __future__ import annotations -import importlib -import pkgutil -from collections.abc import Mapping -from functools import cache -from pathlib import Path -from types import MappingProxyType -from typing import Any +from typing import Iterable -from .backends import EngineBackend, EngineConfig +from .backends import EngineBackend from .config import ConfigError +from .plugins import ( + ENGINE_GROUP, + PluginLoadFailed, + PluginNotFound, + load_entrypoint, + list_ids, +) +from .ids import RESERVED_ENGINE_IDS -def _discover_backends() -> dict[str, EngineBackend]: - import takopi.runners as runners_pkg +def _validate_engine_backend(backend: object, ep) -> None: + if not isinstance(backend, EngineBackend): + raise TypeError(f"{ep.value} is not an EngineBackend") + if backend.id != ep.name: + raise ValueError( + f"{ep.value} engine id {backend.id!r} does not match entrypoint {ep.name!r}" + ) - backends: dict[str, EngineBackend] = {} - prefix = runners_pkg.__name__ + "." - for module_info in pkgutil.iter_modules(runners_pkg.__path__, prefix): - module_name = module_info.name - mod = importlib.import_module(module_name) +def get_backend( + engine_id: str, *, allowlist: Iterable[str] | None = None +) -> EngineBackend: + if engine_id.lower() in RESERVED_ENGINE_IDS: + raise ConfigError(f"Engine id {engine_id!r} is reserved.") + try: + backend = load_entrypoint( + ENGINE_GROUP, + engine_id, + allowlist=allowlist, + validator=_validate_engine_backend, + ) + except PluginNotFound as exc: + if exc.available: + available = ", ".join(exc.available) + message = f"Unknown engine {engine_id!r}. Available: {available}." + else: + message = f"Unknown engine {engine_id!r}." + raise ConfigError(message) from exc + except PluginLoadFailed as exc: + raise ConfigError(f"Failed to load engine {engine_id!r}: {exc}") from exc + return backend - backend = getattr(mod, "BACKEND", None) - if backend is None: + +def list_backends(*, allowlist: Iterable[str] | None = None) -> list[EngineBackend]: + backends: list[EngineBackend] = [] + for engine_id in list_backend_ids(allowlist=allowlist): + try: + backends.append(get_backend(engine_id, allowlist=allowlist)) + except ConfigError: continue - if not isinstance(backend, EngineBackend): - raise RuntimeError(f"{module_name}.BACKEND is not an EngineBackend") - if backend.id in backends: - raise RuntimeError(f"Duplicate backend id: {backend.id}") - backends[backend.id] = backend - + if not backends: + raise ConfigError("No engine backends are available.") return backends -@cache -def _backends() -> Mapping[str, EngineBackend]: - backends = _discover_backends() - return MappingProxyType(backends) - - -def get_backend(engine_id: str) -> EngineBackend: - backends = _backends() - try: - return backends[engine_id] - except KeyError as exc: - available = ", ".join(sorted(backends)) - raise ConfigError( - f"Unknown engine {engine_id!r}. Available: {available}." - ) from exc - - -def list_backends() -> list[EngineBackend]: - backends = _backends() - return [backends[key] for key in sorted(backends)] - - -def list_backend_ids() -> list[str]: - return sorted(_backends()) - - -def get_engine_config( - config: dict[str, Any], engine_id: str, config_path: Path -) -> EngineConfig: - engine_cfg = config.get(engine_id) or {} - if not isinstance(engine_cfg, dict): - raise ConfigError( - f"Invalid `{engine_id}` config in {config_path}; expected a table." - ) - return engine_cfg +def list_backend_ids(*, allowlist: Iterable[str] | None = None) -> list[str]: + return list_ids( + ENGINE_GROUP, + allowlist=allowlist, + reserved_ids=RESERVED_ENGINE_IDS, + ) diff --git a/src/takopi/ids.py b/src/takopi/ids.py new file mode 100644 index 0000000..cf19a37 --- /dev/null +++ b/src/takopi/ids.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import re + +ID_PATTERN = r"^[a-z0-9_]{1,32}$" +_ID_RE = re.compile(ID_PATTERN) + +RESERVED_CLI_COMMANDS = frozenset({"init", "plugins"}) +RESERVED_CHAT_COMMANDS = frozenset({"cancel"}) +RESERVED_ENGINE_IDS = RESERVED_CLI_COMMANDS | RESERVED_CHAT_COMMANDS +RESERVED_COMMAND_IDS = RESERVED_CLI_COMMANDS | RESERVED_CHAT_COMMANDS + + +def is_valid_id(value: str) -> bool: + return bool(_ID_RE.fullmatch(value)) diff --git a/src/takopi/plugins.py b/src/takopi/plugins.py new file mode 100644 index 0000000..79cae68 --- /dev/null +++ b/src/takopi/plugins.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from importlib.metadata import EntryPoint, entry_points +from typing import Any, Callable + +from .ids import ID_PATTERN, is_valid_id + +ENGINE_GROUP = "takopi.engine_backends" +TRANSPORT_GROUP = "takopi.transport_backends" +COMMAND_GROUP = "takopi.command_backends" + + +@dataclass(frozen=True, slots=True) +class PluginLoadError: + group: str + name: str + value: str + distribution: str | None + error: str + + +class PluginLoadFailed(RuntimeError): + def __init__(self, error: PluginLoadError) -> None: + super().__init__(error.error) + self.error = error + + +class PluginNotFound(LookupError): + def __init__(self, group: str, name: str, available: Iterable[str]) -> None: + self.group = group + self.name = name + self.available = tuple(sorted(available)) + message = f"{group} plugin {name!r} not found" + if self.available: + message = f"{message}. Available: {', '.join(self.available)}." + super().__init__(message) + + +_LOAD_ERRORS: list[PluginLoadError] = [] +_LOAD_ERROR_KEYS: set[tuple[str, str, str, str | None, str]] = set() +_LOADED: dict[tuple[str, str], Any] = {} + + +def _error_key(error: PluginLoadError) -> tuple[str, str, str, str | None, str]: + return (error.group, error.name, error.value, error.distribution, error.error) + + +def _record_error(error: PluginLoadError) -> None: + key = _error_key(error) + if key in _LOAD_ERROR_KEYS: + return + _LOAD_ERROR_KEYS.add(key) + _LOAD_ERRORS.append(error) + + +def get_load_errors() -> tuple[PluginLoadError, ...]: + return tuple(_LOAD_ERRORS) + + +def clear_load_errors(*, group: str | None = None, name: str | None = None) -> None: + if group is None and name is None: + _LOAD_ERRORS.clear() + _LOAD_ERROR_KEYS.clear() + return + remaining: list[PluginLoadError] = [] + _LOAD_ERROR_KEYS.clear() + for error in _LOAD_ERRORS: + if group is not None and error.group != group: + remaining.append(error) + _LOAD_ERROR_KEYS.add(_error_key(error)) + continue + if name is not None and error.name != name: + remaining.append(error) + _LOAD_ERROR_KEYS.add(_error_key(error)) + continue + _LOAD_ERRORS[:] = remaining + + +def reset_plugin_state() -> None: + clear_load_errors() + _LOADED.clear() + + +def _select_entrypoints(group: str) -> list[EntryPoint]: + eps = entry_points() + if hasattr(eps, "select"): + return list(eps.select(group=group)) + if isinstance(eps, Mapping): + return list(eps.get(group, [])) + return [] + + +def entrypoint_distribution_name(ep: EntryPoint) -> str | None: + dist = getattr(ep, "dist", None) + if dist is None: + return None + name = getattr(dist, "name", None) + if name: + return name + metadata = getattr(dist, "metadata", None) + if metadata is None: + return None + try: + return metadata["Name"] + except Exception: + return None + + +def normalize_allowlist(allowlist: Iterable[str] | None) -> set[str] | None: + if allowlist is None: + return None + cleaned = {item.strip().lower() for item in allowlist if item and item.strip()} + return cleaned or None + + +def is_entrypoint_allowed(ep: EntryPoint, allowlist: set[str] | None) -> bool: + if allowlist is None: + return True + dist_name = entrypoint_distribution_name(ep) + if dist_name is None: + return False + return dist_name.lower() in allowlist + + +def _entrypoint_sort_key(ep: EntryPoint) -> tuple[str, str, str]: + dist = entrypoint_distribution_name(ep) or "" + return (ep.name, dist, ep.value) + + +def _normalize_reserved(reserved: Iterable[str] | None) -> set[str] | None: + if reserved is None: + return None + cleaned = {item.strip().lower() for item in reserved if item and item.strip()} + return cleaned or None + + +def _discover_entrypoints( + group: str, + *, + allowlist: Iterable[str] | None = None, + reserved_ids: Iterable[str] | None = None, +) -> tuple[dict[str, EntryPoint], dict[str, list[EntryPoint]]]: + allow = normalize_allowlist(allowlist) + reserved = _normalize_reserved(reserved_ids) + raw_eps = _select_entrypoints(group) + eps = [ep for ep in raw_eps if is_entrypoint_allowed(ep, allow)] + eps.sort(key=_entrypoint_sort_key) + + by_name: dict[str, EntryPoint] = {} + duplicates: dict[str, list[EntryPoint]] = {} + + for ep in eps: + if not is_valid_id(ep.name): + _record_error( + PluginLoadError( + group=group, + name=ep.name, + value=ep.value, + distribution=entrypoint_distribution_name(ep), + error=(f"invalid plugin id {ep.name!r}; must match {ID_PATTERN}"), + ) + ) + continue + if reserved is not None and ep.name.lower() in reserved: + _record_error( + PluginLoadError( + group=group, + name=ep.name, + value=ep.value, + distribution=entrypoint_distribution_name(ep), + error=f"reserved plugin id {ep.name!r} is not allowed", + ) + ) + continue + existing = by_name.get(ep.name) + if existing is None: + by_name[ep.name] = ep + continue + duplicates.setdefault(ep.name, [existing]).append(ep) + + for name, items in duplicates.items(): + providers = ", ".join( + sorted( + {entrypoint_distribution_name(item) or "" for item in items} + ) + ) + message = f"duplicate plugin id {name!r} from {providers}" + for item in items: + _record_error( + PluginLoadError( + group=group, + name=name, + value=item.value, + distribution=entrypoint_distribution_name(item), + error=message, + ) + ) + by_name.pop(name, None) + + return by_name, duplicates + + +def list_entrypoints( + group: str, + *, + allowlist: Iterable[str] | None = None, + reserved_ids: Iterable[str] | None = None, +) -> list[EntryPoint]: + by_name, _ = _discover_entrypoints( + group, allowlist=allowlist, reserved_ids=reserved_ids + ) + return [by_name[name] for name in sorted(by_name)] + + +def list_ids( + group: str, + *, + allowlist: Iterable[str] | None = None, + reserved_ids: Iterable[str] | None = None, +) -> list[str]: + return sorted( + ep.name + for ep in list_entrypoints( + group, allowlist=allowlist, reserved_ids=reserved_ids + ) + ) + + +def load_entrypoint( + group: str, + name: str, + *, + allowlist: Iterable[str] | None = None, + validator: Callable[[Any, EntryPoint], None] | None = None, +) -> Any: + by_name, duplicates = _discover_entrypoints(group, allowlist=allowlist) + if name in duplicates: + items = duplicates[name] + providers = ", ".join( + sorted( + {entrypoint_distribution_name(item) or "" for item in items} + ) + ) + error = PluginLoadError( + group=group, + name=name, + value=items[0].value, + distribution=entrypoint_distribution_name(items[0]), + error=f"duplicate plugin id {name!r} from {providers}", + ) + _record_error(error) + raise PluginLoadFailed(error) + + ep = by_name.get(name) + if ep is None: + raise PluginNotFound(group, name, by_name) + + key = (group, name) + if key in _LOADED: + return _LOADED[key] + + try: + loaded = ep.load() + if validator is not None: + validator(loaded, ep) + except PluginLoadFailed: + raise + except Exception as exc: + error = PluginLoadError( + group=group, + name=ep.name, + value=ep.value, + distribution=entrypoint_distribution_name(ep), + error=str(exc), + ) + _record_error(error) + raise PluginLoadFailed(error) from exc + + _LOADED[key] = loaded + clear_load_errors(group=group, name=name) + return loaded diff --git a/src/takopi/settings.py b/src/takopi/settings.py index 605e8e2..53009df 100644 --- a/src/takopi/settings.py +++ b/src/takopi/settings.py @@ -323,6 +323,20 @@ def require_telegram(settings: TakopiSettings, config_path: Path) -> tuple[str, return tg.bot_token.get_secret_value().strip(), tg.chat_id +def require_telegram_config( + config: dict[str, object], config_path: Path +) -> tuple[str, int]: + raw_token = config.get("bot_token") + if raw_token is None or not isinstance(raw_token, str) or not raw_token.strip(): + raise ConfigError(f"Missing bot token in {config_path}.") + raw_chat_id = config.get("chat_id") + if raw_chat_id is None: + raise ConfigError(f"Missing chat_id in {config_path}.") + if isinstance(raw_chat_id, bool) or not isinstance(raw_chat_id, int): + raise ConfigError(f"Invalid `chat_id` in {config_path}; expected an integer.") + return raw_token.strip(), raw_chat_id + + def _resolve_config_path(path: str | Path | None) -> Path: return Path(path).expanduser() if path else HOME_CONFIG_PATH diff --git a/src/takopi/telegram/__init__.py b/src/takopi/telegram/__init__.py index 1b583fa..b133e53 100644 --- a/src/takopi/telegram/__init__.py +++ b/src/takopi/telegram/__init__.py @@ -1,5 +1,10 @@ """Telegram-specific clients and adapters.""" from .client import parse_incoming_update, poll_incoming +from .types import TelegramIncomingMessage -__all__ = ["parse_incoming_update", "poll_incoming"] +__all__ = [ + "TelegramIncomingMessage", + "parse_incoming_update", + "poll_incoming", +] diff --git a/src/takopi/telegram/backend.py b/src/takopi/telegram/backend.py index e97de6f..718c260 100644 --- a/src/takopi/telegram/backend.py +++ b/src/takopi/telegram/backend.py @@ -6,11 +6,10 @@ from pathlib import Path import anyio from ..backends import EngineBackend -from ..config import ProjectsConfig -from ..router import AutoRouter from ..runner_bridge import ExecBridgeConfig -from ..settings import TakopiSettings, require_telegram +from ..settings import require_telegram_config from ..transports import SetupResult, TransportBackend +from ..transport_runtime import TransportRuntime from .bridge import ( TelegramBridgeConfig, TelegramPresenter, @@ -22,24 +21,22 @@ from .onboarding import check_setup, interactive_setup def _build_startup_message( - router: AutoRouter, - projects: ProjectsConfig, + runtime: TransportRuntime, *, startup_pwd: str, ) -> str: - available_engines = [entry.engine for entry in router.available_entries] - missing_engines = [entry.engine for entry in router.entries if not entry.available] + available_engines = list(runtime.available_engine_ids()) + missing_engines = list(runtime.missing_engine_ids()) engine_list = ", ".join(available_engines) if available_engines else "none" if missing_engines: engine_list = f"{engine_list} (not installed: {', '.join(missing_engines)})" project_aliases = sorted( - {project.alias for project in projects.projects.values()}, - key=str.lower, + {alias for alias in runtime.project_aliases()}, key=str.lower ) project_list = ", ".join(project_aliases) if project_aliases else "none" return ( f"\N{OCTOPUS} **takopi is ready**\n\n" - f"default: `{router.default_engine}` \n" + f"default: `{runtime.default_engine}` \n" f"agents: `{engine_list}` \n" f"projects: `{project_list}` \n" f"working in: `{startup_pwd}`" @@ -61,24 +58,25 @@ class TelegramBackend(TransportBackend): def interactive_setup(self, *, force: bool) -> bool: return interactive_setup(force=force) - def lock_token(self, *, settings: TakopiSettings, config_path: Path) -> str | None: - token, _ = require_telegram(settings, config_path) + def lock_token( + self, *, transport_config: dict[str, object], config_path: Path + ) -> str | None: + token, _ = require_telegram_config(transport_config, config_path) return token def build_and_run( self, *, - settings: TakopiSettings, + transport_config: dict[str, object], config_path: Path, - router: AutoRouter, - projects: ProjectsConfig, + runtime: TransportRuntime, final_notify: bool, default_engine_override: str | None, ) -> None: - token, chat_id = require_telegram(settings, config_path) + _ = default_engine_override + token, chat_id = require_telegram_config(transport_config, config_path) startup_msg = _build_startup_message( - router, - projects, + runtime, startup_pwd=os.getcwd(), ) bot = TelegramClient(token) @@ -91,13 +89,13 @@ class TelegramBackend(TransportBackend): ) cfg = TelegramBridgeConfig( bot=bot, - router=router, + runtime=runtime, chat_id=chat_id, startup_msg=startup_msg, exec_cfg=exec_cfg, - projects=projects, ) anyio.run(run_main_loop, cfg) telegram_backend = TelegramBackend() +BACKEND = telegram_backend diff --git a/src/takopi/telegram/bridge.py b/src/takopi/telegram/bridge.py index 7d4ac45..055f1cc 100644 --- a/src/takopi/telegram/bridge.py +++ b/src/takopi/telegram/bridge.py @@ -1,13 +1,24 @@ from __future__ import annotations -from collections.abc import AsyncIterator, Awaitable, Callable -from dataclasses import dataclass, field -import re +import shlex +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence +from dataclasses import dataclass import anyio -from ..config import ProjectsConfig, empty_projects_config +from ..commands import ( + CommandContext, + CommandExecutor, + RunMode, + RunRequest, + RunResult, + get_command, + list_command_ids, +) from ..context import RunContext +from ..config import ConfigError +from ..directives import DirectiveError +from ..ids import RESERVED_COMMAND_IDS, is_valid_id from ..runner_bridge import ( ExecBridgeConfig, IncomingMessage as RunnerIncomingMessage, @@ -19,31 +30,22 @@ from ..logging import bind_run_context, clear_context, get_logger from ..markdown import MarkdownFormatter, MarkdownParts from ..model import EngineId, ResumeToken from ..progress import ProgressState, ProgressTracker -from ..router import AutoRouter, RunnerUnavailableError +from ..router import RunnerUnavailableError from ..runner import Runner from ..scheduler import ThreadJob, ThreadScheduler -from ..transport import ( - IncomingMessage as TransportIncomingMessage, - MessageRef, - RenderedMessage, - SendOptions, - Transport, -) +from ..transport import MessageRef, RenderedMessage, SendOptions, Transport +from ..plugins import COMMAND_GROUP, list_entrypoints from ..utils.paths import reset_run_base_dir, set_run_base_dir -from ..worktrees import WorktreeError, resolve_run_cwd +from ..transport_runtime import TransportRuntime from .client import BotClient, poll_incoming +from .types import TelegramIncomingMessage from .render import prepare_telegram logger = get_logger(__name__) -_COMMAND_RE = re.compile(r"^[a-z0-9_]{1,32}$") _MAX_BOT_COMMANDS = 100 -def _is_valid_bot_command(command: str) -> bool: - return bool(_COMMAND_RE.fullmatch(command)) - - def _is_cancel_command(text: str) -> bool: stripped = text.strip() if not stripped: @@ -52,264 +54,75 @@ def _is_cancel_command(text: str) -> bool: return command == "/cancel" or command.startswith("/cancel@") -def _strip_engine_command( - text: str, *, engine_ids: tuple[EngineId, ...] -) -> tuple[str, EngineId | None]: - if not text: - return text, None - - if not engine_ids: - return text, None - - engine_map = {engine.lower(): engine for engine in engine_ids} - lines = text.splitlines() - idx = next((i for i, line in enumerate(lines) if line.strip()), None) - if idx is None: - return text, None - - line = lines[idx].lstrip() - if not line.startswith("/"): - return text, None - - parts = line.split(maxsplit=1) - command = parts[0][1:] +def _parse_slash_command(text: str) -> tuple[str | None, str]: + stripped = text.lstrip() + if not stripped.startswith("/"): + return None, text + lines = stripped.splitlines() + if not lines: + return None, text + first_line = lines[0] + token, _, rest = first_line.partition(" ") + command = token[1:] + if not command: + return None, text if "@" in command: command = command.split("@", 1)[0] - engine = engine_map.get(command.lower()) - if engine is None: - return text, None - - remainder = parts[1] if len(parts) > 1 else "" - if remainder: - lines[idx] = remainder - else: - lines.pop(idx) - return "\n".join(lines).strip(), engine + args_text = rest + if len(lines) > 1: + tail = "\n".join(lines[1:]) + args_text = f"{args_text}\n{tail}" if args_text else tail + return command.lower(), args_text -@dataclass(frozen=True, slots=True) -class ParsedDirectives: - prompt: str - engine: EngineId | None - project: str | None - branch: str | None - - -@dataclass(frozen=True, slots=True) -class ResolvedMessage: - prompt: str - resume_token: ResumeToken | None - engine_override: EngineId | None - context: RunContext | None - - -class DirectiveError(RuntimeError): - pass - - -def _parse_directives( - text: str, - *, - engine_ids: tuple[EngineId, ...], - projects: ProjectsConfig, -) -> ParsedDirectives: - if not text: - return ParsedDirectives(prompt="", engine=None, project=None, branch=None) - - lines = text.splitlines() - idx = next((i for i, line in enumerate(lines) if line.strip()), None) - if idx is None: - return ParsedDirectives(prompt=text, engine=None, project=None, branch=None) - - line = lines[idx].lstrip() - tokens = line.split() - if not tokens: - return ParsedDirectives(prompt=text, engine=None, project=None, branch=None) - - engine_map = {engine.lower(): engine for engine in engine_ids} - project_map = {alias.lower(): alias for alias in projects.projects} - - engine: EngineId | None = None - project: str | None = None - branch: str | None = None - consumed = 0 - - for token in tokens: - if token.startswith("/"): - name = token[1:] - if "@" in name: - name = name.split("@", 1)[0] - if not name: - break - key = name.lower() - engine_candidate = engine_map.get(key) - project_candidate = project_map.get(key) - if engine_candidate is not None: - if engine is not None: - raise DirectiveError("multiple engine directives") - engine = engine_candidate - consumed += 1 - continue - if project_candidate is not None: - if project is not None: - raise DirectiveError("multiple project directives") - project = project_candidate - consumed += 1 - continue - break - if token.startswith("@"): - value = token[1:] - if not value: - break - if branch is not None: - raise DirectiveError("multiple @branch directives") - branch = value - consumed += 1 - continue - break - - if consumed == 0: - return ParsedDirectives(prompt=text, engine=None, project=None, branch=None) - - if consumed < len(tokens): - remainder = " ".join(tokens[consumed:]) - lines[idx] = remainder - else: - lines.pop(idx) - - prompt = "\n".join(lines).strip() - return ParsedDirectives( - prompt=prompt, engine=engine, project=project, branch=branch - ) - - -def _parse_ctx_line(text: str | None, *, projects: ProjectsConfig) -> RunContext | None: - if not text: - return None - ctx: RunContext | None = None - for line in text.splitlines(): - stripped = line.strip() - if stripped.startswith("`") and stripped.endswith("`") and len(stripped) > 1: - stripped = stripped[1:-1].strip() - elif stripped.startswith("`"): - stripped = stripped[1:].strip() - elif stripped.endswith("`"): - stripped = stripped[:-1].strip() - if not stripped.lower().startswith("ctx:"): - continue - content = stripped.split(":", 1)[1].strip() - if not content: - continue - tokens = content.split() - if not tokens: - continue - project = tokens[0] - branch = None - if len(tokens) >= 2: - if tokens[1] == "@" and len(tokens) >= 3: - branch = tokens[2] - elif tokens[1].startswith("@"): - branch = tokens[1][1:] - project_key = project.lower() - if project_key not in projects.projects: - raise DirectiveError(f"unknown project {project!r} in ctx line") - ctx = RunContext(project=project_key, branch=branch) - return ctx - - -def _format_context_line( - context: RunContext | None, *, projects: ProjectsConfig -) -> str | None: - if context is None or context.project is None: - return None - project_cfg = projects.projects.get(context.project) - alias = project_cfg.alias if project_cfg is not None else context.project - if context.branch: - return f"`ctx: {alias} @ {context.branch}`" - return f"`ctx: {alias}`" - - -def _resolve_message( - *, - text: str, - reply_text: str | None, - router: AutoRouter, - projects: ProjectsConfig, -) -> ResolvedMessage: - directives = _parse_directives( - text, - engine_ids=router.engine_ids, - projects=projects, - ) - reply_ctx = _parse_ctx_line(reply_text, projects=projects) - resume_token = router.resolve_resume(directives.prompt, reply_text) - - if resume_token is not None: - return ResolvedMessage( - prompt=directives.prompt, - resume_token=resume_token, - engine_override=None, - context=reply_ctx, - ) - - if reply_ctx is not None: - engine_override = None - if reply_ctx.project is not None: - project = projects.projects.get(reply_ctx.project) - if project is not None and project.default_engine is not None: - engine_override = project.default_engine - return ResolvedMessage( - prompt=directives.prompt, - resume_token=None, - engine_override=engine_override, - context=reply_ctx, - ) - - project_key = directives.project - if project_key is None and projects.default_project is not None: - project_key = projects.default_project - - context = None - if project_key is not None or directives.branch is not None: - context = RunContext(project=project_key, branch=directives.branch) - - engine_override = directives.engine - if engine_override is None and project_key is not None: - project = projects.projects.get(project_key) - if project is not None and project.default_engine is not None: - engine_override = project.default_engine - - return ResolvedMessage( - prompt=directives.prompt, - resume_token=None, - engine_override=engine_override, - context=context, - ) - - -def _build_bot_commands( - router: AutoRouter, projects: ProjectsConfig -) -> list[dict[str, str]]: +def _build_bot_commands(runtime: TransportRuntime) -> list[dict[str, str]]: commands: list[dict[str, str]] = [] seen: set[str] = set() - for entry in router.available_entries: - cmd = entry.engine.lower() + for engine_id in runtime.available_engine_ids(): + cmd = engine_id.lower() if cmd in seen: continue commands.append({"command": cmd, "description": f"use agent: {cmd}"}) seen.add(cmd) - for alias, project in projects.projects.items(): + for alias in runtime.project_aliases(): cmd = alias.lower() if cmd in seen: continue - if not _is_valid_bot_command(cmd): + if not is_valid_id(cmd): logger.debug( "startup.command_menu.skip_project", - alias=project.alias, + alias=alias, ) continue commands.append({"command": cmd, "description": f"work on: {cmd}"}) seen.add(cmd) + allowlist = runtime.allowlist + for ep in list_entrypoints( + COMMAND_GROUP, + allowlist=allowlist, + reserved_ids=RESERVED_COMMAND_IDS, + ): + try: + backend = get_command(ep.name, allowlist=allowlist) + except ConfigError as exc: + logger.info( + "startup.command_menu.skip_command", + command=ep.name, + error=str(exc), + ) + continue + cmd = backend.id.lower() + if cmd in seen: + continue + if not is_valid_id(cmd): + logger.debug( + "startup.command_menu.skip_command_id", + command=cmd, + ) + continue + description = backend.description or f"command: {cmd}" + commands.append({"command": cmd, "description": description}) + seen.add(cmd) if "cancel" not in seen: commands.append({"command": "cancel", "description": "cancel run"}) if len(commands) > _MAX_BOT_COMMANDS: @@ -325,7 +138,7 @@ def _build_bot_commands( async def _set_command_menu(cfg: TelegramBridgeConfig) -> None: - commands = _build_bot_commands(cfg.router, cfg.projects) + commands = _build_bot_commands(cfg.runtime) if not commands: return try: @@ -468,11 +281,10 @@ class TelegramTransport: @dataclass(frozen=True) class TelegramBridgeConfig: bot: BotClient - router: AutoRouter + runtime: TransportRuntime chat_id: int startup_msg: str exec_cfg: ExecBridgeConfig - projects: ProjectsConfig = field(default_factory=empty_projects_config) async def _send_plain( @@ -524,7 +336,7 @@ async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int | async def poll_updates( cfg: TelegramBridgeConfig, -) -> AsyncIterator[TransportIncomingMessage]: +) -> AsyncIterator[TelegramIncomingMessage]: offset: int | None = None offset = await _drain_backlog(cfg, offset) await _send_startup(cfg) @@ -535,7 +347,7 @@ async def poll_updates( async def _handle_cancel( cfg: TelegramBridgeConfig, - msg: TransportIncomingMessage, + msg: TelegramIncomingMessage, running_tasks: RunningTasks, ) -> None: chat_id = msg.chat_id @@ -623,7 +435,7 @@ async def _send_with_resume( async def _send_runner_unavailable( - cfg: TelegramBridgeConfig, + exec_cfg: ExecBridgeConfig, *, chat_id: int, user_msg_id: int, @@ -634,30 +446,345 @@ async def _send_runner_unavailable( tracker = ProgressTracker(engine=runner.engine) tracker.set_resume(resume_token) state = tracker.snapshot(resume_formatter=runner.format_resume) - message = cfg.exec_cfg.presenter.render_final( + message = exec_cfg.presenter.render_final( state, elapsed_s=0.0, status="error", answer=f"error:\n{reason}", ) reply_to = MessageRef(channel_id=chat_id, message_id=user_msg_id) - await cfg.exec_cfg.transport.send( + await exec_cfg.transport.send( channel_id=chat_id, message=message, options=SendOptions(reply_to=reply_to, notify=True), ) +async def _run_engine( + *, + exec_cfg: ExecBridgeConfig, + runtime: TransportRuntime, + running_tasks: RunningTasks | None, + chat_id: int, + user_msg_id: int, + text: str, + resume_token: ResumeToken | None, + context: RunContext | None, + reply_ref: MessageRef | None = None, + on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] + | None = None, + engine_override: EngineId | None = None, +) -> None: + try: + try: + entry = runtime.resolve_runner( + resume_token=resume_token, + engine_override=engine_override, + ) + except RunnerUnavailableError as exc: + await _send_plain( + exec_cfg.transport, + chat_id=chat_id, + user_msg_id=user_msg_id, + text=f"error:\n{exc}", + ) + return + if not entry.available: + reason = entry.issue or "engine unavailable" + await _send_runner_unavailable( + exec_cfg, + chat_id=chat_id, + user_msg_id=user_msg_id, + resume_token=resume_token, + runner=entry.runner, + reason=reason, + ) + return + try: + cwd = runtime.resolve_run_cwd(context) + except ConfigError as exc: + await _send_plain( + exec_cfg.transport, + chat_id=chat_id, + user_msg_id=user_msg_id, + text=f"error:\n{exc}", + ) + return + run_base_token = set_run_base_dir(cwd) + try: + run_fields = { + "chat_id": chat_id, + "user_msg_id": user_msg_id, + "engine": entry.runner.engine, + "resume": resume_token.value if resume_token else None, + } + if context is not None: + run_fields["project"] = context.project + run_fields["branch"] = context.branch + if cwd is not None: + run_fields["cwd"] = str(cwd) + bind_run_context(**run_fields) + context_line = runtime.format_context_line(context) + incoming = RunnerIncomingMessage( + channel_id=chat_id, + message_id=user_msg_id, + text=text, + reply_to=reply_ref, + ) + await handle_message( + exec_cfg, + runner=entry.runner, + incoming=incoming, + resume_token=resume_token, + context=context, + context_line=context_line, + strip_resume_line=runtime.is_resume_line, + running_tasks=running_tasks, + on_thread_known=on_thread_known, + ) + finally: + reset_run_base_dir(run_base_token) + except Exception as exc: + logger.exception( + "handle.worker_failed", + error=str(exc), + error_type=exc.__class__.__name__, + ) + finally: + clear_context() + + +def _split_command_args(text: str) -> tuple[str, ...]: + if not text.strip(): + return () + try: + return tuple(shlex.split(text)) + except ValueError: + return tuple(text.split()) + + +class _CaptureTransport: + def __init__(self) -> None: + self._next_id = 1 + self.last_message: RenderedMessage | None = None + + async def send( + self, + *, + channel_id: int | str, + message: RenderedMessage, + options: SendOptions | None = None, + ) -> MessageRef: + _ = options + ref = MessageRef(channel_id=channel_id, message_id=self._next_id) + self._next_id += 1 + self.last_message = message + return ref + + async def edit( + self, *, ref: MessageRef, message: RenderedMessage, wait: bool = True + ) -> MessageRef: + _ = ref, wait + self.last_message = message + return ref + + async def delete(self, *, ref: MessageRef) -> bool: + _ = ref + return True + + async def close(self) -> None: + return None + + +class _TelegramCommandExecutor(CommandExecutor): + def __init__( + self, + *, + exec_cfg: ExecBridgeConfig, + runtime: TransportRuntime, + running_tasks: RunningTasks, + scheduler: ThreadScheduler, + chat_id: int, + user_msg_id: int, + ) -> None: + self._exec_cfg = exec_cfg + self._runtime = runtime + self._running_tasks = running_tasks + self._scheduler = scheduler + self._chat_id = chat_id + self._user_msg_id = user_msg_id + self._reply_ref = MessageRef(channel_id=chat_id, message_id=user_msg_id) + + async def send( + self, + message: RenderedMessage | str, + *, + reply_to: MessageRef | None = None, + notify: bool = True, + ) -> MessageRef | None: + rendered = ( + message + if isinstance(message, RenderedMessage) + else RenderedMessage(text=message) + ) + reply_ref = self._reply_ref if reply_to is None else reply_to + return await self._exec_cfg.transport.send( + channel_id=self._chat_id, + message=rendered, + options=SendOptions(reply_to=reply_ref, notify=notify), + ) + + async def run_one( + self, request: RunRequest, *, mode: RunMode = "emit" + ) -> RunResult: + engine = self._runtime.resolve_engine( + engine_override=request.engine, + context=request.context, + ) + if mode == "capture": + capture = _CaptureTransport() + exec_cfg = ExecBridgeConfig( + transport=capture, + presenter=self._exec_cfg.presenter, + final_notify=False, + ) + await _run_engine( + exec_cfg=exec_cfg, + runtime=self._runtime, + running_tasks={}, + chat_id=self._chat_id, + user_msg_id=self._user_msg_id, + text=request.prompt, + resume_token=None, + context=request.context, + reply_ref=self._reply_ref, + on_thread_known=None, + engine_override=engine, + ) + return RunResult(engine=engine, message=capture.last_message) + await _run_engine( + exec_cfg=self._exec_cfg, + runtime=self._runtime, + running_tasks=self._running_tasks, + chat_id=self._chat_id, + user_msg_id=self._user_msg_id, + text=request.prompt, + resume_token=None, + context=request.context, + reply_ref=self._reply_ref, + on_thread_known=self._scheduler.note_thread_known, + engine_override=engine, + ) + return RunResult(engine=engine, message=None) + + async def run_many( + self, + requests: Sequence[RunRequest], + *, + mode: RunMode = "emit", + parallel: bool = False, + ) -> list[RunResult]: + if not parallel: + return [await self.run_one(request, mode=mode) for request in requests] + results: list[RunResult | None] = [None] * len(requests) + + async with anyio.create_task_group() as tg: + + async def run_idx(idx: int, request: RunRequest) -> None: + results[idx] = await self.run_one(request, mode=mode) + + for idx, request in enumerate(requests): + tg.start_soon(run_idx, idx, request) + + return [result for result in results if result is not None] + + +async def _dispatch_command( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + command_id: str, + args_text: str, + running_tasks: RunningTasks, + scheduler: ThreadScheduler, +) -> None: + allowlist = cfg.runtime.allowlist + chat_id = msg.chat_id + user_msg_id = msg.message_id + reply_ref = ( + MessageRef(channel_id=chat_id, message_id=msg.reply_to_message_id) + if msg.reply_to_message_id is not None + else None + ) + executor = _TelegramCommandExecutor( + exec_cfg=cfg.exec_cfg, + runtime=cfg.runtime, + running_tasks=running_tasks, + scheduler=scheduler, + chat_id=chat_id, + user_msg_id=user_msg_id, + ) + message_ref = MessageRef(channel_id=chat_id, message_id=user_msg_id) + try: + backend = get_command(command_id, allowlist=allowlist, required=False) + except ConfigError as exc: + await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True) + return + if backend is None: + return + try: + plugin_config = cfg.runtime.plugin_config(command_id) + except ConfigError as exc: + await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True) + return + ctx = CommandContext( + command=command_id, + text=msg.text, + args_text=args_text, + args=_split_command_args(args_text), + message=message_ref, + reply_to=reply_ref, + reply_text=msg.reply_to_text, + config_path=cfg.runtime.config_path, + plugin_config=plugin_config, + runtime=cfg.runtime, + executor=executor, + ) + try: + result = await backend.handle(ctx) + except Exception as exc: + logger.exception( + "command.failed", + command=command_id, + error=str(exc), + error_type=exc.__class__.__name__, + ) + await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True) + return + if result is not None: + reply_to = message_ref if result.reply_to is None else result.reply_to + await executor.send(result.text, reply_to=reply_to, notify=result.notify) + return None + + async def run_main_loop( cfg: TelegramBridgeConfig, - poller: Callable[ - [TelegramBridgeConfig], AsyncIterator[TransportIncomingMessage] - ] = poll_updates, + poller: Callable[[TelegramBridgeConfig], AsyncIterator[TelegramIncomingMessage]] = ( + poll_updates + ), ) -> None: running_tasks: RunningTasks = {} try: await _set_command_menu(cfg) + allowlist = cfg.runtime.allowlist + command_ids = { + command_id.lower() for command_id in list_command_ids(allowlist=allowlist) + } + reserved_commands = { + *{engine.lower() for engine in cfg.runtime.engine_ids}, + *{alias.lower() for alias in cfg.runtime.project_aliases()}, + *RESERVED_COMMAND_IDS, + } async with anyio.create_task_group() as tg: async def run_job( @@ -671,86 +798,19 @@ async def run_main_loop( | None = None, engine_override: EngineId | None = None, ) -> None: - try: - try: - entry = ( - cfg.router.entry_for_engine(engine_override) - if resume_token is None - else cfg.router.entry_for(resume_token) - ) - except RunnerUnavailableError as exc: - await _send_plain( - cfg.exec_cfg.transport, - chat_id=chat_id, - user_msg_id=user_msg_id, - text=f"error:\n{exc}", - ) - return - if not entry.available: - reason = entry.issue or "engine unavailable" - await _send_runner_unavailable( - cfg, - chat_id=chat_id, - user_msg_id=user_msg_id, - resume_token=resume_token, - runner=entry.runner, - reason=reason, - ) - return - try: - cwd = resolve_run_cwd(context, projects=cfg.projects) - except WorktreeError as exc: - await _send_plain( - cfg.exec_cfg.transport, - chat_id=chat_id, - user_msg_id=user_msg_id, - text=f"error:\n{exc}", - ) - return - run_base_token = set_run_base_dir(cwd) - try: - run_fields = { - "chat_id": chat_id, - "user_msg_id": user_msg_id, - "engine": entry.runner.engine, - "resume": resume_token.value if resume_token else None, - } - if context is not None: - run_fields["project"] = context.project - run_fields["branch"] = context.branch - if cwd is not None: - run_fields["cwd"] = str(cwd) - bind_run_context(**run_fields) - context_line = _format_context_line( - context, projects=cfg.projects - ) - incoming = RunnerIncomingMessage( - channel_id=chat_id, - message_id=user_msg_id, - text=text, - reply_to=reply_ref, - ) - await handle_message( - cfg.exec_cfg, - runner=entry.runner, - incoming=incoming, - resume_token=resume_token, - context=context, - context_line=context_line, - strip_resume_line=cfg.router.is_resume_line, - running_tasks=running_tasks, - on_thread_known=on_thread_known, - ) - finally: - reset_run_base_dir(run_base_token) - except Exception as exc: - logger.exception( - "handle.worker_failed", - error=str(exc), - error_type=exc.__class__.__name__, - ) - finally: - clear_context() + await _run_engine( + exec_cfg=cfg.exec_cfg, + runtime=cfg.runtime, + running_tasks=running_tasks, + chat_id=chat_id, + user_msg_id=user_msg_id, + text=text, + resume_token=resume_token, + context=context, + reply_ref=reply_ref, + on_thread_known=on_thread_known, + engine_override=engine_override, + ) async def run_thread_job(job: ThreadJob) -> None: await run_job( @@ -779,13 +839,29 @@ async def run_main_loop( tg.start_soon(_handle_cancel, cfg, msg, running_tasks) continue + command_id, args_text = _parse_slash_command(text) + if command_id is not None and command_id not in reserved_commands: + if command_id not in command_ids: + command_ids = { + cid.lower() for cid in list_command_ids(allowlist=allowlist) + } + if command_id in command_ids: + tg.start_soon( + _dispatch_command, + cfg, + msg, + command_id, + args_text, + running_tasks, + scheduler, + ) + continue + reply_text = msg.reply_to_text try: - resolved = _resolve_message( + resolved = cfg.runtime.resolve_message( text=text, reply_text=reply_text, - router=cfg.router, - projects=cfg.projects, ) except DirectiveError as exc: await _send_plain( diff --git a/src/takopi/telegram/client.py b/src/takopi/telegram/client.py index 83f44a5..b21fe57 100644 --- a/src/takopi/telegram/client.py +++ b/src/takopi/telegram/client.py @@ -18,7 +18,7 @@ import httpx import anyio from ..logging import get_logger -from ..transport import IncomingMessage +from .types import TelegramIncomingMessage logger = get_logger(__name__) @@ -45,7 +45,7 @@ def is_group_chat_id(chat_id: int) -> bool: def parse_incoming_update( update: dict[str, Any], *, chat_id: int -) -> IncomingMessage | None: +) -> TelegramIncomingMessage | None: msg = update.get("message") if not isinstance(msg, dict): return None @@ -79,7 +79,7 @@ def parse_incoming_update( if isinstance(sender, dict) and isinstance(sender.get("id"), int) else None ) - return IncomingMessage( + return TelegramIncomingMessage( transport="telegram", chat_id=msg_chat_id, message_id=message_id, @@ -96,7 +96,7 @@ async def poll_incoming( *, chat_id: int, offset: int | None = None, -) -> AsyncIterator[IncomingMessage]: +) -> AsyncIterator[TelegramIncomingMessage]: while True: updates = await bot.get_updates( offset=offset, timeout_s=50, allowed_updates=["message"] diff --git a/src/takopi/telegram/types.py b/src/takopi/telegram/types.py new file mode 100644 index 0000000..5425702 --- /dev/null +++ b/src/takopi/telegram/types.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class TelegramIncomingMessage: + transport: str + chat_id: int + message_id: int + text: str + reply_to_message_id: int | None + reply_to_text: str | None + sender_id: int | None + raw: dict[str, Any] | None = None diff --git a/src/takopi/transport.py b/src/takopi/transport.py index fab11d8..b0a8789 100644 --- a/src/takopi/transport.py +++ b/src/takopi/transport.py @@ -7,18 +7,6 @@ ChannelId: TypeAlias = int | str MessageId: TypeAlias = int | str -@dataclass(frozen=True, slots=True) -class IncomingMessage: - transport: str - chat_id: int - message_id: int - text: str - reply_to_message_id: int | None - reply_to_text: str | None - sender_id: int | None - raw: dict[str, Any] | None = None - - @dataclass(frozen=True, slots=True) class MessageRef: channel_id: ChannelId diff --git a/src/takopi/transport_runtime.py b/src/takopi/transport_runtime.py new file mode 100644 index 0000000..455299e --- /dev/null +++ b/src/takopi/transport_runtime.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .config import ConfigError, ProjectsConfig +from .context import RunContext +from .directives import format_context_line, parse_context_line, parse_directives +from .model import EngineId, ResumeToken +from .plugins import normalize_allowlist +from .router import AutoRouter +from .runner import Runner +from .worktrees import WorktreeError, resolve_run_cwd + + +@dataclass(frozen=True, slots=True) +class ResolvedMessage: + prompt: str + resume_token: ResumeToken | None + engine_override: EngineId | None + context: RunContext | None + + +@dataclass(frozen=True, slots=True) +class ResolvedRunner: + engine: EngineId + runner: Runner + available: bool + issue: str | None = None + + +class TransportRuntime: + __slots__ = ( + "_router", + "_projects", + "_allowlist", + "_config_path", + "_plugin_configs", + ) + + def __init__( + self, + *, + router: AutoRouter, + projects: ProjectsConfig, + allowlist: Iterable[str] | None = None, + config_path: Path | None = None, + plugin_configs: Mapping[str, Any] | None = None, + ) -> None: + self._router = router + self._projects = projects + self._allowlist = normalize_allowlist(allowlist) + self._config_path = config_path + self._plugin_configs = dict(plugin_configs or {}) + + @property + def default_engine(self) -> EngineId: + return self._router.default_engine + + def resolve_engine( + self, + *, + engine_override: EngineId | None, + context: RunContext | None, + ) -> EngineId: + if engine_override is not None: + return engine_override + if context is None or context.project is None: + return self._router.default_engine + project = self._projects.projects.get(context.project) + if project is None: + return self._router.default_engine + return project.default_engine or self._router.default_engine + + @property + def engine_ids(self) -> tuple[EngineId, ...]: + return self._router.engine_ids + + def available_engine_ids(self) -> tuple[EngineId, ...]: + return tuple(entry.engine for entry in self._router.available_entries) + + def missing_engine_ids(self) -> tuple[EngineId, ...]: + return tuple( + entry.engine for entry in self._router.entries if not entry.available + ) + + def project_aliases(self) -> tuple[str, ...]: + return tuple(project.alias for project in self._projects.projects.values()) + + @property + def allowlist(self) -> set[str] | None: + return self._allowlist + + @property + def config_path(self) -> Path | None: + return self._config_path + + def plugin_config(self, plugin_id: str) -> dict[str, Any]: + if not self._plugin_configs: + return {} + raw = self._plugin_configs.get(plugin_id) + if raw is None: + return {} + if not isinstance(raw, dict): + path = self._config_path or Path("") + raise ConfigError( + f"Invalid `plugins.{plugin_id}` in {path}; expected a table." + ) + return dict(raw) + + def resolve_message(self, *, text: str, reply_text: str | None) -> ResolvedMessage: + directives = parse_directives( + text, + engine_ids=self._router.engine_ids, + projects=self._projects, + ) + reply_ctx = parse_context_line(reply_text, projects=self._projects) + resume_token = self._router.resolve_resume(directives.prompt, reply_text) + + if resume_token is not None: + return ResolvedMessage( + prompt=directives.prompt, + resume_token=resume_token, + engine_override=None, + context=reply_ctx, + ) + + if reply_ctx is not None: + engine_override = None + if reply_ctx.project is not None: + project = self._projects.projects.get(reply_ctx.project) + if project is not None and project.default_engine is not None: + engine_override = project.default_engine + return ResolvedMessage( + prompt=directives.prompt, + resume_token=None, + engine_override=engine_override, + context=reply_ctx, + ) + + project_key = directives.project + if project_key is None and self._projects.default_project is not None: + project_key = self._projects.default_project + + context = None + if project_key is not None or directives.branch is not None: + context = RunContext(project=project_key, branch=directives.branch) + + engine_override = directives.engine + if engine_override is None and project_key is not None: + project = self._projects.projects.get(project_key) + if project is not None and project.default_engine is not None: + engine_override = project.default_engine + + return ResolvedMessage( + prompt=directives.prompt, + resume_token=None, + engine_override=engine_override, + context=context, + ) + + def resolve_runner( + self, + *, + resume_token: ResumeToken | None, + engine_override: EngineId | None, + ) -> ResolvedRunner: + entry = ( + self._router.entry_for_engine(engine_override) + if resume_token is None + else self._router.entry_for(resume_token) + ) + return ResolvedRunner( + engine=entry.engine, + runner=entry.runner, + available=entry.available, + issue=entry.issue, + ) + + def is_resume_line(self, line: str) -> bool: + return self._router.is_resume_line(line) + + def resolve_run_cwd(self, context: RunContext | None) -> Path | None: + try: + return resolve_run_cwd(context, projects=self._projects) + except WorktreeError as exc: + raise ConfigError(str(exc)) from exc + + def format_context_line(self, context: RunContext | None) -> str | None: + return format_context_line(context, projects=self._projects) diff --git a/src/takopi/transports.py b/src/takopi/transports.py index e05f2ad..111ddc5 100644 --- a/src/takopi/transports.py +++ b/src/takopi/transports.py @@ -2,12 +2,18 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path -from typing import Protocol +from typing import Iterable, Protocol, runtime_checkable from .backends import EngineBackend, SetupIssue -from .config import ConfigError, ProjectsConfig -from .router import AutoRouter -from .settings import TakopiSettings +from .config import ConfigError +from .plugins import ( + PluginLoadFailed, + PluginNotFound, + TRANSPORT_GROUP, + load_entrypoint, + list_ids, +) +from .transport_runtime import TransportRuntime @dataclass(frozen=True, slots=True) @@ -20,6 +26,7 @@ class SetupResult: return not self.issues +@runtime_checkable class TransportBackend(Protocol): id: str description: str @@ -34,53 +41,50 @@ class TransportBackend(Protocol): def interactive_setup(self, *, force: bool) -> bool: ... def lock_token( - self, *, settings: TakopiSettings, config_path: Path + self, *, transport_config: dict[str, object], config_path: Path ) -> str | None: ... def build_and_run( self, *, - settings: TakopiSettings, + transport_config: dict[str, object], config_path: Path, - router: AutoRouter, - projects: ProjectsConfig, + runtime: TransportRuntime, final_notify: bool, default_engine_override: str | None, ) -> None: ... -_registry: dict[str, TransportBackend] = {} -_builtins_loaded = False +def _validate_transport_backend(backend: object, ep) -> None: + if not isinstance(backend, TransportBackend): + raise TypeError(f"{ep.value} is not a TransportBackend") + if backend.id != ep.name: + raise ValueError( + f"{ep.value} transport id {backend.id!r} does not match entrypoint {ep.name!r}" + ) -def register_transport(backend: TransportBackend) -> None: - existing = _registry.get(backend.id) - if existing is not None and existing is not backend: - raise ConfigError(f"Transport {backend.id!r} is already registered.") - _registry[backend.id] = backend - - -def register_builtin_transports() -> None: - global _builtins_loaded - if _builtins_loaded: - return - from .telegram.backend import telegram_backend - - register_transport(telegram_backend) - _builtins_loaded = True - - -def get_transport(transport_id: str) -> TransportBackend: - register_builtin_transports() +def get_transport( + transport_id: str, *, allowlist: Iterable[str] | None = None +) -> TransportBackend: try: - return _registry[transport_id] - except KeyError: - available = ", ".join(sorted(_registry)) - raise ConfigError( - f"Unknown transport {transport_id!r}. Available: {available}." - ) from None + backend = load_entrypoint( + TRANSPORT_GROUP, + transport_id, + allowlist=allowlist, + validator=_validate_transport_backend, + ) + except PluginNotFound as exc: + if exc.available: + available = ", ".join(exc.available) + message = f"Unknown transport {transport_id!r}. Available: {available}." + else: + message = f"Unknown transport {transport_id!r}." + raise ConfigError(message) from exc + except PluginLoadFailed as exc: + raise ConfigError(f"Failed to load transport {transport_id!r}: {exc}") from exc + return backend -def list_transports() -> list[str]: - register_builtin_transports() - return sorted(_registry) +def list_transports(*, allowlist: Iterable[str] | None = None) -> list[str]: + return list_ids(TRANSPORT_GROUP, allowlist=allowlist) diff --git a/tests/conftest.py b/tests/conftest.py index 3fa616b..34383ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,3 +9,10 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) @pytest.fixture def anyio_backend() -> str: return "asyncio" + + +@pytest.fixture(autouse=True) +def reset_plugins_state() -> None: + import takopi.plugins as plugins + + plugins.reset_plugin_state() diff --git a/tests/plugin_fixtures.py b/tests/plugin_fixtures.py new file mode 100644 index 0000000..ded3273 --- /dev/null +++ b/tests/plugin_fixtures.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Iterable + + +@dataclass(frozen=True, slots=True) +class FakeDist: + name: str + + +class FakeEntryPoint: + def __init__( + self, + name: str, + value: str, + group: str, + *, + loader: Callable[[], Any] | None = None, + dist_name: str | None = "takopi", + ) -> None: + self.name = name + self.value = value + self.group = group + self._loader = loader or (lambda: None) + self.dist = FakeDist(dist_name) if dist_name else None + + def load(self) -> Any: + return self._loader() + + +class FakeEntryPoints(list): + def select(self, *, group: str) -> list[FakeEntryPoint]: + return [ep for ep in self if ep.group == group] + + def get(self, group: str, default: Iterable[Any] | None = None) -> list[Any]: + _ = default + return [ep for ep in self if ep.group == group] + + +def install_entrypoints(monkeypatch, entrypoints: Iterable[FakeEntryPoint]) -> None: + from takopi import plugins + + def _entry_points() -> FakeEntryPoints: + return FakeEntryPoints(entrypoints) + + monkeypatch.setattr(plugins, "entry_points", _entry_points) diff --git a/tests/test_command_registry.py b/tests/test_command_registry.py new file mode 100644 index 0000000..b86fd76 --- /dev/null +++ b/tests/test_command_registry.py @@ -0,0 +1,47 @@ +import pytest + +from takopi import commands, plugins +from takopi.config import ConfigError +from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints + + +class DummyCommand: + id = "hello" + description = "Hello command" + + async def handle(self, ctx): + _ = ctx + return None + + +@pytest.fixture +def command_entrypoints(monkeypatch): + entrypoints = [ + FakeEntryPoint( + "hello", + "takopi.commands.hello:BACKEND", + plugins.COMMAND_GROUP, + loader=DummyCommand, + ) + ] + install_entrypoints(monkeypatch, entrypoints) + return entrypoints + + +def test_command_registry_lists_ids(command_entrypoints) -> None: + ids = commands.list_command_ids() + assert "hello" in ids + + +def test_command_registry_gets_command(command_entrypoints) -> None: + backend = commands.get_command("hello") + assert backend.id == "hello" + + +def test_command_registry_unknown(command_entrypoints) -> None: + with pytest.raises(ConfigError, match="Unknown command"): + commands.get_command("nope") + + +def test_command_registry_optional_missing(command_entrypoints) -> None: + assert commands.get_command("nope", required=False) is None diff --git a/tests/test_engine_discovery.py b/tests/test_engine_discovery.py index 2aae74c..2fb862a 100644 --- a/tests/test_engine_discovery.py +++ b/tests/test_engine_discovery.py @@ -1,28 +1,57 @@ from typing import cast +import pytest + import click import typer -from takopi import cli, engines +from takopi import cli, engines, plugins +from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints -def test_engine_discovery_skips_non_backend() -> None: +@pytest.fixture +def engine_entrypoints(monkeypatch): + entrypoints = [ + FakeEntryPoint( + "codex", + "takopi.runners.codex:BACKEND", + plugins.ENGINE_GROUP, + ), + FakeEntryPoint( + "claude", + "takopi.runners.claude:BACKEND", + plugins.ENGINE_GROUP, + ), + FakeEntryPoint( + "bad-id", + "takopi.runners.bad:BACKEND", + plugins.ENGINE_GROUP, + ), + ] + install_entrypoints(monkeypatch, entrypoints) + monkeypatch.setattr(cli, "_load_settings_optional", lambda: (None, None)) + return entrypoints + + +def test_engine_discovery_filters_invalid_ids(engine_entrypoints) -> None: ids = engines.list_backend_ids() - assert "codex" in ids - assert "claude" in ids - assert "mock" not in ids + assert ids == ["claude", "codex"] -def test_cli_registers_engine_commands_sorted() -> None: - command_names = [cmd.name for cmd in cli.app.registered_commands] +def test_cli_registers_engine_commands_sorted(engine_entrypoints) -> None: + app = cli.create_app() + command_names = [cmd.name for cmd in app.registered_commands] engine_ids = engines.list_backend_ids() assert set(engine_ids) <= set(command_names) engine_commands = [name for name in command_names if name in engine_ids] assert engine_commands == engine_ids -def test_engine_commands_do_not_expose_engine_id_option() -> None: - group = cast(click.Group, typer.main.get_command(cli.app)) +def test_engine_commands_do_not_expose_engine_id_option( + engine_entrypoints, +) -> None: + app = cli.create_app() + group = cast(click.Group, typer.main.get_command(app)) engine_ids = engines.list_backend_ids() ctx = group.make_context("takopi", []) diff --git a/tests/test_plugins.py b/tests/test_plugins.py new file mode 100644 index 0000000..332ad8f --- /dev/null +++ b/tests/test_plugins.py @@ -0,0 +1,184 @@ +import pytest + +from takopi import plugins +from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints + + +def test_list_ids_does_not_load_entrypoints(monkeypatch) -> None: + calls = {"count": 0} + + def loader(): + calls["count"] += 1 + return object() + + entrypoints = [ + FakeEntryPoint( + "codex", + "takopi.runners.codex:BACKEND", + plugins.ENGINE_GROUP, + loader=loader, + ) + ] + install_entrypoints(monkeypatch, entrypoints) + + ids = plugins.list_ids(plugins.ENGINE_GROUP) + assert ids == ["codex"] + assert calls["count"] == 0 + + +def test_load_entrypoint_records_errors(monkeypatch) -> None: + def loader(): + raise RuntimeError("boom") + + entrypoints = [ + FakeEntryPoint( + "broken", + "takopi.runners.broken:BACKEND", + plugins.ENGINE_GROUP, + loader=loader, + ) + ] + install_entrypoints(monkeypatch, entrypoints) + + with pytest.raises(plugins.PluginLoadFailed): + plugins.load_entrypoint(plugins.ENGINE_GROUP, "broken") + + errors = plugins.get_load_errors() + assert errors + assert errors[0].name == "broken" + assert "boom" in errors[0].error + + +def test_duplicate_entrypoints_are_rejected(monkeypatch) -> None: + entrypoints = [ + FakeEntryPoint( + "dup", + "takopi.runners.one:BACKEND", + plugins.ENGINE_GROUP, + dist_name="one", + ), + FakeEntryPoint( + "dup", + "takopi.runners.two:BACKEND", + plugins.ENGINE_GROUP, + dist_name="two", + ), + ] + install_entrypoints(monkeypatch, entrypoints) + + ids = plugins.list_ids(plugins.ENGINE_GROUP) + assert ids == [] + + with pytest.raises(plugins.PluginLoadFailed): + plugins.load_entrypoint(plugins.ENGINE_GROUP, "dup") + + errors = plugins.get_load_errors() + assert any("duplicate plugin id" in err.error for err in errors) + + +def test_allowlist_filters_by_distribution(monkeypatch) -> None: + entrypoints = [ + FakeEntryPoint( + "codex", + "takopi.runners.codex:BACKEND", + plugins.ENGINE_GROUP, + dist_name="takopi", + ), + FakeEntryPoint( + "thirdparty", + "takopi_thirdparty.backend:BACKEND", + plugins.ENGINE_GROUP, + dist_name="takopi-thirdparty", + ), + ] + install_entrypoints(monkeypatch, entrypoints) + + ids = plugins.list_ids(plugins.ENGINE_GROUP, allowlist=["takopi"]) + assert ids == ["codex"] + + +def test_validator_errors_are_captured(monkeypatch) -> None: + entrypoints = [ + FakeEntryPoint( + "bad", + "takopi.runners.bad:BACKEND", + plugins.ENGINE_GROUP, + ) + ] + install_entrypoints(monkeypatch, entrypoints) + + def validator(obj, ep): + raise TypeError("not valid") + + with pytest.raises(plugins.PluginLoadFailed): + plugins.load_entrypoint(plugins.ENGINE_GROUP, "bad", validator=validator) + + errors = plugins.get_load_errors() + assert any("not valid" in err.error for err in errors) + + +def test_reset_plugin_state_clears_cache(monkeypatch) -> None: + calls = {"count": 0} + + def loader(): + calls["count"] += 1 + return object() + + entrypoints = [ + FakeEntryPoint( + "codex", + "takopi.runners.codex:BACKEND", + plugins.ENGINE_GROUP, + loader=loader, + ) + ] + install_entrypoints(monkeypatch, entrypoints) + + plugins.load_entrypoint(plugins.ENGINE_GROUP, "codex") + plugins.load_entrypoint(plugins.ENGINE_GROUP, "codex") + assert calls["count"] == 1 + + plugins.reset_plugin_state() + plugins.load_entrypoint(plugins.ENGINE_GROUP, "codex") + assert calls["count"] == 2 + + +def test_clear_load_errors_filters(monkeypatch) -> None: + def loader(): + raise RuntimeError("boom") + + entrypoints = [ + FakeEntryPoint( + "broken_engine", + "takopi.runners.broken:BACKEND", + plugins.ENGINE_GROUP, + loader=loader, + dist_name="engine-dist", + ), + FakeEntryPoint( + "broken_transport", + "takopi.transports.broken:BACKEND", + plugins.TRANSPORT_GROUP, + loader=loader, + dist_name="transport-dist", + ), + ] + install_entrypoints(monkeypatch, entrypoints) + + with pytest.raises(plugins.PluginLoadFailed): + plugins.load_entrypoint(plugins.ENGINE_GROUP, "broken_engine") + with pytest.raises(plugins.PluginLoadFailed): + plugins.load_entrypoint(plugins.TRANSPORT_GROUP, "broken_transport") + + errors = plugins.get_load_errors() + assert {err.group for err in errors} == { + plugins.ENGINE_GROUP, + plugins.TRANSPORT_GROUP, + } + + plugins.clear_load_errors(group=plugins.ENGINE_GROUP) + errors = plugins.get_load_errors() + assert {err.group for err in errors} == {plugins.TRANSPORT_GROUP} + + plugins.clear_load_errors(name="broken_transport") + assert plugins.get_load_errors() == () diff --git a/tests/test_projects_config.py b/tests/test_projects_config.py index 0c330b1..b718683 100644 --- a/tests/test_projects_config.py +++ b/tests/test_projects_config.py @@ -35,13 +35,14 @@ def test_init_writes_project(monkeypatch, tmp_path) -> None: config_path = tmp_path / "takopi.toml" monkeypatch.setattr("takopi.config.HOME_CONFIG_PATH", config_path) monkeypatch.setattr(cli, "resolve_default_base", lambda _: "main") + monkeypatch.setattr(cli, "_load_settings_optional", lambda: (None, None)) repo_path = tmp_path / "repo" repo_path.mkdir() monkeypatch.chdir(repo_path) runner = CliRunner() - result = runner.invoke(cli.app, ["init", "z80"]) + result = runner.invoke(cli.create_app(), ["init", "z80"]) assert result.exit_code == 0 saved = config_path.read_text(encoding="utf-8") @@ -56,13 +57,14 @@ def test_init_migrates_legacy_config(monkeypatch, tmp_path) -> None: config_path.write_text('bot_token = "token"\nchat_id = 123\n', encoding="utf-8") monkeypatch.setattr("takopi.config.HOME_CONFIG_PATH", config_path) monkeypatch.setattr(cli, "resolve_default_base", lambda _: "main") + monkeypatch.setattr(cli, "_load_settings_optional", lambda: (None, None)) repo_path = tmp_path / "repo" repo_path.mkdir() monkeypatch.chdir(repo_path) runner = CliRunner() - result = runner.invoke(cli.app, ["init", "z80"]) + result = runner.invoke(cli.create_app(), ["init", "z80"]) assert result.exit_code == 0 raw = read_raw_toml(config_path) diff --git a/tests/test_telegram_bridge.py b/tests/test_telegram_bridge.py index 69851da..629511a 100644 --- a/tests/test_telegram_bridge.py +++ b/tests/test_telegram_bridge.py @@ -3,15 +3,16 @@ from pathlib import Path import anyio import pytest +from takopi import commands, plugins +import takopi.telegram.bridge as bridge +from takopi.directives import parse_directives from takopi.telegram.bridge import ( TelegramBridgeConfig, TelegramTransport, _build_bot_commands, _handle_cancel, _is_cancel_command, - _resolve_message, _send_with_resume, - _strip_engine_command, run_main_loop, ) from takopi.context import RunContext @@ -20,8 +21,11 @@ from takopi.runner_bridge import ExecBridgeConfig, RunningTask from takopi.markdown import MarkdownPresenter from takopi.model import EngineId, ResumeToken from takopi.router import AutoRouter, RunnerEntry +from takopi.transport_runtime import TransportRuntime from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait -from takopi.transport import IncomingMessage, MessageRef, RenderedMessage, SendOptions +from takopi.telegram.types import TelegramIncomingMessage +from takopi.transport import MessageRef, RenderedMessage, SendOptions +from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints CODEX_ENGINE = EngineId("codex") @@ -185,59 +189,78 @@ def _make_cfg( presenter=MarkdownPresenter(), final_notify=True, ) + runtime = TransportRuntime( + router=_make_router(runner), + projects=empty_projects_config(), + ) return TelegramBridgeConfig( bot=_FakeBot(), - router=_make_router(runner), + runtime=runtime, chat_id=123, startup_msg="", exec_cfg=exec_cfg, ) -def test_strip_engine_command_inline() -> None: - text, engine = _strip_engine_command( - "/claude do it", engine_ids=("codex", "claude") +def test_parse_directives_inline_engine() -> None: + directives = parse_directives( + "/claude do it", + engine_ids=("codex", "claude"), + projects=empty_projects_config(), ) - assert engine == "claude" - assert text == "do it" + assert directives.engine == "claude" + assert directives.prompt == "do it" -def test_strip_engine_command_newline() -> None: - text, engine = _strip_engine_command( - "/codex\nhello", engine_ids=("codex", "claude") +def test_parse_directives_newline() -> None: + directives = parse_directives( + "/codex\nhello", + engine_ids=("codex", "claude"), + projects=empty_projects_config(), ) - assert engine == "codex" - assert text == "hello" + assert directives.engine == "codex" + assert directives.prompt == "hello" -def test_strip_engine_command_ignores_unknown() -> None: - text, engine = _strip_engine_command("/unknown hi", engine_ids=("codex", "claude")) - assert engine is None - assert text == "/unknown hi" - - -def test_strip_engine_command_bot_suffix() -> None: - text, engine = _strip_engine_command( - "/claude@bunny_agent_bot hi", engine_ids=("claude",) +def test_parse_directives_ignores_unknown() -> None: + directives = parse_directives( + "/unknown hi", + engine_ids=("codex", "claude"), + projects=empty_projects_config(), ) - assert engine == "claude" - assert text == "hi" + assert directives.engine is None + assert directives.prompt == "/unknown hi" -def test_strip_engine_command_only_first_non_empty_line() -> None: - text, engine = _strip_engine_command( - "hello\n/claude hi", engine_ids=("codex", "claude") +def test_parse_directives_bot_suffix() -> None: + directives = parse_directives( + "/claude@bunny_agent_bot hi", + engine_ids=("claude",), + projects=empty_projects_config(), ) - assert engine is None - assert text == "hello\n/claude hi" + assert directives.engine == "claude" + assert directives.prompt == "hi" + + +def test_parse_directives_only_first_non_empty_line() -> None: + directives = parse_directives( + "hello\n/claude hi", + engine_ids=("codex", "claude"), + projects=empty_projects_config(), + ) + assert directives.engine is None + assert directives.prompt == "hello\n/claude hi" def test_build_bot_commands_includes_cancel_and_engine() -> None: runner = ScriptRunner( [Return(answer="ok")], engine=CODEX_ENGINE, resume_value="sid" ) - router = _make_router(runner) - commands = _build_bot_commands(router, empty_projects_config()) + runtime = TransportRuntime( + router=_make_router(runner), + projects=empty_projects_config(), + ) + commands = _build_bot_commands(runtime) assert {"command": "cancel", "description": "cancel run"} in commands assert any(cmd["command"] == "codex" for cmd in commands) @@ -264,12 +287,42 @@ def test_build_bot_commands_includes_projects() -> None: default_project=None, ) - commands = _build_bot_commands(router, projects) + runtime = TransportRuntime(router=router, projects=projects) + commands = _build_bot_commands(runtime) assert any(cmd["command"] == "good" for cmd in commands) assert not any(cmd["command"] == "bad-name" for cmd in commands) +def test_build_bot_commands_includes_command_plugins(monkeypatch) -> None: + class _Command: + id = "pingcmd" + description = "ping command" + + async def handle(self, ctx): + _ = ctx + return None + + entrypoints = [ + FakeEntryPoint( + "pingcmd", + "takopi.commands.ping:BACKEND", + plugins.COMMAND_GROUP, + loader=_Command, + ) + ] + install_entrypoints(monkeypatch, entrypoints) + runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) + runtime = TransportRuntime( + router=_make_router(runner), + projects=empty_projects_config(), + ) + + commands_list = _build_bot_commands(runtime) + + assert {"command": "pingcmd", "description": "ping command"} in commands_list + + def test_build_bot_commands_caps_total() -> None: runner = ScriptRunner( [Return(answer="ok")], engine=CODEX_ENGINE, resume_value="sid" @@ -287,7 +340,8 @@ def test_build_bot_commands_caps_total() -> None: default_project=None, ) - commands = _build_bot_commands(router, projects) + runtime = TransportRuntime(router=router, projects=projects) + commands = _build_bot_commands(runtime) assert len(commands) == 100 assert any(cmd["command"] == "codex" for cmd in commands) @@ -410,7 +464,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None: async def test_handle_cancel_without_reply_prompts_user() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) - msg = IncomingMessage( + msg = TelegramIncomingMessage( transport="telegram", chat_id=123, message_id=10, @@ -431,7 +485,7 @@ async def test_handle_cancel_without_reply_prompts_user() -> None: async def test_handle_cancel_with_no_progress_message_says_nothing_running() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) - msg = IncomingMessage( + msg = TelegramIncomingMessage( transport="telegram", chat_id=123, message_id=10, @@ -453,7 +507,7 @@ async def test_handle_cancel_with_finished_task_says_nothing_running() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) progress_id = 99 - msg = IncomingMessage( + msg = TelegramIncomingMessage( transport="telegram", chat_id=123, message_id=10, @@ -475,7 +529,7 @@ async def test_handle_cancel_cancels_running_task() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) progress_id = 42 - msg = IncomingMessage( + msg = TelegramIncomingMessage( transport="telegram", chat_id=123, message_id=10, @@ -499,7 +553,7 @@ async def test_handle_cancel_only_cancels_matching_progress_message() -> None: cfg = _make_cfg(transport) task_first = RunningTask() task_second = RunningTask() - msg = IncomingMessage( + msg = TelegramIncomingMessage( transport="telegram", chat_id=123, message_id=10, @@ -527,23 +581,22 @@ def test_cancel_command_accepts_extra_text() -> None: def test_resolve_message_accepts_backticked_ctx_line() -> None: - router = _make_router(ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)) - projects = ProjectsConfig( - projects={ - "takopi": ProjectConfig( - alias="takopi", - path=Path("."), - worktrees_dir=Path(".worktrees"), - ) - }, - default_project=None, + runtime = TransportRuntime( + router=_make_router(ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)), + projects=ProjectsConfig( + projects={ + "takopi": ProjectConfig( + alias="takopi", + path=Path("."), + worktrees_dir=Path(".worktrees"), + ) + }, + default_project=None, + ), ) - - resolved = _resolve_message( + resolved = runtime.resolve_message( text="do it", reply_text="`ctx: takopi @ feat/api`", - router=router, - projects=projects, ) assert resolved.prompt == "do it" @@ -643,16 +696,20 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None: presenter=MarkdownPresenter(), final_notify=True, ) + runtime = TransportRuntime( + router=_make_router(runner), + projects=empty_projects_config(), + ) cfg = TelegramBridgeConfig( bot=bot, - router=_make_router(runner), + runtime=runtime, chat_id=123, startup_msg="", exec_cfg=exec_cfg, ) async def poller(_cfg: TelegramBridgeConfig): - yield IncomingMessage( + yield TelegramIncomingMessage( transport="telegram", chat_id=123, message_id=1, @@ -666,7 +723,7 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None: assert isinstance(transport.progress_ref.message_id, int) reply_id = transport.progress_ref.message_id reply_ready.set() - yield IncomingMessage( + yield TelegramIncomingMessage( transport="telegram", chat_id=123, message_id=2, @@ -694,3 +751,212 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None: hold.set() stop_polling.set() tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_run_main_loop_handles_command_plugins(monkeypatch) -> None: + class _Command: + id = "echo_cmd" + description = "echo" + + async def handle(self, ctx): + return commands.CommandResult(text=f"echo:{ctx.args_text}") + + entrypoints = [ + FakeEntryPoint( + "echo_cmd", + "takopi.commands.echo:BACKEND", + plugins.COMMAND_GROUP, + loader=_Command, + ) + ] + install_entrypoints(monkeypatch, entrypoints) + + transport = _FakeTransport() + bot = _FakeBot() + runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) + exec_cfg = ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ) + runtime = TransportRuntime( + router=_make_router(runner), + projects=empty_projects_config(), + ) + cfg = TelegramBridgeConfig( + bot=bot, + runtime=runtime, + chat_id=123, + startup_msg="", + exec_cfg=exec_cfg, + ) + + async def poller(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=1, + text="/echo_cmd hello", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + ) + + await run_main_loop(cfg, poller) + + assert runner.calls == [] + assert transport.send_calls + assert transport.send_calls[-1]["message"].text == "echo:hello" + + +@pytest.mark.anyio +async def test_run_main_loop_command_uses_project_default_engine( + monkeypatch, +) -> None: + class _Command: + id = "use_project" + description = "use project default" + + async def handle(self, ctx): + result = await ctx.executor.run_one( + commands.RunRequest( + prompt="hello", + context=RunContext(project="proj"), + ), + mode="capture", + ) + return commands.CommandResult(text=f"ran:{result.engine}") + + entrypoints = [ + FakeEntryPoint( + "use_project", + "takopi.commands.use_project:BACKEND", + plugins.COMMAND_GROUP, + loader=_Command, + ) + ] + install_entrypoints(monkeypatch, entrypoints) + + transport = _FakeTransport() + bot = _FakeBot() + codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) + pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi")) + router = AutoRouter( + entries=[ + RunnerEntry(engine=codex_runner.engine, runner=codex_runner), + RunnerEntry(engine=pi_runner.engine, runner=pi_runner), + ], + default_engine=codex_runner.engine, + ) + projects = ProjectsConfig( + projects={ + "proj": ProjectConfig( + alias="proj", + path=Path("."), + worktrees_dir=Path(".worktrees"), + default_engine=pi_runner.engine, + ) + }, + default_project=None, + ) + runtime = TransportRuntime( + router=router, + projects=projects, + ) + exec_cfg = ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ) + cfg = TelegramBridgeConfig( + bot=bot, + runtime=runtime, + chat_id=123, + startup_msg="", + exec_cfg=exec_cfg, + ) + + async def poller(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=1, + text="/use_project", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + ) + + await run_main_loop(cfg, poller) + + assert codex_runner.calls == [] + assert len(pi_runner.calls) == 1 + assert transport.send_calls[-1]["message"].text == "ran:pi" + + +@pytest.mark.anyio +async def test_run_main_loop_refreshes_command_ids(monkeypatch) -> None: + class _Command: + id = "late_cmd" + description = "late command" + + async def handle(self, ctx): + return commands.CommandResult(text="late") + + entrypoints = [ + FakeEntryPoint( + "late_cmd", + "takopi.commands.late:BACKEND", + plugins.COMMAND_GROUP, + loader=_Command, + ) + ] + install_entrypoints(monkeypatch, entrypoints) + + calls = {"count": 0} + + def _list_command_ids(*, allowlist=None): + _ = allowlist + calls["count"] += 1 + if calls["count"] == 1: + return [] + return ["late_cmd"] + + monkeypatch.setattr(bridge, "list_command_ids", _list_command_ids) + + transport = _FakeTransport() + bot = _FakeBot() + runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) + exec_cfg = ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ) + runtime = TransportRuntime( + router=_make_router(runner), + projects=empty_projects_config(), + ) + cfg = TelegramBridgeConfig( + bot=bot, + runtime=runtime, + chat_id=123, + startup_msg="", + exec_cfg=exec_cfg, + ) + + async def poller(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=1, + text="/late_cmd hello", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + ) + + await run_main_loop(cfg, poller) + + assert calls["count"] >= 2 + assert transport.send_calls[-1]["message"].text == "late" diff --git a/tests/test_transport_registry.py b/tests/test_transport_registry.py index 286a75c..5bbc64a 100644 --- a/tests/test_transport_registry.py +++ b/tests/test_transport_registry.py @@ -1,19 +1,67 @@ import pytest -from takopi import transports +from takopi import plugins, transports from takopi.config import ConfigError +from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints -def test_transport_registry_lists_telegram() -> None: +class DummyTransport: + id = "telegram" + description = "Telegram" + + def check_setup(self, *args, **kwargs): + raise NotImplementedError + + def interactive_setup(self, *, force: bool) -> bool: + raise NotImplementedError + + def lock_token(self, *, transport_config: dict[str, object], config_path): + _ = transport_config, config_path + raise NotImplementedError + + def build_and_run( + self, + *, + transport_config: dict[str, object], + config_path, + runtime, + final_notify: bool, + default_engine_override: str | None, + ) -> None: + _ = ( + transport_config, + config_path, + runtime, + final_notify, + default_engine_override, + ) + raise NotImplementedError + + +@pytest.fixture +def transport_entrypoints(monkeypatch): + entrypoints = [ + FakeEntryPoint( + "telegram", + "takopi.telegram.backend:telegram_backend", + plugins.TRANSPORT_GROUP, + loader=DummyTransport, + ) + ] + install_entrypoints(monkeypatch, entrypoints) + return entrypoints + + +def test_transport_registry_lists_telegram(transport_entrypoints) -> None: ids = transports.list_transports() assert "telegram" in ids -def test_transport_registry_gets_telegram() -> None: +def test_transport_registry_gets_telegram(transport_entrypoints) -> None: backend = transports.get_transport("telegram") assert backend.id == "telegram" -def test_transport_registry_unknown() -> None: +def test_transport_registry_unknown(transport_entrypoints) -> None: with pytest.raises(ConfigError, match="Unknown transport"): transports.get_transport("nope") diff --git a/tests/test_transport_runtime.py b/tests/test_transport_runtime.py new file mode 100644 index 0000000..7af0e69 --- /dev/null +++ b/tests/test_transport_runtime.py @@ -0,0 +1,45 @@ +from pathlib import Path + +from takopi.config import ProjectConfig, ProjectsConfig +from takopi.context import RunContext +from takopi.router import AutoRouter, RunnerEntry +from takopi.runners.mock import Return, ScriptRunner +from takopi.transport_runtime import TransportRuntime + + +def _make_runtime(*, project_default_engine: str | None = None) -> TransportRuntime: + codex = ScriptRunner([Return(answer="ok")], engine="codex") + pi = ScriptRunner([Return(answer="ok")], engine="pi") + router = AutoRouter( + entries=[ + RunnerEntry(engine=codex.engine, runner=codex), + RunnerEntry(engine=pi.engine, runner=pi), + ], + default_engine=codex.engine, + ) + project = ProjectConfig( + alias="proj", + path=Path("."), + worktrees_dir=Path(".worktrees"), + default_engine=project_default_engine, + ) + projects = ProjectsConfig(projects={"proj": project}, default_project=None) + return TransportRuntime(router=router, projects=projects) + + +def test_resolve_engine_uses_project_default() -> None: + runtime = _make_runtime(project_default_engine="pi") + engine = runtime.resolve_engine( + engine_override=None, + context=RunContext(project="proj"), + ) + assert engine == "pi" + + +def test_resolve_engine_prefers_override() -> None: + runtime = _make_runtime(project_default_engine="pi") + engine = runtime.resolve_engine( + engine_override="codex", + context=RunContext(project="proj"), + ) + assert engine == "codex"