feat(telegram): add per-chat/topic default agents (#109)

This commit is contained in:
banteg
2026-01-13 04:11:16 +04:00
committed by GitHub
parent 6ce08ee602
commit f060d3b59c
13 changed files with 660 additions and 31 deletions
+1
View File
@@ -2,3 +2,4 @@ after you finish work, commit with a conventional message. only commit the files
always run `just check` before code commits. always run `just check` before code commits.
if you fix anything from `just check`, rerun it and confirm it passes before committing. if you fix anything from `just check`, rerun it and confirm it passes before committing.
when using gh to edit or create PR descriptions, prefer `--body-file` to preserve newlines. when using gh to edit or create PR descriptions, prefer `--body-file` to preserve newlines.
always include a "Manual testing" checklist section in PRs.
+9
View File
@@ -353,6 +353,15 @@ Decision (v0.4.0):
* Users MAY prefix the first non-empty line with `/{engine}` (e.g. `/claude`, `/codex`, or `/pi`) to select the engine for a **new** thread. * Users MAY prefix the first non-empty line with `/{engine}` (e.g. `/claude`, `/codex`, or `/pi`) to select the engine for a **new** thread.
* The bridge MUST strip that directive from the prompt before invoking the runner. * The bridge MUST strip that directive from the prompt before invoking the runner.
* If a ResumeToken is resolved from the message or reply, it MUST take precedence and the `/{engine}` directive MUST be ignored. * If a ResumeToken is resolved from the message or reply, it MUST take precedence and the `/{engine}` directive MUST be ignored.
* Bridges MAY persist default engine overrides per Telegram scope:
* **Topic default**: forum topic (`chat_id + thread_id`)
* **Chat default**: chat (`chat_id`)
* When no ResumeToken is resolved, engine selection MUST follow this precedence:
1) explicit `/{engine}` directive
2) topic default (if any)
3) chat default (if any)
4) project default engine (if configured for the resolved context)
5) global default engine
### 8.1 Command menu (Telegram) ### 8.1 Command menu (Telegram)
+19 -2
View File
@@ -122,6 +122,22 @@ Prefix your message with an engine directive to override the default:
Directives are only parsed at the start of the first non-empty line. Directives are only parsed at the start of the first non-empty line.
### Default agent per chat or topic
Use `/agent` to view or set a persistent default for the current scope:
```
/agent
/agent set claude
/agent clear
```
- Inside a forum topic, `/agent set` affects that topic.
- In normal chats, it affects the whole chat.
- In group chats, only admins can change defaults.
Precedence (highest to lowest): resume token → `/engine` directive → topic default → chat default → project default → global default.
### Setting up engines ### Setting up engines
Takopi shells out to the agent CLIs. Install them and make sure they're on your `PATH` Takopi shells out to the agent CLIs. Install them and make sure they're on your `PATH`
@@ -257,7 +273,7 @@ takopi chat-id
## 7. Topics ## 7. Topics
Topics bind Telegram forum threads to specific project/branch contexts. They also preserve resume tokens, so agents can pick up where they left off. Topics bind Telegram forum threads to specific project/branch contexts. They also preserve resume tokens and can store a default agent per topic.
### Enabling topics ### Enabling topics
@@ -341,7 +357,7 @@ path = "~/dev/takopi"
chat_id = -1001111111111 # forum-enabled group chat_id = -1001111111111 # forum-enabled group
``` ```
Topic state is stored in `telegram_topics_state.json` next to your config file. Topic state is stored in `telegram_topics_state.json` next to your config file. Chat defaults live in `telegram_chat_prefs_state.json`.
--- ---
@@ -496,6 +512,7 @@ worktree_base = "develop"
| `/cancel` | Reply to the progress message to stop the current run | | `/cancel` | Reply to the progress message to stop the current run |
| `/file put <path>` | Upload a document into the repo/worktree | | `/file put <path>` | Upload a document into the repo/worktree |
| `/file get <path>` | Fetch a file (directories are zipped) | | `/file get <path>` | Fetch a file (directories are zipped) |
| `/agent` | Show/set the default agent for the current scope |
| `/topic <project> @branch` | Create/bind a topic | | `/topic <project> @branch` | Create/bind a topic |
| `/ctx` | Show current context | | `/ctx` | Show current context |
| `/ctx set <project> @branch` | Update context binding | | `/ctx set <project> @branch` | Update context binding |
+95
View File
@@ -0,0 +1,95 @@
from __future__ import annotations
from pathlib import Path
import msgspec
from ..logging import get_logger
from .state_store import JsonStateStore
logger = get_logger(__name__)
STATE_VERSION = 1
STATE_FILENAME = "telegram_chat_prefs_state.json"
class _ChatPrefs(msgspec.Struct, forbid_unknown_fields=False):
default_engine: str | None = None
class _ChatPrefsState(msgspec.Struct, forbid_unknown_fields=False):
version: int
chats: dict[str, _ChatPrefs] = msgspec.field(default_factory=dict)
def resolve_prefs_path(config_path: Path) -> Path:
return config_path.with_name(STATE_FILENAME)
def _chat_key(chat_id: int) -> str:
return str(chat_id)
def _normalize_text(value: str | None) -> str | None:
if value is None:
return None
value = value.strip()
return value or None
def _new_state() -> _ChatPrefsState:
return _ChatPrefsState(version=STATE_VERSION, chats={})
class ChatPrefsStore(JsonStateStore[_ChatPrefsState]):
def __init__(self, path: Path) -> None:
super().__init__(
path,
version=STATE_VERSION,
state_type=_ChatPrefsState,
state_factory=_new_state,
log_prefix="telegram.chat_prefs",
logger=logger,
)
async def get_default_engine(self, chat_id: int) -> str | None:
async with self._lock:
self._reload_locked_if_needed()
chat = self._get_chat_locked(chat_id)
if chat is None:
return None
return _normalize_text(chat.default_engine)
async def set_default_engine(self, chat_id: int, engine: str | None) -> None:
normalized = _normalize_text(engine)
async with self._lock:
self._reload_locked_if_needed()
if normalized is None:
if self._remove_chat_locked(chat_id):
self._save_locked()
return
chat = self._ensure_chat_locked(chat_id)
chat.default_engine = normalized
self._save_locked()
async def clear_default_engine(self, chat_id: int) -> None:
await self.set_default_engine(chat_id, None)
def _get_chat_locked(self, chat_id: int) -> _ChatPrefs | None:
return self._state.chats.get(_chat_key(chat_id))
def _ensure_chat_locked(self, chat_id: int) -> _ChatPrefs:
key = _chat_key(chat_id)
entry = self._state.chats.get(key)
if entry is not None:
return entry
entry = _ChatPrefs()
self._state.chats[key] = entry
return entry
def _remove_chat_locked(self, chat_id: int) -> bool:
key = _chat_key(chat_id)
if key not in self._state.chats:
return False
del self._state.chats[key]
return True
+160
View File
@@ -38,6 +38,7 @@ from ..transport import MessageRef, RenderedMessage, SendOptions
from ..transport_runtime import ResolvedMessage, TransportRuntime from ..transport_runtime import ResolvedMessage, TransportRuntime
from ..utils.paths import reset_run_base_dir, set_run_base_dir from ..utils.paths import reset_run_base_dir, set_run_base_dir
from .bridge import TelegramBridgeConfig, send_plain from .bridge import TelegramBridgeConfig, send_plain
from .chat_prefs import ChatPrefsStore
from .chat_sessions import ChatSessionStore from .chat_sessions import ChatSessionStore
from .context import ( from .context import (
_format_context, _format_context,
@@ -47,6 +48,7 @@ from .context import (
_usage_ctx_set, _usage_ctx_set,
_usage_topic, _usage_topic,
) )
from .engine_defaults import resolve_engine_for_message
from .files import ( from .files import (
default_upload_name, default_upload_name,
default_upload_path, default_upload_path,
@@ -79,6 +81,7 @@ __all__ = [
"FILE_GET_USAGE", "FILE_GET_USAGE",
"FILE_PUT_USAGE", "FILE_PUT_USAGE",
"_dispatch_command", "_dispatch_command",
"_handle_agent_command",
"_handle_chat_new_command", "_handle_chat_new_command",
"_handle_file_command", "_handle_file_command",
"_handle_file_get", "_handle_file_get",
@@ -97,6 +100,7 @@ __all__ = [
_MAX_BOT_COMMANDS = 100 _MAX_BOT_COMMANDS = 100
FILE_PUT_USAGE = "usage: `/file put <path>`" FILE_PUT_USAGE = "usage: `/file put <path>`"
FILE_GET_USAGE = "usage: `/file get <path>`" FILE_GET_USAGE = "usage: `/file get <path>`"
AGENT_USAGE = "usage: `/agent`, `/agent set <engine>`, or `/agent clear`"
def is_cancel_command(text: str) -> bool: def is_cancel_command(text: str) -> bool:
@@ -362,6 +366,29 @@ async def _check_file_permissions(
return False return False
async def _check_agent_permissions(
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
) -> bool:
reply = _reply_sender(cfg, msg)
sender_id = msg.sender_id
if sender_id is None:
await reply(text="cannot verify sender for agent defaults.")
return False
is_private = msg.chat_type == "private"
if msg.chat_type is None:
is_private = msg.chat_id > 0
if is_private:
return True
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
if member is None:
await reply(text="failed to verify agent permissions.")
return False
if member.status in {"creator", "administrator"}:
return True
await reply(text="changing default agents is restricted to group admins.")
return False
async def _prepare_file_put_plan( async def _prepare_file_put_plan(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
@@ -1034,6 +1061,125 @@ async def _handle_ctx_command(
) )
async def _handle_agent_command(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
args_text: str,
ambient_context: RunContext | None,
topic_store: TopicStateStore | None,
chat_prefs: ChatPrefsStore | None,
*,
resolved_scope: str | None = None,
scope_chat_ids: frozenset[int] | None = None,
) -> None:
reply = _reply_sender(cfg, msg)
tkey = (
_topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
if topic_store is not None
else None
)
tokens = split_command_args(args_text)
action = tokens[0].lower() if tokens else "show"
if action in {"show", ""}:
try:
resolved = cfg.runtime.resolve_message(
text="",
reply_text=msg.reply_to_text,
ambient_context=ambient_context,
chat_id=msg.chat_id,
)
except DirectiveError as exc:
await reply(text=f"error:\n{exc}")
return
selection = await resolve_engine_for_message(
runtime=cfg.runtime,
context=resolved.context,
explicit_engine=None,
chat_id=msg.chat_id,
topic_key=tkey,
topic_store=topic_store,
chat_prefs=chat_prefs,
)
source_labels = {
"directive": "directive",
"topic_default": "topic default",
"chat_default": "chat default",
"project_default": "project default",
"global_default": "global default",
}
agent_line = f"agent: {selection.engine} ({source_labels[selection.source]})"
topic_default = selection.topic_default or "none"
if tkey is None:
topic_default = "none"
if chat_prefs is None:
chat_default = "unavailable"
else:
chat_default = selection.chat_default or "none"
project_default = (
selection.project_default
if selection.project_default is not None
else "none"
)
defaults_line = (
"defaults: "
f"topic: {topic_default}, "
f"chat: {chat_default}, "
f"project: {project_default}, "
f"global: {cfg.runtime.default_engine}"
)
available = ", ".join(cfg.runtime.engine_ids)
available_line = f"available: {available}"
await reply(text="\n\n".join([agent_line, defaults_line, available_line]))
return
if action == "set":
if len(tokens) < 2:
await reply(text=AGENT_USAGE)
return
if not await _check_agent_permissions(cfg, msg):
return
engine = tokens[1].strip().lower()
if engine not in cfg.runtime.engine_ids:
available = ", ".join(cfg.runtime.engine_ids)
await reply(
text=f"unknown engine `{engine}`.\navailable agents: `{available}`",
)
return
if tkey is not None:
if topic_store is None:
await reply(text="topic defaults are unavailable.")
return
await topic_store.set_default_engine(tkey[0], tkey[1], engine)
await reply(text=f"topic default agent set to `{engine}`")
return
if chat_prefs is None:
await reply(text="chat defaults are unavailable (no config path).")
return
await chat_prefs.set_default_engine(msg.chat_id, engine)
await reply(text=f"chat default agent set to `{engine}`")
return
if action == "clear":
if not await _check_agent_permissions(cfg, msg):
return
if tkey is not None:
if topic_store is None:
await reply(text="topic defaults are unavailable.")
return
await topic_store.clear_default_engine(tkey[0], tkey[1])
await reply(text="topic default agent cleared.")
return
if chat_prefs is None:
await reply(text="chat defaults are unavailable (no config path).")
return
await chat_prefs.clear_default_engine(msg.chat_id)
await reply(text="chat default agent cleared.")
return
await reply(text=AGENT_USAGE)
async def _handle_new_command( async def _handle_new_command(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
@@ -1369,6 +1515,7 @@ class _TelegramCommandExecutor(CommandExecutor):
thread_id: int | None, thread_id: int | None,
show_resume_line: bool, show_resume_line: bool,
stateful_mode: bool, stateful_mode: bool,
default_engine_override: EngineId | None,
) -> None: ) -> None:
self._exec_cfg = exec_cfg self._exec_cfg = exec_cfg
self._runtime = runtime self._runtime = runtime
@@ -1380,6 +1527,7 @@ class _TelegramCommandExecutor(CommandExecutor):
self._thread_id = thread_id self._thread_id = thread_id
self._show_resume_line = show_resume_line self._show_resume_line = show_resume_line
self._stateful_mode = stateful_mode self._stateful_mode = stateful_mode
self._default_engine_override = default_engine_override
self._reply_ref = MessageRef( self._reply_ref = MessageRef(
channel_id=chat_id, channel_id=chat_id,
message_id=user_msg_id, message_id=user_msg_id,
@@ -1398,6 +1546,15 @@ class _TelegramCommandExecutor(CommandExecutor):
context=context, context=context,
) )
def _apply_default_engine(self, request: RunRequest) -> RunRequest:
if request.engine is not None or self._default_engine_override is None:
return request
return RunRequest(
prompt=request.prompt,
engine=self._default_engine_override,
context=request.context,
)
async def send( async def send(
self, self,
message: RenderedMessage | str, message: RenderedMessage | str,
@@ -1425,6 +1582,7 @@ class _TelegramCommandExecutor(CommandExecutor):
self, request: RunRequest, *, mode: RunMode = "emit" self, request: RunRequest, *, mode: RunMode = "emit"
) -> RunResult: ) -> RunResult:
request = self._apply_default_context(request) request = self._apply_default_context(request)
request = self._apply_default_engine(request)
effective_show_resume_line = _should_show_resume_line( effective_show_resume_line = _should_show_resume_line(
show_resume_line=self._show_resume_line, show_resume_line=self._show_resume_line,
stateful_mode=self._stateful_mode, stateful_mode=self._stateful_mode,
@@ -1511,6 +1669,7 @@ async def _dispatch_command(
scheduler: ThreadScheduler, scheduler: ThreadScheduler,
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None, on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None,
stateful_mode: bool, stateful_mode: bool,
default_engine_override: EngineId | None,
) -> None: ) -> None:
allowlist = cfg.runtime.allowlist allowlist = cfg.runtime.allowlist
chat_id = msg.chat_id chat_id = msg.chat_id
@@ -1535,6 +1694,7 @@ async def _dispatch_command(
thread_id=msg.thread_id, thread_id=msg.thread_id,
show_resume_line=cfg.show_resume_line, show_resume_line=cfg.show_resume_line,
stateful_mode=stateful_mode, stateful_mode=stateful_mode,
default_engine_override=default_engine_override,
) )
message_ref = MessageRef( message_ref = MessageRef(
channel_id=chat_id, channel_id=chat_id,
+86
View File
@@ -0,0 +1,86 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
from ..context import RunContext
from ..model import EngineId
from ..transport_runtime import TransportRuntime
from .chat_prefs import ChatPrefsStore
from .topic_state import TopicStateStore
EngineSource = Literal[
"directive",
"topic_default",
"chat_default",
"project_default",
"global_default",
]
@dataclass(frozen=True, slots=True)
class EngineResolution:
engine: EngineId
source: EngineSource
topic_default: EngineId | None
chat_default: EngineId | None
project_default: EngineId | None
async def resolve_engine_for_message(
*,
runtime: TransportRuntime,
context: RunContext | None,
explicit_engine: EngineId | None,
chat_id: int,
topic_key: tuple[int, int] | None,
topic_store: TopicStateStore | None,
chat_prefs: ChatPrefsStore | None,
) -> EngineResolution:
topic_default = None
if topic_store is not None and topic_key is not None:
topic_default = await topic_store.get_default_engine(*topic_key)
chat_default = None
if chat_prefs is not None:
chat_default = await chat_prefs.get_default_engine(chat_id)
project_default = runtime.project_default_engine(context)
if explicit_engine is not None:
return EngineResolution(
engine=explicit_engine,
source="directive",
topic_default=topic_default,
chat_default=chat_default,
project_default=project_default,
)
if topic_default is not None:
return EngineResolution(
engine=topic_default,
source="topic_default",
topic_default=topic_default,
chat_default=chat_default,
project_default=project_default,
)
if chat_default is not None:
return EngineResolution(
engine=chat_default,
source="chat_default",
topic_default=topic_default,
chat_default=chat_default,
project_default=project_default,
)
if project_default is not None:
return EngineResolution(
engine=project_default,
source="project_default",
topic_default=topic_default,
chat_default=chat_default,
project_default=project_default,
)
return EngineResolution(
engine=runtime.default_engine,
source="global_default",
topic_default=topic_default,
chat_default=chat_default,
project_default=project_default,
)
+73 -18
View File
@@ -23,6 +23,7 @@ from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
from .commands import ( from .commands import (
FILE_PUT_USAGE, FILE_PUT_USAGE,
_dispatch_command, _dispatch_command,
_handle_agent_command,
_handle_chat_new_command, _handle_chat_new_command,
_handle_ctx_command, _handle_ctx_command,
_handle_file_command, _handle_file_command,
@@ -50,7 +51,9 @@ from .topics import (
_validate_topics_setup, _validate_topics_setup,
) )
from .client import poll_incoming from .client import poll_incoming
from .chat_prefs import ChatPrefsStore, resolve_prefs_path
from .chat_sessions import ChatSessionStore, resolve_sessions_path from .chat_sessions import ChatSessionStore, resolve_sessions_path
from .engine_defaults import resolve_engine_for_message
from .topic_state import TopicStateStore, resolve_state_path from .topic_state import TopicStateStore, resolve_state_path
from .types import ( from .types import (
TelegramCallbackQuery, TelegramCallbackQuery,
@@ -110,6 +113,7 @@ def _dispatch_builtin_command(
args_text: str, args_text: str,
ambient_context: RunContext | None, ambient_context: RunContext | None,
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
chat_prefs: ChatPrefsStore | None,
resolved_scope: str | None, resolved_scope: str | None,
scope_chat_ids: frozenset[int], scope_chat_ids: frozenset[int],
reply: Callable[..., Awaitable[None]], reply: Callable[..., Awaitable[None]],
@@ -165,6 +169,19 @@ def _dispatch_builtin_command(
} }
) )
if command_id == "agent":
handlers["agent"] = partial(
_handle_agent_command,
cfg,
msg,
args_text,
ambient_context,
topic_store,
chat_prefs,
resolved_scope=resolved_scope,
scope_chat_ids=scope_chat_ids,
)
handler = handlers.get(command_id) handler = handlers.get(command_id)
if handler is None: if handler is None:
return False return False
@@ -311,6 +328,7 @@ async def run_main_loop(
) )
topic_store: TopicStateStore | None = None topic_store: TopicStateStore | None = None
chat_session_store: ChatSessionStore | None = None chat_session_store: ChatSessionStore | None = None
chat_prefs: ChatPrefsStore | None = None
media_groups: dict[tuple[int, str], _MediaGroupState] = {} media_groups: dict[tuple[int, str], _MediaGroupState] = {}
resolved_topics_scope: str | None = None resolved_topics_scope: str | None = None
topics_chat_ids: frozenset[int] = frozenset() topics_chat_ids: frozenset[int] = frozenset()
@@ -333,6 +351,12 @@ async def run_main_loop(
try: try:
config_path = cfg.runtime.config_path config_path = cfg.runtime.config_path
if config_path is not None:
chat_prefs = ChatPrefsStore(resolve_prefs_path(config_path))
logger.info(
"chat_prefs.enabled",
state_path=str(resolve_prefs_path(config_path)),
)
if cfg.session_mode == "chat": if cfg.session_mode == "chat":
if config_path is None: if config_path is None:
raise ConfigError( raise ConfigError(
@@ -552,6 +576,23 @@ async def run_main_loop(
return None return None
return resolved return resolved
async def resolve_engine_defaults(
*,
explicit_engine: EngineId | None,
context: RunContext | None,
chat_id: int,
topic_key: tuple[int, int] | None,
):
return await resolve_engine_for_message(
runtime=cfg.runtime,
context=context,
explicit_engine=explicit_engine,
chat_id=chat_id,
topic_key=topic_key,
topic_store=topic_store,
chat_prefs=chat_prefs,
)
async def run_prompt_from_upload( async def run_prompt_from_upload(
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
prompt_text: str, prompt_text: str,
@@ -570,7 +611,6 @@ async def run_main_loop(
else None else None
) )
resume_token = resolved.resume_token resume_token = resolved.resume_token
engine_override = resolved.engine_override
context = resolved.context context = resolved.context
chat_session_key = _chat_session_key(msg, store=chat_session_store) chat_session_key = _chat_session_key(msg, store=chat_session_store)
topic_key = ( topic_key = (
@@ -578,6 +618,13 @@ async def run_main_loop(
if topic_store is not None if topic_store is not None
else None else None
) )
engine_resolution = await resolve_engine_defaults(
explicit_engine=resolved.engine_override,
context=context,
chat_id=chat_id,
topic_key=topic_key,
)
engine_override = engine_resolution.engine
if resume_token is None and reply_id is not None: if resume_token is None and reply_id is not None:
running_task = running_tasks.get( running_task = running_tasks.get(
MessageRef(channel_id=chat_id, message_id=reply_id) MessageRef(channel_id=chat_id, message_id=reply_id)
@@ -600,10 +647,7 @@ async def run_main_loop(
and topic_store is not None and topic_store is not None
and topic_key is not None and topic_key is not None
): ):
engine_for_session = cfg.runtime.resolve_engine( engine_for_session = engine_resolution.engine
engine_override=engine_override,
context=context,
)
stored = await topic_store.get_session_resume( stored = await topic_store.get_session_resume(
topic_key[0], topic_key[1], engine_for_session topic_key[0], topic_key[1], engine_for_session
) )
@@ -614,10 +658,7 @@ async def run_main_loop(
and chat_session_store is not None and chat_session_store is not None
and chat_session_key is not None and chat_session_key is not None
): ):
engine_for_session = cfg.runtime.resolve_engine( engine_for_session = engine_resolution.engine
engine_override=engine_override,
context=context,
)
stored = await chat_session_store.get_session_resume( stored = await chat_session_store.get_session_resume(
chat_session_key[0], chat_session_key[0],
chat_session_key[1], chat_session_key[1],
@@ -815,6 +856,7 @@ async def run_main_loop(
args_text=args_text, args_text=args_text,
ambient_context=ambient_context, ambient_context=ambient_context,
topic_store=topic_store, topic_store=topic_store,
chat_prefs=chat_prefs,
resolved_scope=resolved_topics_scope, resolved_scope=resolved_topics_scope,
scope_chat_ids=topics_chat_ids, scope_chat_ids=topics_chat_ids,
reply=reply, reply=reply,
@@ -853,6 +895,18 @@ async def run_main_loop(
if command_id not in command_ids: if command_id not in command_ids:
refresh_commands() refresh_commands()
if command_id in command_ids: if command_id in command_ids:
engine_resolution = await resolve_engine_defaults(
explicit_engine=None,
context=ambient_context,
chat_id=chat_id,
topic_key=topic_key,
)
default_engine_override = (
engine_resolution.engine
if engine_resolution.source
in {"directive", "topic_default", "chat_default"}
else None
)
tg.start_soon( tg.start_soon(
_dispatch_command, _dispatch_command,
cfg, cfg,
@@ -868,6 +922,7 @@ async def run_main_loop(
chat_session_key, chat_session_key,
), ),
stateful_mode, stateful_mode,
default_engine_override,
) )
continue continue
@@ -885,8 +940,14 @@ async def run_main_loop(
text = resolved.prompt text = resolved.prompt
resume_token = resolved.resume_token resume_token = resolved.resume_token
engine_override = resolved.engine_override
context = resolved.context context = resolved.context
engine_resolution = await resolve_engine_defaults(
explicit_engine=resolved.engine_override,
context=context,
chat_id=chat_id,
topic_key=topic_key,
)
engine_override = engine_resolution.engine
if ( if (
topic_store is not None topic_store is not None
and topic_key is not None and topic_key is not None
@@ -936,10 +997,7 @@ async def run_main_loop(
and topic_store is not None and topic_store is not None
and topic_key is not None and topic_key is not None
): ):
engine_for_session = cfg.runtime.resolve_engine( engine_for_session = engine_resolution.engine
engine_override=engine_override,
context=context,
)
stored = await topic_store.get_session_resume( stored = await topic_store.get_session_resume(
topic_key[0], topic_key[1], engine_for_session topic_key[0], topic_key[1], engine_for_session
) )
@@ -950,10 +1008,7 @@ async def run_main_loop(
and chat_session_store is not None and chat_session_store is not None
and chat_session_key is not None and chat_session_key is not None
): ):
engine_for_session = cfg.runtime.resolve_engine( engine_for_session = engine_resolution.engine
engine_override=engine_override,
context=context,
)
stored = await chat_session_store.get_session_resume( stored = await chat_session_store.get_session_resume(
chat_session_key[0], chat_session_key[1], engine_for_session chat_session_key[0], chat_session_key[1], engine_for_session
) )
+24
View File
@@ -23,6 +23,7 @@ class TopicThreadSnapshot:
context: RunContext | None context: RunContext | None
sessions: dict[str, str] sessions: dict[str, str]
topic_title: str | None topic_title: str | None
default_engine: str | None
class _ContextState(msgspec.Struct, forbid_unknown_fields=False): class _ContextState(msgspec.Struct, forbid_unknown_fields=False):
@@ -38,6 +39,7 @@ class _ThreadState(msgspec.Struct, forbid_unknown_fields=False):
context: _ContextState | None = None context: _ContextState | None = None
sessions: dict[str, _SessionState] = msgspec.field(default_factory=dict) sessions: dict[str, _SessionState] = msgspec.field(default_factory=dict)
topic_title: str | None = None topic_title: str | None = None
default_engine: str | None = None
class _TopicState(msgspec.Struct, forbid_unknown_fields=False): class _TopicState(msgspec.Struct, forbid_unknown_fields=False):
@@ -151,6 +153,27 @@ class TopicStateStore(JsonStateStore[_TopicState]):
return None return None
return ResumeToken(engine=engine, value=entry.resume) return ResumeToken(engine=engine, value=entry.resume)
async def get_default_engine(self, chat_id: int, thread_id: int) -> str | None:
async with self._lock:
self._reload_locked_if_needed()
thread = self._get_thread_locked(chat_id, thread_id)
if thread is None:
return None
return _normalize_text(thread.default_engine)
async def set_default_engine(
self, chat_id: int, thread_id: int, engine: str | None
) -> None:
normalized = _normalize_text(engine)
async with self._lock:
self._reload_locked_if_needed()
thread = self._ensure_thread_locked(chat_id, thread_id)
thread.default_engine = normalized
self._save_locked()
async def clear_default_engine(self, chat_id: int, thread_id: int) -> None:
await self.set_default_engine(chat_id, thread_id, None)
async def set_session_resume( async def set_session_resume(
self, chat_id: int, thread_id: int, token: ResumeToken self, chat_id: int, thread_id: int, token: ResumeToken
) -> None: ) -> None:
@@ -205,6 +228,7 @@ class TopicStateStore(JsonStateStore[_TopicState]):
context=_context_from_state(thread.context), context=_context_from_state(thread.context),
sessions=sessions, sessions=sessions,
topic_title=thread.topic_title, topic_title=thread.topic_title,
default_engine=_normalize_text(thread.default_engine),
) )
def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None: def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None:
+8 -11
View File
@@ -198,7 +198,6 @@ class TransportRuntime:
) )
engine_override = self._resolve_engine_override( engine_override = self._resolve_engine_override(
directives_engine=directives.engine, directives_engine=directives.engine,
context=context,
) )
return ResolvedMessage( return ResolvedMessage(
@@ -209,6 +208,14 @@ class TransportRuntime:
context_source=context_source, context_source=context_source,
) )
def project_default_engine(self, context: RunContext | None) -> EngineId | None:
if context is None or context.project is None:
return None
project = self._projects.projects.get(context.project)
if project is None:
return None
return project.default_engine
def _resolve_context( def _resolve_context(
self, self,
*, *,
@@ -253,19 +260,9 @@ class TransportRuntime:
self, self,
*, *,
directives_engine: EngineId | None, directives_engine: EngineId | None,
context: RunContext | None,
) -> EngineId | None: ) -> EngineId | None:
if directives_engine is not None: if directives_engine is not None:
return directives_engine return directives_engine
if context is None:
return None
project = (
self._projects.projects.get(context.project)
if context.project is not None
else None
)
if project is not None and project.default_engine is not None:
return project.default_engine
return None return None
@property @property
+80
View File
@@ -1445,6 +1445,86 @@ async def test_run_main_loop_persists_topic_sessions_in_project_scope(
assert stored == ResumeToken(engine=CODEX_ENGINE, value=resume_value) assert stored == ResumeToken(engine=CODEX_ENGINE, value=resume_value)
@pytest.mark.anyio
async def test_run_main_loop_auto_resumes_topic_default_engine(
tmp_path: Path,
) -> None:
state_path = tmp_path / "takopi.toml"
topic_path = resolve_state_path(state_path)
store = TopicStateStore(topic_path)
await store.set_session_resume(
123, 77, ResumeToken(engine=CODEX_ENGINE, value="resume-codex")
)
await store.set_session_resume(
123, 77, ResumeToken(engine=EngineId("claude"), value="resume-claude")
)
await store.set_default_engine(123, 77, "claude")
transport = _FakeTransport()
bot = _FakeBot()
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
claude_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("claude"))
router = AutoRouter(
entries=[
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
RunnerEntry(engine=claude_runner.engine, runner=claude_runner),
],
default_engine=codex_runner.engine,
)
projects = ProjectsConfig(
projects={
"proj": ProjectConfig(
alias="proj",
path=tmp_path,
worktrees_dir=Path(".worktrees"),
chat_id=123,
)
},
default_project=None,
chat_map={123: "proj"},
)
runtime = TransportRuntime(
router=router,
projects=projects,
config_path=state_path,
)
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=123,
startup_msg="",
exec_cfg=ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
),
topics=TelegramTopicsSettings(
enabled=True,
scope="main",
),
)
async def poller(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=1,
text="hello",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
thread_id=77,
)
await run_main_loop(cfg, poller)
assert codex_runner.calls == []
assert len(claude_runner.calls) == 1
assert claude_runner.calls[0][1] == ResumeToken(
engine=EngineId("claude"), value="resume-claude"
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_run_main_loop_auto_resumes_chat_sessions(tmp_path: Path) -> None: async def test_run_main_loop_auto_resumes_chat_sessions(tmp_path: Path) -> None:
resume_value = "resume-123" resume_value = "resume-123"
+20
View File
@@ -0,0 +1,20 @@
import pytest
from takopi.telegram.chat_prefs import ChatPrefsStore
@pytest.mark.anyio
async def test_chat_prefs_store_roundtrip(tmp_path) -> None:
path = tmp_path / "telegram_chat_prefs_state.json"
store = ChatPrefsStore(path)
await store.set_default_engine(123, "codex")
await store.set_default_engine(123, "codex")
await store.clear_default_engine(456)
assert await store.get_default_engine(123) == "codex"
store2 = ChatPrefsStore(path)
assert await store2.get_default_engine(123) == "codex"
await store2.clear_default_engine(123)
assert await store2.get_default_engine(123) is None
+78
View File
@@ -0,0 +1,78 @@
from pathlib import Path
import pytest
from takopi.config import ProjectConfig, ProjectsConfig
from takopi.context import RunContext
from takopi.model import EngineId
from takopi.router import AutoRouter, RunnerEntry
from takopi.runners.mock import Return, ScriptRunner
from takopi.telegram.chat_prefs import ChatPrefsStore
from takopi.telegram.engine_defaults import resolve_engine_for_message
from takopi.telegram.topic_state import TopicStateStore
from takopi.transport_runtime import TransportRuntime
@pytest.mark.anyio
async def test_resolve_engine_for_message_sources(tmp_path) -> None:
codex = ScriptRunner([Return(answer="ok")], engine=EngineId("codex"))
pi = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
router = AutoRouter(
entries=[
RunnerEntry(engine=codex.engine, runner=codex),
RunnerEntry(engine=pi.engine, runner=pi),
],
default_engine=codex.engine,
)
project = ProjectConfig(
alias="proj",
path=tmp_path,
worktrees_dir=Path(".worktrees"),
default_engine=pi.engine,
)
runtime = TransportRuntime(
router=router,
projects=ProjectsConfig(projects={"proj": project}, default_project=None),
)
chat_prefs = ChatPrefsStore(tmp_path / "telegram_chat_prefs_state.json")
topic_store = TopicStateStore(tmp_path / "telegram_topics_state.json")
await chat_prefs.set_default_engine(1, "pi")
await topic_store.set_default_engine(1, 10, "codex")
resolved = await resolve_engine_for_message(
runtime=runtime,
context=RunContext(project="proj"),
explicit_engine=EngineId("codex"),
chat_id=1,
topic_key=(1, 10),
topic_store=topic_store,
chat_prefs=chat_prefs,
)
assert resolved.source == "directive"
assert resolved.engine == "codex"
await topic_store.clear_default_engine(1, 10)
resolved = await resolve_engine_for_message(
runtime=runtime,
context=RunContext(project="proj"),
explicit_engine=None,
chat_id=1,
topic_key=(1, 10),
topic_store=topic_store,
chat_prefs=chat_prefs,
)
assert resolved.source == "chat_default"
assert resolved.engine == "pi"
await chat_prefs.clear_default_engine(1)
resolved = await resolve_engine_for_message(
runtime=runtime,
context=RunContext(project="proj"),
explicit_engine=None,
chat_id=1,
topic_key=(1, 10),
topic_store=topic_store,
chat_prefs=chat_prefs,
)
assert resolved.source == "project_default"
assert resolved.engine == "pi"
+7
View File
@@ -11,18 +11,21 @@ async def test_topic_state_store_roundtrip(tmp_path) -> None:
store = TopicStateStore(path) store = TopicStateStore(path)
context = RunContext(project="proj", branch="feat/topic") context = RunContext(project="proj", branch="feat/topic")
await store.set_context(1, 10, context) await store.set_context(1, 10, context)
await store.set_default_engine(1, 10, "claude")
await store.set_session_resume(1, 10, ResumeToken(engine="codex", value="abc123")) await store.set_session_resume(1, 10, ResumeToken(engine="codex", value="abc123"))
snapshot = await store.get_thread(1, 10) snapshot = await store.get_thread(1, 10)
assert snapshot is not None assert snapshot is not None
assert snapshot.context == context assert snapshot.context == context
assert snapshot.sessions == {"codex": "abc123"} assert snapshot.sessions == {"codex": "abc123"}
assert snapshot.default_engine == "claude"
store2 = TopicStateStore(path) store2 = TopicStateStore(path)
snapshot2 = await store2.get_thread(1, 10) snapshot2 = await store2.get_thread(1, 10)
assert snapshot2 is not None assert snapshot2 is not None
assert snapshot2.context == context assert snapshot2.context == context
assert snapshot2.sessions == {"codex": "abc123"} assert snapshot2.sessions == {"codex": "abc123"}
assert snapshot2.default_engine == "claude"
@pytest.mark.anyio @pytest.mark.anyio
@@ -47,3 +50,7 @@ async def test_topic_state_store_clear_and_find(tmp_path) -> None:
snapshot = await store.get_thread(2, 20) snapshot = await store.get_thread(2, 20)
assert snapshot is not None assert snapshot is not None
assert snapshot.context is None assert snapshot.context is None
await store.clear_default_engine(2, 20)
snapshot = await store.get_thread(2, 20)
assert snapshot is not None
assert snapshot.default_engine is None