feat: add per-project chat routing (#76)

This commit is contained in:
banteg
2026-01-10 01:22:20 +04:00
committed by GitHub
parent 7ffb99d779
commit 81618e48e4
12 changed files with 299 additions and 14 deletions
+1 -1
View File
@@ -325,7 +325,7 @@ See `docs/adding-a-runner.md` for the full guide and a worked example.
```
Telegram Update
telegram/bridge.poll_updates() drains backlog, long-polls, filters chat_id == cfg.chat_id
telegram/bridge.poll_updates() drains backlog, long-polls, filters allowed chat ids
telegram/bridge.run_main_loop() spawns tasks in TaskGroup
+6
View File
@@ -31,6 +31,7 @@ path = "~/dev/z80" # required (repo root)
worktrees_dir = ".worktrees" # optional, default ".worktrees"
default_engine = "codex" # optional, per-project override
worktree_base = "master" # optional, base for new branches
chat_id = -123 # optional, project chat id
```
Legacy config note: top-level `bot_token` / `chat_id` are auto-migrated into
@@ -52,6 +53,7 @@ Validation rules:
- `default_project` must match a configured project alias.
- Project aliases cannot collide with engine ids or reserved commands (`/cancel`).
- `default_engine` and per-project `default_engine` must be valid engine ids.
- `projects.<alias>.chat_id` must be unique and must not match `transports.telegram.chat_id`.
- `transport` defaults to `"telegram"` when omitted; override per-run with `--transport`.
## `takopi init`
@@ -95,6 +97,10 @@ code (backticked):
The `ctx:` line is parsed from replies and takes precedence over new directives.
When a message arrives in a chat whose `chat_id` matches `projects.<alias>.chat_id`,
Takopi defaults the project context to that alias unless a reply `ctx:` or explicit
`/project` directive is present.
## Worktree resolution
When `@branch` is present:
+2 -1
View File
@@ -111,6 +111,7 @@ path = "~/dev/z80"
worktrees_dir = ".worktrees"
default_engine = "codex"
worktree_base = "master"
chat_id = -123456789 # optional, project chat id
```
note: the default `worktrees_dir` lives inside the repo, so `.worktrees/` will
@@ -163,7 +164,7 @@ see:
## notes
* the bot only responds to the configured `chat_id` (private or group)
* the bot only responds to the primary `chat_id` plus any per-project `chat_id`
* run only one takopi instance per bot token: multiple instances will race telegram's `getUpdates` offsets and cause missed updates
## development
+41 -2
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
import tomllib
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Iterable
@@ -41,6 +41,7 @@ class ProjectConfig:
worktrees_dir: Path
default_engine: str | None = None
worktree_base: str | None = None
chat_id: int | None = None
@property
def worktrees_root(self) -> Path:
@@ -53,6 +54,7 @@ class ProjectConfig:
class ProjectsConfig:
projects: dict[str, ProjectConfig]
default_project: str | None = None
chat_map: dict[int, str] = field(default_factory=dict)
def resolve(self, alias: str | None) -> ProjectConfig | None:
if alias is None:
@@ -61,6 +63,14 @@ class ProjectsConfig:
return self.projects.get(self.default_project)
return self.projects.get(alias.lower())
def project_for_chat(self, chat_id: int | None) -> str | None:
if chat_id is None:
return None
return self.chat_map.get(chat_id)
def project_chat_ids(self) -> tuple[int, ...]:
return tuple(self.chat_map.keys())
def empty_projects_config() -> ProjectsConfig:
return ProjectsConfig(projects={}, default_project=None)
@@ -99,6 +109,7 @@ def parse_projects_config(
config_path: Path,
engine_ids: Iterable[str],
reserved: Iterable[str] = ("cancel",),
default_chat_id: int | None = None,
) -> ProjectsConfig:
default_project_raw = config.get("default_project")
default_project = None
@@ -116,6 +127,7 @@ def parse_projects_config(
reserved_lower = {value.lower() for value in reserved}
engine_lower = {value.lower() for value in engine_ids}
projects: dict[str, ProjectConfig] = {}
chat_map: dict[int, str] = {}
for raw_alias, raw_entry in projects_raw.items():
if not isinstance(raw_alias, str) or not raw_alias.strip():
@@ -173,12 +185,35 @@ def parse_projects_config(
)
worktree_base = worktree_base_raw.strip()
chat_id_raw = raw_entry.get("chat_id")
chat_id = None
if chat_id_raw is not None:
if isinstance(chat_id_raw, bool) or not isinstance(chat_id_raw, int):
raise ConfigError(
f"Invalid `projects.{alias}.chat_id` in {config_path}; "
"expected an integer."
)
chat_id = chat_id_raw
if default_chat_id is not None and chat_id == default_chat_id:
raise ConfigError(
f"Invalid `projects.{alias}.chat_id` in {config_path}; "
"must not match transports.telegram.chat_id."
)
if chat_id in chat_map:
existing = chat_map[chat_id]
raise ConfigError(
f"Duplicate `projects.*.chat_id` {chat_id} in {config_path}; "
f"already used by {existing!r}."
)
chat_map[chat_id] = alias_key
projects[alias_key] = ProjectConfig(
alias=alias,
path=path,
worktrees_dir=worktrees_dir,
default_engine=default_engine,
worktree_base=worktree_base,
chat_id=chat_id,
)
if default_project is not None:
@@ -190,7 +225,11 @@ def parse_projects_config(
)
default_project = default_key
return ProjectsConfig(projects=projects, default_project=default_project)
return ProjectsConfig(
projects=projects,
default_project=default_project,
chat_map=chat_map,
)
def _toml_escape(value: str) -> str:
+38 -1
View File
@@ -70,6 +70,7 @@ class ProjectSettings(BaseModel):
worktrees_dir: str = ".worktrees"
default_engine: str | None = None
worktree_base: str | None = None
chat_id: int | None = None
model_config = ConfigDict(extra="allow")
@@ -91,6 +92,15 @@ class ProjectSettings(BaseModel):
raise ValueError(f"{info.field_name} must be a non-empty string")
return cleaned
@field_validator("chat_id", mode="before")
@classmethod
def _validate_chat_id(cls, value: Any) -> Any:
if value is None:
return None
if isinstance(value, bool) or not isinstance(value, int):
raise ValueError("chat_id must be an integer")
return value
class TakopiSettings(BaseSettings):
model_config = SettingsConfigDict(
@@ -194,10 +204,12 @@ class TakopiSettings(BaseSettings):
reserved: Iterable[str] = ("cancel",),
) -> ProjectsConfig:
default_project = self.default_project
default_chat_id = self.transports.telegram.chat_id
reserved_lower = {value.lower() for value in reserved}
engine_map = {engine.lower(): engine for engine in engine_ids}
projects: dict[str, ProjectConfig] = {}
chat_map: dict[int, str] = {}
for raw_alias, entry in self.projects.items():
if not isinstance(raw_alias, str) or not raw_alias.strip():
@@ -258,12 +270,33 @@ class TakopiSettings(BaseSettings):
)
worktree_base = worktree_base_raw.strip()
chat_id = entry.chat_id
if chat_id is not None:
if isinstance(chat_id, bool) or not isinstance(chat_id, int):
raise ConfigError(
f"Invalid `projects.{alias}.chat_id` in {config_path}; "
"expected an integer."
)
if default_chat_id is not None and chat_id == default_chat_id:
raise ConfigError(
f"Invalid `projects.{alias}.chat_id` in {config_path}; "
"must not match transports.telegram.chat_id."
)
if chat_id in chat_map:
existing = chat_map[chat_id]
raise ConfigError(
f"Duplicate `projects.*.chat_id` {chat_id} in {config_path}; "
f"already used by {existing!r}."
)
chat_map[chat_id] = alias_key
projects[alias_key] = ProjectConfig(
alias=alias,
path=path,
worktrees_dir=worktrees_dir,
default_engine=default_engine,
worktree_base=worktree_base,
chat_id=chat_id,
)
if default_project is not None:
@@ -275,7 +308,11 @@ class TakopiSettings(BaseSettings):
)
default_project = default_key
return ProjectsConfig(projects=projects, default_project=default_project)
return ProjectsConfig(
projects=projects,
default_project=default_project,
chat_map=chat_map,
)
def load_settings(path: str | Path | None = None) -> tuple[TakopiSettings, Path]:
+2
View File
@@ -97,10 +97,12 @@ class TelegramBackend(TransportBackend):
final_notify=final_notify,
)
voice_transcription = _build_voice_transcription_config(transport_config)
chat_ids = (chat_id, *runtime.project_chat_ids())
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=chat_id,
chat_ids=chat_ids,
startup_msg=startup_msg,
exec_cfg=exec_cfg,
voice_transcription=voice_transcription,
+26 -1
View File
@@ -297,6 +297,13 @@ class TelegramBridgeConfig:
startup_msg: str
exec_cfg: ExecBridgeConfig
voice_transcription: TelegramVoiceTranscriptionConfig | None = None
chat_ids: tuple[int, ...] | None = None
def _allowed_chat_ids(cfg: TelegramBridgeConfig) -> set[int]:
allowed = set(cfg.chat_ids or ())
allowed.add(cfg.chat_id)
return allowed
async def _send_plain(
@@ -353,7 +360,11 @@ async def poll_updates(
offset = await _drain_backlog(cfg, offset)
await _send_startup(cfg)
async for msg in poll_incoming(cfg.bot, chat_id=cfg.chat_id, offset=offset):
async for msg in poll_incoming(
cfg.bot,
chat_ids=_allowed_chat_ids(cfg),
offset=offset,
):
yield msg
@@ -746,6 +757,18 @@ class _TelegramCommandExecutor(CommandExecutor):
self._user_msg_id = user_msg_id
self._reply_ref = MessageRef(channel_id=chat_id, message_id=user_msg_id)
def _apply_default_context(self, request: RunRequest) -> RunRequest:
if request.context is not None:
return request
context = self._runtime.default_context_for_chat(self._chat_id)
if context is None:
return request
return RunRequest(
prompt=request.prompt,
engine=request.engine,
context=context,
)
async def send(
self,
message: RenderedMessage | str,
@@ -768,6 +791,7 @@ class _TelegramCommandExecutor(CommandExecutor):
async def run_one(
self, request: RunRequest, *, mode: RunMode = "emit"
) -> RunResult:
request = self._apply_default_context(request)
engine = self._runtime.resolve_engine(
engine_override=request.engine,
context=request.context,
@@ -999,6 +1023,7 @@ async def run_main_loop(
resolved = cfg.runtime.resolve_message(
text=text,
reply_text=reply_text,
chat_id=chat_id,
)
except DirectiveError as exc:
await _send_plain(
+17 -4
View File
@@ -9,6 +9,7 @@ from typing import (
Awaitable,
Callable,
Hashable,
Iterable,
Protocol,
TYPE_CHECKING,
)
@@ -44,7 +45,10 @@ def is_group_chat_id(chat_id: int) -> bool:
def parse_incoming_update(
update: dict[str, Any], *, chat_id: int
update: dict[str, Any],
*,
chat_id: int | None = None,
chat_ids: set[int] | None = None,
) -> TelegramIncomingMessage | None:
msg = update.get("message")
if not isinstance(msg, dict):
@@ -78,7 +82,12 @@ def parse_incoming_update(
if not isinstance(chat, dict):
return None
msg_chat_id = chat.get("id")
if not isinstance(msg_chat_id, int) or msg_chat_id != chat_id:
if not isinstance(msg_chat_id, int):
return None
allowed = chat_ids
if allowed is None and chat_id is not None:
allowed = {chat_id}
if allowed is not None and msg_chat_id not in allowed:
return None
message_id = msg.get("message_id")
if not isinstance(message_id, int):
@@ -117,9 +126,13 @@ def parse_incoming_update(
async def poll_incoming(
bot: BotClient,
*,
chat_id: int,
chat_id: int | None = None,
chat_ids: Iterable[int] | None = None,
offset: int | None = None,
) -> AsyncIterator[TelegramIncomingMessage]:
allowed = set(chat_ids) if chat_ids is not None else None
if allowed is None and chat_id is not None:
allowed = {chat_id}
while True:
updates = await bot.get_updates(
offset=offset, timeout_s=50, allowed_updates=["message"]
@@ -131,7 +144,7 @@ async def poll_incoming(
logger.debug("loop.updates", updates=updates)
for upd in updates:
offset = upd["update_id"] + 1
msg = parse_incoming_update(upd, chat_id=chat_id)
msg = parse_incoming_update(upd, chat_ids=allowed)
if msg is not None:
yield msg
+23 -4
View File
@@ -110,7 +110,13 @@ class TransportRuntime:
)
return dict(raw)
def resolve_message(self, *, text: str, reply_text: str | None) -> ResolvedMessage:
def resolve_message(
self,
*,
text: str,
reply_text: str | None,
chat_id: int | None = None,
) -> ResolvedMessage:
directives = parse_directives(
text,
engine_ids=self._router.engine_ids,
@@ -118,13 +124,17 @@ class TransportRuntime:
)
reply_ctx = parse_context_line(reply_text, projects=self._projects)
resume_token = self._router.resolve_resume(directives.prompt, reply_text)
chat_project = self._projects.project_for_chat(chat_id)
if resume_token is not None:
context = reply_ctx
if context is None and chat_project is not None:
context = RunContext(project=chat_project, branch=None)
return ResolvedMessage(
prompt=directives.prompt,
resume_token=resume_token,
engine_override=None,
context=reply_ctx,
context=context,
)
if reply_ctx is not None:
@@ -141,8 +151,8 @@ class TransportRuntime:
)
project_key = directives.project
if project_key is None and self._projects.default_project is not None:
project_key = self._projects.default_project
if project_key is None:
project_key = chat_project or self._projects.default_project
context = None
if project_key is not None or directives.branch is not None:
@@ -161,6 +171,15 @@ class TransportRuntime:
context=context,
)
def default_context_for_chat(self, chat_id: int | None) -> RunContext | None:
project_key = self._projects.project_for_chat(chat_id)
if project_key is None:
return None
return RunContext(project=project_key, branch=None)
def project_chat_ids(self) -> tuple[int, ...]:
return self._projects.project_chat_ids()
def resolve_runner(
self,
*,
+31
View File
@@ -87,6 +87,37 @@ def test_projects_default_engine_unknown() -> None:
)
def test_projects_chat_id_cannot_match_transport_chat_id() -> None:
config = {
"transports": {"telegram": {"bot_token": "token", "chat_id": 123}},
"projects": {"z80": {"path": "/tmp/repo", "chat_id": 123}},
}
settings = TakopiSettings.model_validate(config)
with pytest.raises(ConfigError, match="chat_id"):
settings.to_projects_config(
config_path=Path("takopi.toml"),
engine_ids=["codex"],
reserved=("cancel",),
)
def test_projects_chat_id_must_be_unique() -> None:
config = {
"transports": {"telegram": {"bot_token": "token", "chat_id": 123}},
"projects": {
"a": {"path": "/tmp/a", "chat_id": -10},
"b": {"path": "/tmp/b", "chat_id": -10},
},
}
settings = TakopiSettings.model_validate(config)
with pytest.raises(ConfigError, match="chat_id"):
settings.to_projects_config(
config_path=Path("takopi.toml"),
engine_ids=["codex"],
reserved=("cancel",),
)
def test_projects_relative_path_resolves(tmp_path: Path) -> None:
config_path = tmp_path / "takopi.toml"
settings = TakopiSettings.model_validate({"projects": {"z80": {"path": "repo"}}})
+84
View File
@@ -911,6 +911,90 @@ async def test_run_main_loop_command_uses_project_default_engine(
assert transport.send_calls[-1]["message"].text == "ran:pi"
@pytest.mark.anyio
async def test_run_main_loop_command_defaults_to_chat_project(
monkeypatch,
) -> None:
class _Command:
id = "auto_ctx"
description = "auto context"
async def handle(self, ctx):
result = await ctx.executor.run_one(
commands.RunRequest(prompt="hello"),
mode="capture",
)
return commands.CommandResult(text=f"ran:{result.engine}")
entrypoints = [
FakeEntryPoint(
"auto_ctx",
"takopi.commands.auto_ctx:BACKEND",
plugins.COMMAND_GROUP,
loader=_Command,
)
]
install_entrypoints(monkeypatch, entrypoints)
transport = _FakeTransport()
bot = _FakeBot()
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
router = AutoRouter(
entries=[
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
RunnerEntry(engine=pi_runner.engine, runner=pi_runner),
],
default_engine=codex_runner.engine,
)
projects = ProjectsConfig(
projects={
"proj": ProjectConfig(
alias="proj",
path=Path("."),
worktrees_dir=Path(".worktrees"),
default_engine=pi_runner.engine,
chat_id=-42,
)
},
default_project=None,
chat_map={-42: "proj"},
)
runtime = TransportRuntime(
router=router,
projects=projects,
)
exec_cfg = ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
)
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg,
)
async def poller(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=-42,
message_id=1,
text="/auto_ctx",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
)
await run_main_loop(cfg, poller)
assert codex_runner.calls == []
assert len(pi_runner.calls) == 1
assert transport.send_calls[-1]["message"].text == "ran:pi"
@pytest.mark.anyio
async def test_run_main_loop_refreshes_command_ids(monkeypatch) -> None:
class _Command:
+28
View File
@@ -43,3 +43,31 @@ def test_resolve_engine_prefers_override() -> None:
context=RunContext(project="proj"),
)
assert engine == "codex"
def test_resolve_message_defaults_to_chat_project() -> None:
codex = ScriptRunner([Return(answer="ok")], engine="codex")
router = AutoRouter(
entries=[RunnerEntry(engine=codex.engine, runner=codex)],
default_engine=codex.engine,
)
project = ProjectConfig(
alias="proj",
path=Path("."),
worktrees_dir=Path(".worktrees"),
chat_id=-42,
)
projects = ProjectsConfig(
projects={"proj": project},
default_project=None,
chat_map={-42: "proj"},
)
runtime = TransportRuntime(router=router, projects=projects)
resolved = runtime.resolve_message(
text="hello",
reply_text=None,
chat_id=-42,
)
assert resolved.context == RunContext(project="proj", branch=None)