feat: add per-project chat routing (#76)

This commit is contained in:
banteg
2026-01-10 01:22:20 +04:00
committed by GitHub
parent 7ffb99d779
commit 81618e48e4
12 changed files with 299 additions and 14 deletions
+1 -1
View File
@@ -325,7 +325,7 @@ See `docs/adding-a-runner.md` for the full guide and a worked example.
``` ```
Telegram Update Telegram Update
telegram/bridge.poll_updates() drains backlog, long-polls, filters chat_id == cfg.chat_id telegram/bridge.poll_updates() drains backlog, long-polls, filters allowed chat ids
telegram/bridge.run_main_loop() spawns tasks in TaskGroup telegram/bridge.run_main_loop() spawns tasks in TaskGroup
+6
View File
@@ -31,6 +31,7 @@ path = "~/dev/z80" # required (repo root)
worktrees_dir = ".worktrees" # optional, default ".worktrees" worktrees_dir = ".worktrees" # optional, default ".worktrees"
default_engine = "codex" # optional, per-project override default_engine = "codex" # optional, per-project override
worktree_base = "master" # optional, base for new branches worktree_base = "master" # optional, base for new branches
chat_id = -123 # optional, project chat id
``` ```
Legacy config note: top-level `bot_token` / `chat_id` are auto-migrated into Legacy config note: top-level `bot_token` / `chat_id` are auto-migrated into
@@ -52,6 +53,7 @@ Validation rules:
- `default_project` must match a configured project alias. - `default_project` must match a configured project alias.
- Project aliases cannot collide with engine ids or reserved commands (`/cancel`). - Project aliases cannot collide with engine ids or reserved commands (`/cancel`).
- `default_engine` and per-project `default_engine` must be valid engine ids. - `default_engine` and per-project `default_engine` must be valid engine ids.
- `projects.<alias>.chat_id` must be unique and must not match `transports.telegram.chat_id`.
- `transport` defaults to `"telegram"` when omitted; override per-run with `--transport`. - `transport` defaults to `"telegram"` when omitted; override per-run with `--transport`.
## `takopi init` ## `takopi init`
@@ -95,6 +97,10 @@ code (backticked):
The `ctx:` line is parsed from replies and takes precedence over new directives. The `ctx:` line is parsed from replies and takes precedence over new directives.
When a message arrives in a chat whose `chat_id` matches `projects.<alias>.chat_id`,
Takopi defaults the project context to that alias unless a reply `ctx:` or explicit
`/project` directive is present.
## Worktree resolution ## Worktree resolution
When `@branch` is present: When `@branch` is present:
+2 -1
View File
@@ -111,6 +111,7 @@ path = "~/dev/z80"
worktrees_dir = ".worktrees" worktrees_dir = ".worktrees"
default_engine = "codex" default_engine = "codex"
worktree_base = "master" worktree_base = "master"
chat_id = -123456789 # optional, project chat id
``` ```
note: the default `worktrees_dir` lives inside the repo, so `.worktrees/` will note: the default `worktrees_dir` lives inside the repo, so `.worktrees/` will
@@ -163,7 +164,7 @@ see:
## notes ## notes
* the bot only responds to the configured `chat_id` (private or group) * the bot only responds to the primary `chat_id` plus any per-project `chat_id`
* run only one takopi instance per bot token: multiple instances will race telegram's `getUpdates` offsets and cause missed updates * run only one takopi instance per bot token: multiple instances will race telegram's `getUpdates` offsets and cause missed updates
## development ## development
+41 -2
View File
@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import tomllib import tomllib
from dataclasses import dataclass from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Iterable from typing import Any, Iterable
@@ -41,6 +41,7 @@ class ProjectConfig:
worktrees_dir: Path worktrees_dir: Path
default_engine: str | None = None default_engine: str | None = None
worktree_base: str | None = None worktree_base: str | None = None
chat_id: int | None = None
@property @property
def worktrees_root(self) -> Path: def worktrees_root(self) -> Path:
@@ -53,6 +54,7 @@ class ProjectConfig:
class ProjectsConfig: class ProjectsConfig:
projects: dict[str, ProjectConfig] projects: dict[str, ProjectConfig]
default_project: str | None = None default_project: str | None = None
chat_map: dict[int, str] = field(default_factory=dict)
def resolve(self, alias: str | None) -> ProjectConfig | None: def resolve(self, alias: str | None) -> ProjectConfig | None:
if alias is None: if alias is None:
@@ -61,6 +63,14 @@ class ProjectsConfig:
return self.projects.get(self.default_project) return self.projects.get(self.default_project)
return self.projects.get(alias.lower()) return self.projects.get(alias.lower())
def project_for_chat(self, chat_id: int | None) -> str | None:
if chat_id is None:
return None
return self.chat_map.get(chat_id)
def project_chat_ids(self) -> tuple[int, ...]:
return tuple(self.chat_map.keys())
def empty_projects_config() -> ProjectsConfig: def empty_projects_config() -> ProjectsConfig:
return ProjectsConfig(projects={}, default_project=None) return ProjectsConfig(projects={}, default_project=None)
@@ -99,6 +109,7 @@ def parse_projects_config(
config_path: Path, config_path: Path,
engine_ids: Iterable[str], engine_ids: Iterable[str],
reserved: Iterable[str] = ("cancel",), reserved: Iterable[str] = ("cancel",),
default_chat_id: int | None = None,
) -> ProjectsConfig: ) -> ProjectsConfig:
default_project_raw = config.get("default_project") default_project_raw = config.get("default_project")
default_project = None default_project = None
@@ -116,6 +127,7 @@ def parse_projects_config(
reserved_lower = {value.lower() for value in reserved} reserved_lower = {value.lower() for value in reserved}
engine_lower = {value.lower() for value in engine_ids} engine_lower = {value.lower() for value in engine_ids}
projects: dict[str, ProjectConfig] = {} projects: dict[str, ProjectConfig] = {}
chat_map: dict[int, str] = {}
for raw_alias, raw_entry in projects_raw.items(): for raw_alias, raw_entry in projects_raw.items():
if not isinstance(raw_alias, str) or not raw_alias.strip(): if not isinstance(raw_alias, str) or not raw_alias.strip():
@@ -173,12 +185,35 @@ def parse_projects_config(
) )
worktree_base = worktree_base_raw.strip() worktree_base = worktree_base_raw.strip()
chat_id_raw = raw_entry.get("chat_id")
chat_id = None
if chat_id_raw is not None:
if isinstance(chat_id_raw, bool) or not isinstance(chat_id_raw, int):
raise ConfigError(
f"Invalid `projects.{alias}.chat_id` in {config_path}; "
"expected an integer."
)
chat_id = chat_id_raw
if default_chat_id is not None and chat_id == default_chat_id:
raise ConfigError(
f"Invalid `projects.{alias}.chat_id` in {config_path}; "
"must not match transports.telegram.chat_id."
)
if chat_id in chat_map:
existing = chat_map[chat_id]
raise ConfigError(
f"Duplicate `projects.*.chat_id` {chat_id} in {config_path}; "
f"already used by {existing!r}."
)
chat_map[chat_id] = alias_key
projects[alias_key] = ProjectConfig( projects[alias_key] = ProjectConfig(
alias=alias, alias=alias,
path=path, path=path,
worktrees_dir=worktrees_dir, worktrees_dir=worktrees_dir,
default_engine=default_engine, default_engine=default_engine,
worktree_base=worktree_base, worktree_base=worktree_base,
chat_id=chat_id,
) )
if default_project is not None: if default_project is not None:
@@ -190,7 +225,11 @@ def parse_projects_config(
) )
default_project = default_key default_project = default_key
return ProjectsConfig(projects=projects, default_project=default_project) return ProjectsConfig(
projects=projects,
default_project=default_project,
chat_map=chat_map,
)
def _toml_escape(value: str) -> str: def _toml_escape(value: str) -> str:
+38 -1
View File
@@ -70,6 +70,7 @@ class ProjectSettings(BaseModel):
worktrees_dir: str = ".worktrees" worktrees_dir: str = ".worktrees"
default_engine: str | None = None default_engine: str | None = None
worktree_base: str | None = None worktree_base: str | None = None
chat_id: int | None = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@@ -91,6 +92,15 @@ class ProjectSettings(BaseModel):
raise ValueError(f"{info.field_name} must be a non-empty string") raise ValueError(f"{info.field_name} must be a non-empty string")
return cleaned return cleaned
@field_validator("chat_id", mode="before")
@classmethod
def _validate_chat_id(cls, value: Any) -> Any:
if value is None:
return None
if isinstance(value, bool) or not isinstance(value, int):
raise ValueError("chat_id must be an integer")
return value
class TakopiSettings(BaseSettings): class TakopiSettings(BaseSettings):
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
@@ -194,10 +204,12 @@ class TakopiSettings(BaseSettings):
reserved: Iterable[str] = ("cancel",), reserved: Iterable[str] = ("cancel",),
) -> ProjectsConfig: ) -> ProjectsConfig:
default_project = self.default_project default_project = self.default_project
default_chat_id = self.transports.telegram.chat_id
reserved_lower = {value.lower() for value in reserved} reserved_lower = {value.lower() for value in reserved}
engine_map = {engine.lower(): engine for engine in engine_ids} engine_map = {engine.lower(): engine for engine in engine_ids}
projects: dict[str, ProjectConfig] = {} projects: dict[str, ProjectConfig] = {}
chat_map: dict[int, str] = {}
for raw_alias, entry in self.projects.items(): for raw_alias, entry in self.projects.items():
if not isinstance(raw_alias, str) or not raw_alias.strip(): if not isinstance(raw_alias, str) or not raw_alias.strip():
@@ -258,12 +270,33 @@ class TakopiSettings(BaseSettings):
) )
worktree_base = worktree_base_raw.strip() worktree_base = worktree_base_raw.strip()
chat_id = entry.chat_id
if chat_id is not None:
if isinstance(chat_id, bool) or not isinstance(chat_id, int):
raise ConfigError(
f"Invalid `projects.{alias}.chat_id` in {config_path}; "
"expected an integer."
)
if default_chat_id is not None and chat_id == default_chat_id:
raise ConfigError(
f"Invalid `projects.{alias}.chat_id` in {config_path}; "
"must not match transports.telegram.chat_id."
)
if chat_id in chat_map:
existing = chat_map[chat_id]
raise ConfigError(
f"Duplicate `projects.*.chat_id` {chat_id} in {config_path}; "
f"already used by {existing!r}."
)
chat_map[chat_id] = alias_key
projects[alias_key] = ProjectConfig( projects[alias_key] = ProjectConfig(
alias=alias, alias=alias,
path=path, path=path,
worktrees_dir=worktrees_dir, worktrees_dir=worktrees_dir,
default_engine=default_engine, default_engine=default_engine,
worktree_base=worktree_base, worktree_base=worktree_base,
chat_id=chat_id,
) )
if default_project is not None: if default_project is not None:
@@ -275,7 +308,11 @@ class TakopiSettings(BaseSettings):
) )
default_project = default_key default_project = default_key
return ProjectsConfig(projects=projects, default_project=default_project) return ProjectsConfig(
projects=projects,
default_project=default_project,
chat_map=chat_map,
)
def load_settings(path: str | Path | None = None) -> tuple[TakopiSettings, Path]: def load_settings(path: str | Path | None = None) -> tuple[TakopiSettings, Path]:
+2
View File
@@ -97,10 +97,12 @@ class TelegramBackend(TransportBackend):
final_notify=final_notify, final_notify=final_notify,
) )
voice_transcription = _build_voice_transcription_config(transport_config) voice_transcription = _build_voice_transcription_config(transport_config)
chat_ids = (chat_id, *runtime.project_chat_ids())
cfg = TelegramBridgeConfig( cfg = TelegramBridgeConfig(
bot=bot, bot=bot,
runtime=runtime, runtime=runtime,
chat_id=chat_id, chat_id=chat_id,
chat_ids=chat_ids,
startup_msg=startup_msg, startup_msg=startup_msg,
exec_cfg=exec_cfg, exec_cfg=exec_cfg,
voice_transcription=voice_transcription, voice_transcription=voice_transcription,
+26 -1
View File
@@ -297,6 +297,13 @@ class TelegramBridgeConfig:
startup_msg: str startup_msg: str
exec_cfg: ExecBridgeConfig exec_cfg: ExecBridgeConfig
voice_transcription: TelegramVoiceTranscriptionConfig | None = None voice_transcription: TelegramVoiceTranscriptionConfig | None = None
chat_ids: tuple[int, ...] | None = None
def _allowed_chat_ids(cfg: TelegramBridgeConfig) -> set[int]:
allowed = set(cfg.chat_ids or ())
allowed.add(cfg.chat_id)
return allowed
async def _send_plain( async def _send_plain(
@@ -353,7 +360,11 @@ async def poll_updates(
offset = await _drain_backlog(cfg, offset) offset = await _drain_backlog(cfg, offset)
await _send_startup(cfg) await _send_startup(cfg)
async for msg in poll_incoming(cfg.bot, chat_id=cfg.chat_id, offset=offset): async for msg in poll_incoming(
cfg.bot,
chat_ids=_allowed_chat_ids(cfg),
offset=offset,
):
yield msg yield msg
@@ -746,6 +757,18 @@ class _TelegramCommandExecutor(CommandExecutor):
self._user_msg_id = user_msg_id self._user_msg_id = user_msg_id
self._reply_ref = MessageRef(channel_id=chat_id, message_id=user_msg_id) self._reply_ref = MessageRef(channel_id=chat_id, message_id=user_msg_id)
def _apply_default_context(self, request: RunRequest) -> RunRequest:
if request.context is not None:
return request
context = self._runtime.default_context_for_chat(self._chat_id)
if context is None:
return request
return RunRequest(
prompt=request.prompt,
engine=request.engine,
context=context,
)
async def send( async def send(
self, self,
message: RenderedMessage | str, message: RenderedMessage | str,
@@ -768,6 +791,7 @@ class _TelegramCommandExecutor(CommandExecutor):
async def run_one( async def run_one(
self, request: RunRequest, *, mode: RunMode = "emit" self, request: RunRequest, *, mode: RunMode = "emit"
) -> RunResult: ) -> RunResult:
request = self._apply_default_context(request)
engine = self._runtime.resolve_engine( engine = self._runtime.resolve_engine(
engine_override=request.engine, engine_override=request.engine,
context=request.context, context=request.context,
@@ -999,6 +1023,7 @@ async def run_main_loop(
resolved = cfg.runtime.resolve_message( resolved = cfg.runtime.resolve_message(
text=text, text=text,
reply_text=reply_text, reply_text=reply_text,
chat_id=chat_id,
) )
except DirectiveError as exc: except DirectiveError as exc:
await _send_plain( await _send_plain(
+17 -4
View File
@@ -9,6 +9,7 @@ from typing import (
Awaitable, Awaitable,
Callable, Callable,
Hashable, Hashable,
Iterable,
Protocol, Protocol,
TYPE_CHECKING, TYPE_CHECKING,
) )
@@ -44,7 +45,10 @@ def is_group_chat_id(chat_id: int) -> bool:
def parse_incoming_update( def parse_incoming_update(
update: dict[str, Any], *, chat_id: int update: dict[str, Any],
*,
chat_id: int | None = None,
chat_ids: set[int] | None = None,
) -> TelegramIncomingMessage | None: ) -> TelegramIncomingMessage | None:
msg = update.get("message") msg = update.get("message")
if not isinstance(msg, dict): if not isinstance(msg, dict):
@@ -78,7 +82,12 @@ def parse_incoming_update(
if not isinstance(chat, dict): if not isinstance(chat, dict):
return None return None
msg_chat_id = chat.get("id") msg_chat_id = chat.get("id")
if not isinstance(msg_chat_id, int) or msg_chat_id != chat_id: if not isinstance(msg_chat_id, int):
return None
allowed = chat_ids
if allowed is None and chat_id is not None:
allowed = {chat_id}
if allowed is not None and msg_chat_id not in allowed:
return None return None
message_id = msg.get("message_id") message_id = msg.get("message_id")
if not isinstance(message_id, int): if not isinstance(message_id, int):
@@ -117,9 +126,13 @@ def parse_incoming_update(
async def poll_incoming( async def poll_incoming(
bot: BotClient, bot: BotClient,
*, *,
chat_id: int, chat_id: int | None = None,
chat_ids: Iterable[int] | None = None,
offset: int | None = None, offset: int | None = None,
) -> AsyncIterator[TelegramIncomingMessage]: ) -> AsyncIterator[TelegramIncomingMessage]:
allowed = set(chat_ids) if chat_ids is not None else None
if allowed is None and chat_id is not None:
allowed = {chat_id}
while True: while True:
updates = await bot.get_updates( updates = await bot.get_updates(
offset=offset, timeout_s=50, allowed_updates=["message"] offset=offset, timeout_s=50, allowed_updates=["message"]
@@ -131,7 +144,7 @@ async def poll_incoming(
logger.debug("loop.updates", updates=updates) logger.debug("loop.updates", updates=updates)
for upd in updates: for upd in updates:
offset = upd["update_id"] + 1 offset = upd["update_id"] + 1
msg = parse_incoming_update(upd, chat_id=chat_id) msg = parse_incoming_update(upd, chat_ids=allowed)
if msg is not None: if msg is not None:
yield msg yield msg
+23 -4
View File
@@ -110,7 +110,13 @@ class TransportRuntime:
) )
return dict(raw) return dict(raw)
def resolve_message(self, *, text: str, reply_text: str | None) -> ResolvedMessage: def resolve_message(
self,
*,
text: str,
reply_text: str | None,
chat_id: int | None = None,
) -> ResolvedMessage:
directives = parse_directives( directives = parse_directives(
text, text,
engine_ids=self._router.engine_ids, engine_ids=self._router.engine_ids,
@@ -118,13 +124,17 @@ class TransportRuntime:
) )
reply_ctx = parse_context_line(reply_text, projects=self._projects) reply_ctx = parse_context_line(reply_text, projects=self._projects)
resume_token = self._router.resolve_resume(directives.prompt, reply_text) resume_token = self._router.resolve_resume(directives.prompt, reply_text)
chat_project = self._projects.project_for_chat(chat_id)
if resume_token is not None: if resume_token is not None:
context = reply_ctx
if context is None and chat_project is not None:
context = RunContext(project=chat_project, branch=None)
return ResolvedMessage( return ResolvedMessage(
prompt=directives.prompt, prompt=directives.prompt,
resume_token=resume_token, resume_token=resume_token,
engine_override=None, engine_override=None,
context=reply_ctx, context=context,
) )
if reply_ctx is not None: if reply_ctx is not None:
@@ -141,8 +151,8 @@ class TransportRuntime:
) )
project_key = directives.project project_key = directives.project
if project_key is None and self._projects.default_project is not None: if project_key is None:
project_key = self._projects.default_project project_key = chat_project or self._projects.default_project
context = None context = None
if project_key is not None or directives.branch is not None: if project_key is not None or directives.branch is not None:
@@ -161,6 +171,15 @@ class TransportRuntime:
context=context, context=context,
) )
def default_context_for_chat(self, chat_id: int | None) -> RunContext | None:
project_key = self._projects.project_for_chat(chat_id)
if project_key is None:
return None
return RunContext(project=project_key, branch=None)
def project_chat_ids(self) -> tuple[int, ...]:
return self._projects.project_chat_ids()
def resolve_runner( def resolve_runner(
self, self,
*, *,
+31
View File
@@ -87,6 +87,37 @@ def test_projects_default_engine_unknown() -> None:
) )
def test_projects_chat_id_cannot_match_transport_chat_id() -> None:
config = {
"transports": {"telegram": {"bot_token": "token", "chat_id": 123}},
"projects": {"z80": {"path": "/tmp/repo", "chat_id": 123}},
}
settings = TakopiSettings.model_validate(config)
with pytest.raises(ConfigError, match="chat_id"):
settings.to_projects_config(
config_path=Path("takopi.toml"),
engine_ids=["codex"],
reserved=("cancel",),
)
def test_projects_chat_id_must_be_unique() -> None:
config = {
"transports": {"telegram": {"bot_token": "token", "chat_id": 123}},
"projects": {
"a": {"path": "/tmp/a", "chat_id": -10},
"b": {"path": "/tmp/b", "chat_id": -10},
},
}
settings = TakopiSettings.model_validate(config)
with pytest.raises(ConfigError, match="chat_id"):
settings.to_projects_config(
config_path=Path("takopi.toml"),
engine_ids=["codex"],
reserved=("cancel",),
)
def test_projects_relative_path_resolves(tmp_path: Path) -> None: def test_projects_relative_path_resolves(tmp_path: Path) -> None:
config_path = tmp_path / "takopi.toml" config_path = tmp_path / "takopi.toml"
settings = TakopiSettings.model_validate({"projects": {"z80": {"path": "repo"}}}) settings = TakopiSettings.model_validate({"projects": {"z80": {"path": "repo"}}})
+84
View File
@@ -911,6 +911,90 @@ async def test_run_main_loop_command_uses_project_default_engine(
assert transport.send_calls[-1]["message"].text == "ran:pi" assert transport.send_calls[-1]["message"].text == "ran:pi"
@pytest.mark.anyio
async def test_run_main_loop_command_defaults_to_chat_project(
monkeypatch,
) -> None:
class _Command:
id = "auto_ctx"
description = "auto context"
async def handle(self, ctx):
result = await ctx.executor.run_one(
commands.RunRequest(prompt="hello"),
mode="capture",
)
return commands.CommandResult(text=f"ran:{result.engine}")
entrypoints = [
FakeEntryPoint(
"auto_ctx",
"takopi.commands.auto_ctx:BACKEND",
plugins.COMMAND_GROUP,
loader=_Command,
)
]
install_entrypoints(monkeypatch, entrypoints)
transport = _FakeTransport()
bot = _FakeBot()
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
router = AutoRouter(
entries=[
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
RunnerEntry(engine=pi_runner.engine, runner=pi_runner),
],
default_engine=codex_runner.engine,
)
projects = ProjectsConfig(
projects={
"proj": ProjectConfig(
alias="proj",
path=Path("."),
worktrees_dir=Path(".worktrees"),
default_engine=pi_runner.engine,
chat_id=-42,
)
},
default_project=None,
chat_map={-42: "proj"},
)
runtime = TransportRuntime(
router=router,
projects=projects,
)
exec_cfg = ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
)
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg,
)
async def poller(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=-42,
message_id=1,
text="/auto_ctx",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
)
await run_main_loop(cfg, poller)
assert codex_runner.calls == []
assert len(pi_runner.calls) == 1
assert transport.send_calls[-1]["message"].text == "ran:pi"
@pytest.mark.anyio @pytest.mark.anyio
async def test_run_main_loop_refreshes_command_ids(monkeypatch) -> None: async def test_run_main_loop_refreshes_command_ids(monkeypatch) -> None:
class _Command: class _Command:
+28
View File
@@ -43,3 +43,31 @@ def test_resolve_engine_prefers_override() -> None:
context=RunContext(project="proj"), context=RunContext(project="proj"),
) )
assert engine == "codex" assert engine == "codex"
def test_resolve_message_defaults_to_chat_project() -> None:
codex = ScriptRunner([Return(answer="ok")], engine="codex")
router = AutoRouter(
entries=[RunnerEntry(engine=codex.engine, runner=codex)],
default_engine=codex.engine,
)
project = ProjectConfig(
alias="proj",
path=Path("."),
worktrees_dir=Path(".worktrees"),
chat_id=-42,
)
projects = ProjectsConfig(
projects={"proj": project},
default_project=None,
chat_map={-42: "proj"},
)
runtime = TransportRuntime(router=router, projects=projects)
resolved = runtime.resolve_message(
text="hello",
reply_text=None,
chat_id=-42,
)
assert resolved.context == RunContext(project="proj", branch=None)