From 81618e48e4d3769f31070c8702f10002eb46a2e3 Mon Sep 17 00:00:00 2001 From: banteg <4562643+banteg@users.noreply.github.com> Date: Sat, 10 Jan 2026 01:22:20 +0400 Subject: [PATCH] feat: add per-project chat routing (#76) --- docs/developing.md | 2 +- docs/projects.md | 6 +++ readme.md | 3 +- src/takopi/config.py | 43 ++++++++++++++++- src/takopi/settings.py | 39 ++++++++++++++- src/takopi/telegram/backend.py | 2 + src/takopi/telegram/bridge.py | 27 ++++++++++- src/takopi/telegram/client.py | 21 +++++++-- src/takopi/transport_runtime.py | 27 +++++++++-- tests/test_projects_config.py | 31 ++++++++++++ tests/test_telegram_bridge.py | 84 +++++++++++++++++++++++++++++++++ tests/test_transport_runtime.py | 28 +++++++++++ 12 files changed, 299 insertions(+), 14 deletions(-) diff --git a/docs/developing.md b/docs/developing.md index 6211e33..984bcf8 100644 --- a/docs/developing.md +++ b/docs/developing.md @@ -325,7 +325,7 @@ See `docs/adding-a-runner.md` for the full guide and a worked example. ``` Telegram Update ↓ -telegram/bridge.poll_updates() drains backlog, long-polls, filters chat_id == cfg.chat_id +telegram/bridge.poll_updates() drains backlog, long-polls, filters allowed chat ids ↓ telegram/bridge.run_main_loop() spawns tasks in TaskGroup ↓ diff --git a/docs/projects.md b/docs/projects.md index a94d982..5a8d521 100644 --- a/docs/projects.md +++ b/docs/projects.md @@ -31,6 +31,7 @@ path = "~/dev/z80" # required (repo root) worktrees_dir = ".worktrees" # optional, default ".worktrees" default_engine = "codex" # optional, per-project override worktree_base = "master" # optional, base for new branches +chat_id = -123 # optional, project chat id ``` Legacy config note: top-level `bot_token` / `chat_id` are auto-migrated into @@ -52,6 +53,7 @@ Validation rules: - `default_project` must match a configured project alias. - Project aliases cannot collide with engine ids or reserved commands (`/cancel`). - `default_engine` and per-project `default_engine` must be valid engine ids. +- `projects..chat_id` must be unique and must not match `transports.telegram.chat_id`. - `transport` defaults to `"telegram"` when omitted; override per-run with `--transport`. ## `takopi init` @@ -95,6 +97,10 @@ code (backticked): The `ctx:` line is parsed from replies and takes precedence over new directives. +When a message arrives in a chat whose `chat_id` matches `projects..chat_id`, +Takopi defaults the project context to that alias unless a reply `ctx:` or explicit +`/project` directive is present. + ## Worktree resolution When `@branch` is present: diff --git a/readme.md b/readme.md index fa4023b..6d399b5 100644 --- a/readme.md +++ b/readme.md @@ -111,6 +111,7 @@ path = "~/dev/z80" worktrees_dir = ".worktrees" default_engine = "codex" worktree_base = "master" +chat_id = -123456789 # optional, project chat id ``` note: the default `worktrees_dir` lives inside the repo, so `.worktrees/` will @@ -163,7 +164,7 @@ see: ## notes -* the bot only responds to the configured `chat_id` (private or group) +* the bot only responds to the primary `chat_id` plus any per-project `chat_id` * run only one takopi instance per bot token: multiple instances will race telegram's `getUpdates` offsets and cause missed updates ## development diff --git a/src/takopi/config.py b/src/takopi/config.py index ee36f13..b882c81 100644 --- a/src/takopi/config.py +++ b/src/takopi/config.py @@ -1,7 +1,7 @@ from __future__ import annotations import tomllib -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Any, Iterable @@ -41,6 +41,7 @@ class ProjectConfig: worktrees_dir: Path default_engine: str | None = None worktree_base: str | None = None + chat_id: int | None = None @property def worktrees_root(self) -> Path: @@ -53,6 +54,7 @@ class ProjectConfig: class ProjectsConfig: projects: dict[str, ProjectConfig] default_project: str | None = None + chat_map: dict[int, str] = field(default_factory=dict) def resolve(self, alias: str | None) -> ProjectConfig | None: if alias is None: @@ -61,6 +63,14 @@ class ProjectsConfig: return self.projects.get(self.default_project) return self.projects.get(alias.lower()) + def project_for_chat(self, chat_id: int | None) -> str | None: + if chat_id is None: + return None + return self.chat_map.get(chat_id) + + def project_chat_ids(self) -> tuple[int, ...]: + return tuple(self.chat_map.keys()) + def empty_projects_config() -> ProjectsConfig: return ProjectsConfig(projects={}, default_project=None) @@ -99,6 +109,7 @@ def parse_projects_config( config_path: Path, engine_ids: Iterable[str], reserved: Iterable[str] = ("cancel",), + default_chat_id: int | None = None, ) -> ProjectsConfig: default_project_raw = config.get("default_project") default_project = None @@ -116,6 +127,7 @@ def parse_projects_config( reserved_lower = {value.lower() for value in reserved} engine_lower = {value.lower() for value in engine_ids} projects: dict[str, ProjectConfig] = {} + chat_map: dict[int, str] = {} for raw_alias, raw_entry in projects_raw.items(): if not isinstance(raw_alias, str) or not raw_alias.strip(): @@ -173,12 +185,35 @@ def parse_projects_config( ) worktree_base = worktree_base_raw.strip() + chat_id_raw = raw_entry.get("chat_id") + chat_id = None + if chat_id_raw is not None: + if isinstance(chat_id_raw, bool) or not isinstance(chat_id_raw, int): + raise ConfigError( + f"Invalid `projects.{alias}.chat_id` in {config_path}; " + "expected an integer." + ) + chat_id = chat_id_raw + if default_chat_id is not None and chat_id == default_chat_id: + raise ConfigError( + f"Invalid `projects.{alias}.chat_id` in {config_path}; " + "must not match transports.telegram.chat_id." + ) + if chat_id in chat_map: + existing = chat_map[chat_id] + raise ConfigError( + f"Duplicate `projects.*.chat_id` {chat_id} in {config_path}; " + f"already used by {existing!r}." + ) + chat_map[chat_id] = alias_key + projects[alias_key] = ProjectConfig( alias=alias, path=path, worktrees_dir=worktrees_dir, default_engine=default_engine, worktree_base=worktree_base, + chat_id=chat_id, ) if default_project is not None: @@ -190,7 +225,11 @@ def parse_projects_config( ) default_project = default_key - return ProjectsConfig(projects=projects, default_project=default_project) + return ProjectsConfig( + projects=projects, + default_project=default_project, + chat_map=chat_map, + ) def _toml_escape(value: str) -> str: diff --git a/src/takopi/settings.py b/src/takopi/settings.py index 36b4b7a..74a4641 100644 --- a/src/takopi/settings.py +++ b/src/takopi/settings.py @@ -70,6 +70,7 @@ class ProjectSettings(BaseModel): worktrees_dir: str = ".worktrees" default_engine: str | None = None worktree_base: str | None = None + chat_id: int | None = None model_config = ConfigDict(extra="allow") @@ -91,6 +92,15 @@ class ProjectSettings(BaseModel): raise ValueError(f"{info.field_name} must be a non-empty string") return cleaned + @field_validator("chat_id", mode="before") + @classmethod + def _validate_chat_id(cls, value: Any) -> Any: + if value is None: + return None + if isinstance(value, bool) or not isinstance(value, int): + raise ValueError("chat_id must be an integer") + return value + class TakopiSettings(BaseSettings): model_config = SettingsConfigDict( @@ -194,10 +204,12 @@ class TakopiSettings(BaseSettings): reserved: Iterable[str] = ("cancel",), ) -> ProjectsConfig: default_project = self.default_project + default_chat_id = self.transports.telegram.chat_id reserved_lower = {value.lower() for value in reserved} engine_map = {engine.lower(): engine for engine in engine_ids} projects: dict[str, ProjectConfig] = {} + chat_map: dict[int, str] = {} for raw_alias, entry in self.projects.items(): if not isinstance(raw_alias, str) or not raw_alias.strip(): @@ -258,12 +270,33 @@ class TakopiSettings(BaseSettings): ) worktree_base = worktree_base_raw.strip() + chat_id = entry.chat_id + if chat_id is not None: + if isinstance(chat_id, bool) or not isinstance(chat_id, int): + raise ConfigError( + f"Invalid `projects.{alias}.chat_id` in {config_path}; " + "expected an integer." + ) + if default_chat_id is not None and chat_id == default_chat_id: + raise ConfigError( + f"Invalid `projects.{alias}.chat_id` in {config_path}; " + "must not match transports.telegram.chat_id." + ) + if chat_id in chat_map: + existing = chat_map[chat_id] + raise ConfigError( + f"Duplicate `projects.*.chat_id` {chat_id} in {config_path}; " + f"already used by {existing!r}." + ) + chat_map[chat_id] = alias_key + projects[alias_key] = ProjectConfig( alias=alias, path=path, worktrees_dir=worktrees_dir, default_engine=default_engine, worktree_base=worktree_base, + chat_id=chat_id, ) if default_project is not None: @@ -275,7 +308,11 @@ class TakopiSettings(BaseSettings): ) default_project = default_key - return ProjectsConfig(projects=projects, default_project=default_project) + return ProjectsConfig( + projects=projects, + default_project=default_project, + chat_map=chat_map, + ) def load_settings(path: str | Path | None = None) -> tuple[TakopiSettings, Path]: diff --git a/src/takopi/telegram/backend.py b/src/takopi/telegram/backend.py index 5a52515..fa1a002 100644 --- a/src/takopi/telegram/backend.py +++ b/src/takopi/telegram/backend.py @@ -97,10 +97,12 @@ class TelegramBackend(TransportBackend): final_notify=final_notify, ) voice_transcription = _build_voice_transcription_config(transport_config) + chat_ids = (chat_id, *runtime.project_chat_ids()) cfg = TelegramBridgeConfig( bot=bot, runtime=runtime, chat_id=chat_id, + chat_ids=chat_ids, startup_msg=startup_msg, exec_cfg=exec_cfg, voice_transcription=voice_transcription, diff --git a/src/takopi/telegram/bridge.py b/src/takopi/telegram/bridge.py index d1693de..7157f5e 100644 --- a/src/takopi/telegram/bridge.py +++ b/src/takopi/telegram/bridge.py @@ -297,6 +297,13 @@ class TelegramBridgeConfig: startup_msg: str exec_cfg: ExecBridgeConfig voice_transcription: TelegramVoiceTranscriptionConfig | None = None + chat_ids: tuple[int, ...] | None = None + + +def _allowed_chat_ids(cfg: TelegramBridgeConfig) -> set[int]: + allowed = set(cfg.chat_ids or ()) + allowed.add(cfg.chat_id) + return allowed async def _send_plain( @@ -353,7 +360,11 @@ async def poll_updates( offset = await _drain_backlog(cfg, offset) await _send_startup(cfg) - async for msg in poll_incoming(cfg.bot, chat_id=cfg.chat_id, offset=offset): + async for msg in poll_incoming( + cfg.bot, + chat_ids=_allowed_chat_ids(cfg), + offset=offset, + ): yield msg @@ -746,6 +757,18 @@ class _TelegramCommandExecutor(CommandExecutor): self._user_msg_id = user_msg_id self._reply_ref = MessageRef(channel_id=chat_id, message_id=user_msg_id) + def _apply_default_context(self, request: RunRequest) -> RunRequest: + if request.context is not None: + return request + context = self._runtime.default_context_for_chat(self._chat_id) + if context is None: + return request + return RunRequest( + prompt=request.prompt, + engine=request.engine, + context=context, + ) + async def send( self, message: RenderedMessage | str, @@ -768,6 +791,7 @@ class _TelegramCommandExecutor(CommandExecutor): async def run_one( self, request: RunRequest, *, mode: RunMode = "emit" ) -> RunResult: + request = self._apply_default_context(request) engine = self._runtime.resolve_engine( engine_override=request.engine, context=request.context, @@ -999,6 +1023,7 @@ async def run_main_loop( resolved = cfg.runtime.resolve_message( text=text, reply_text=reply_text, + chat_id=chat_id, ) except DirectiveError as exc: await _send_plain( diff --git a/src/takopi/telegram/client.py b/src/takopi/telegram/client.py index 137abe1..37e0a28 100644 --- a/src/takopi/telegram/client.py +++ b/src/takopi/telegram/client.py @@ -9,6 +9,7 @@ from typing import ( Awaitable, Callable, Hashable, + Iterable, Protocol, TYPE_CHECKING, ) @@ -44,7 +45,10 @@ def is_group_chat_id(chat_id: int) -> bool: def parse_incoming_update( - update: dict[str, Any], *, chat_id: int + update: dict[str, Any], + *, + chat_id: int | None = None, + chat_ids: set[int] | None = None, ) -> TelegramIncomingMessage | None: msg = update.get("message") if not isinstance(msg, dict): @@ -78,7 +82,12 @@ def parse_incoming_update( if not isinstance(chat, dict): return None msg_chat_id = chat.get("id") - if not isinstance(msg_chat_id, int) or msg_chat_id != chat_id: + if not isinstance(msg_chat_id, int): + return None + allowed = chat_ids + if allowed is None and chat_id is not None: + allowed = {chat_id} + if allowed is not None and msg_chat_id not in allowed: return None message_id = msg.get("message_id") if not isinstance(message_id, int): @@ -117,9 +126,13 @@ def parse_incoming_update( async def poll_incoming( bot: BotClient, *, - chat_id: int, + chat_id: int | None = None, + chat_ids: Iterable[int] | None = None, offset: int | None = None, ) -> AsyncIterator[TelegramIncomingMessage]: + allowed = set(chat_ids) if chat_ids is not None else None + if allowed is None and chat_id is not None: + allowed = {chat_id} while True: updates = await bot.get_updates( offset=offset, timeout_s=50, allowed_updates=["message"] @@ -131,7 +144,7 @@ async def poll_incoming( logger.debug("loop.updates", updates=updates) for upd in updates: offset = upd["update_id"] + 1 - msg = parse_incoming_update(upd, chat_id=chat_id) + msg = parse_incoming_update(upd, chat_ids=allowed) if msg is not None: yield msg diff --git a/src/takopi/transport_runtime.py b/src/takopi/transport_runtime.py index 455299e..ac94e54 100644 --- a/src/takopi/transport_runtime.py +++ b/src/takopi/transport_runtime.py @@ -110,7 +110,13 @@ class TransportRuntime: ) return dict(raw) - def resolve_message(self, *, text: str, reply_text: str | None) -> ResolvedMessage: + def resolve_message( + self, + *, + text: str, + reply_text: str | None, + chat_id: int | None = None, + ) -> ResolvedMessage: directives = parse_directives( text, engine_ids=self._router.engine_ids, @@ -118,13 +124,17 @@ class TransportRuntime: ) reply_ctx = parse_context_line(reply_text, projects=self._projects) resume_token = self._router.resolve_resume(directives.prompt, reply_text) + chat_project = self._projects.project_for_chat(chat_id) if resume_token is not None: + context = reply_ctx + if context is None and chat_project is not None: + context = RunContext(project=chat_project, branch=None) return ResolvedMessage( prompt=directives.prompt, resume_token=resume_token, engine_override=None, - context=reply_ctx, + context=context, ) if reply_ctx is not None: @@ -141,8 +151,8 @@ class TransportRuntime: ) project_key = directives.project - if project_key is None and self._projects.default_project is not None: - project_key = self._projects.default_project + if project_key is None: + project_key = chat_project or self._projects.default_project context = None if project_key is not None or directives.branch is not None: @@ -161,6 +171,15 @@ class TransportRuntime: context=context, ) + def default_context_for_chat(self, chat_id: int | None) -> RunContext | None: + project_key = self._projects.project_for_chat(chat_id) + if project_key is None: + return None + return RunContext(project=project_key, branch=None) + + def project_chat_ids(self) -> tuple[int, ...]: + return self._projects.project_chat_ids() + def resolve_runner( self, *, diff --git a/tests/test_projects_config.py b/tests/test_projects_config.py index b718683..6a3bb54 100644 --- a/tests/test_projects_config.py +++ b/tests/test_projects_config.py @@ -87,6 +87,37 @@ def test_projects_default_engine_unknown() -> None: ) +def test_projects_chat_id_cannot_match_transport_chat_id() -> None: + config = { + "transports": {"telegram": {"bot_token": "token", "chat_id": 123}}, + "projects": {"z80": {"path": "/tmp/repo", "chat_id": 123}}, + } + settings = TakopiSettings.model_validate(config) + with pytest.raises(ConfigError, match="chat_id"): + settings.to_projects_config( + config_path=Path("takopi.toml"), + engine_ids=["codex"], + reserved=("cancel",), + ) + + +def test_projects_chat_id_must_be_unique() -> None: + config = { + "transports": {"telegram": {"bot_token": "token", "chat_id": 123}}, + "projects": { + "a": {"path": "/tmp/a", "chat_id": -10}, + "b": {"path": "/tmp/b", "chat_id": -10}, + }, + } + settings = TakopiSettings.model_validate(config) + with pytest.raises(ConfigError, match="chat_id"): + settings.to_projects_config( + config_path=Path("takopi.toml"), + engine_ids=["codex"], + reserved=("cancel",), + ) + + def test_projects_relative_path_resolves(tmp_path: Path) -> None: config_path = tmp_path / "takopi.toml" settings = TakopiSettings.model_validate({"projects": {"z80": {"path": "repo"}}}) diff --git a/tests/test_telegram_bridge.py b/tests/test_telegram_bridge.py index 56fb058..8d4786e 100644 --- a/tests/test_telegram_bridge.py +++ b/tests/test_telegram_bridge.py @@ -911,6 +911,90 @@ async def test_run_main_loop_command_uses_project_default_engine( assert transport.send_calls[-1]["message"].text == "ran:pi" +@pytest.mark.anyio +async def test_run_main_loop_command_defaults_to_chat_project( + monkeypatch, +) -> None: + class _Command: + id = "auto_ctx" + description = "auto context" + + async def handle(self, ctx): + result = await ctx.executor.run_one( + commands.RunRequest(prompt="hello"), + mode="capture", + ) + return commands.CommandResult(text=f"ran:{result.engine}") + + entrypoints = [ + FakeEntryPoint( + "auto_ctx", + "takopi.commands.auto_ctx:BACKEND", + plugins.COMMAND_GROUP, + loader=_Command, + ) + ] + install_entrypoints(monkeypatch, entrypoints) + + transport = _FakeTransport() + bot = _FakeBot() + codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) + pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi")) + router = AutoRouter( + entries=[ + RunnerEntry(engine=codex_runner.engine, runner=codex_runner), + RunnerEntry(engine=pi_runner.engine, runner=pi_runner), + ], + default_engine=codex_runner.engine, + ) + projects = ProjectsConfig( + projects={ + "proj": ProjectConfig( + alias="proj", + path=Path("."), + worktrees_dir=Path(".worktrees"), + default_engine=pi_runner.engine, + chat_id=-42, + ) + }, + default_project=None, + chat_map={-42: "proj"}, + ) + runtime = TransportRuntime( + router=router, + projects=projects, + ) + exec_cfg = ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ) + cfg = TelegramBridgeConfig( + bot=bot, + runtime=runtime, + chat_id=123, + startup_msg="", + exec_cfg=exec_cfg, + ) + + async def poller(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=-42, + message_id=1, + text="/auto_ctx", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + ) + + await run_main_loop(cfg, poller) + + assert codex_runner.calls == [] + assert len(pi_runner.calls) == 1 + assert transport.send_calls[-1]["message"].text == "ran:pi" + + @pytest.mark.anyio async def test_run_main_loop_refreshes_command_ids(monkeypatch) -> None: class _Command: diff --git a/tests/test_transport_runtime.py b/tests/test_transport_runtime.py index 7af0e69..8cd6db6 100644 --- a/tests/test_transport_runtime.py +++ b/tests/test_transport_runtime.py @@ -43,3 +43,31 @@ def test_resolve_engine_prefers_override() -> None: context=RunContext(project="proj"), ) assert engine == "codex" + + +def test_resolve_message_defaults_to_chat_project() -> None: + codex = ScriptRunner([Return(answer="ok")], engine="codex") + router = AutoRouter( + entries=[RunnerEntry(engine=codex.engine, runner=codex)], + default_engine=codex.engine, + ) + project = ProjectConfig( + alias="proj", + path=Path("."), + worktrees_dir=Path(".worktrees"), + chat_id=-42, + ) + projects = ProjectsConfig( + projects={"proj": project}, + default_project=None, + chat_map={-42: "proj"}, + ) + runtime = TransportRuntime(router=router, projects=projects) + + resolved = runtime.resolve_message( + text="hello", + reply_text=None, + chat_id=-42, + ) + + assert resolved.context == RunContext(project="proj", branch=None)