feat(telegram): add per-chat/topic default agents (#109)
This commit is contained in:
@@ -2,3 +2,4 @@ after you finish work, commit with a conventional message. only commit the files
|
||||
always run `just check` before code commits.
|
||||
if you fix anything from `just check`, rerun it and confirm it passes before committing.
|
||||
when using gh to edit or create PR descriptions, prefer `--body-file` to preserve newlines.
|
||||
always include a "Manual testing" checklist section in PRs.
|
||||
|
||||
@@ -353,6 +353,15 @@ Decision (v0.4.0):
|
||||
* Users MAY prefix the first non-empty line with `/{engine}` (e.g. `/claude`, `/codex`, or `/pi`) to select the engine for a **new** thread.
|
||||
* The bridge MUST strip that directive from the prompt before invoking the runner.
|
||||
* If a ResumeToken is resolved from the message or reply, it MUST take precedence and the `/{engine}` directive MUST be ignored.
|
||||
* Bridges MAY persist default engine overrides per Telegram scope:
|
||||
* **Topic default**: forum topic (`chat_id + thread_id`)
|
||||
* **Chat default**: chat (`chat_id`)
|
||||
* When no ResumeToken is resolved, engine selection MUST follow this precedence:
|
||||
1) explicit `/{engine}` directive
|
||||
2) topic default (if any)
|
||||
3) chat default (if any)
|
||||
4) project default engine (if configured for the resolved context)
|
||||
5) global default engine
|
||||
|
||||
### 8.1 Command menu (Telegram)
|
||||
|
||||
|
||||
+19
-2
@@ -122,6 +122,22 @@ Prefix your message with an engine directive to override the default:
|
||||
|
||||
Directives are only parsed at the start of the first non-empty line.
|
||||
|
||||
### Default agent per chat or topic
|
||||
|
||||
Use `/agent` to view or set a persistent default for the current scope:
|
||||
|
||||
```
|
||||
/agent
|
||||
/agent set claude
|
||||
/agent clear
|
||||
```
|
||||
|
||||
- Inside a forum topic, `/agent set` affects that topic.
|
||||
- In normal chats, it affects the whole chat.
|
||||
- In group chats, only admins can change defaults.
|
||||
|
||||
Precedence (highest to lowest): resume token → `/engine` directive → topic default → chat default → project default → global default.
|
||||
|
||||
### Setting up engines
|
||||
|
||||
Takopi shells out to the agent CLIs. Install them and make sure they're on your `PATH`
|
||||
@@ -257,7 +273,7 @@ takopi chat-id
|
||||
|
||||
## 7. Topics
|
||||
|
||||
Topics bind Telegram forum threads to specific project/branch contexts. They also preserve resume tokens, so agents can pick up where they left off.
|
||||
Topics bind Telegram forum threads to specific project/branch contexts. They also preserve resume tokens and can store a default agent per topic.
|
||||
|
||||
### Enabling topics
|
||||
|
||||
@@ -341,7 +357,7 @@ path = "~/dev/takopi"
|
||||
chat_id = -1001111111111 # forum-enabled group
|
||||
```
|
||||
|
||||
Topic state is stored in `telegram_topics_state.json` next to your config file.
|
||||
Topic state is stored in `telegram_topics_state.json` next to your config file. Chat defaults live in `telegram_chat_prefs_state.json`.
|
||||
|
||||
---
|
||||
|
||||
@@ -496,6 +512,7 @@ worktree_base = "develop"
|
||||
| `/cancel` | Reply to the progress message to stop the current run |
|
||||
| `/file put <path>` | Upload a document into the repo/worktree |
|
||||
| `/file get <path>` | Fetch a file (directories are zipped) |
|
||||
| `/agent` | Show/set the default agent for the current scope |
|
||||
| `/topic <project> @branch` | Create/bind a topic |
|
||||
| `/ctx` | Show current context |
|
||||
| `/ctx set <project> @branch` | Update context binding |
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import msgspec
|
||||
|
||||
from ..logging import get_logger
|
||||
from .state_store import JsonStateStore
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
STATE_VERSION = 1
|
||||
STATE_FILENAME = "telegram_chat_prefs_state.json"
|
||||
|
||||
|
||||
class _ChatPrefs(msgspec.Struct, forbid_unknown_fields=False):
|
||||
default_engine: str | None = None
|
||||
|
||||
|
||||
class _ChatPrefsState(msgspec.Struct, forbid_unknown_fields=False):
|
||||
version: int
|
||||
chats: dict[str, _ChatPrefs] = msgspec.field(default_factory=dict)
|
||||
|
||||
|
||||
def resolve_prefs_path(config_path: Path) -> Path:
|
||||
return config_path.with_name(STATE_FILENAME)
|
||||
|
||||
|
||||
def _chat_key(chat_id: int) -> str:
|
||||
return str(chat_id)
|
||||
|
||||
|
||||
def _normalize_text(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
value = value.strip()
|
||||
return value or None
|
||||
|
||||
|
||||
def _new_state() -> _ChatPrefsState:
|
||||
return _ChatPrefsState(version=STATE_VERSION, chats={})
|
||||
|
||||
|
||||
class ChatPrefsStore(JsonStateStore[_ChatPrefsState]):
|
||||
def __init__(self, path: Path) -> None:
|
||||
super().__init__(
|
||||
path,
|
||||
version=STATE_VERSION,
|
||||
state_type=_ChatPrefsState,
|
||||
state_factory=_new_state,
|
||||
log_prefix="telegram.chat_prefs",
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
async def get_default_engine(self, chat_id: int) -> str | None:
|
||||
async with self._lock:
|
||||
self._reload_locked_if_needed()
|
||||
chat = self._get_chat_locked(chat_id)
|
||||
if chat is None:
|
||||
return None
|
||||
return _normalize_text(chat.default_engine)
|
||||
|
||||
async def set_default_engine(self, chat_id: int, engine: str | None) -> None:
|
||||
normalized = _normalize_text(engine)
|
||||
async with self._lock:
|
||||
self._reload_locked_if_needed()
|
||||
if normalized is None:
|
||||
if self._remove_chat_locked(chat_id):
|
||||
self._save_locked()
|
||||
return
|
||||
chat = self._ensure_chat_locked(chat_id)
|
||||
chat.default_engine = normalized
|
||||
self._save_locked()
|
||||
|
||||
async def clear_default_engine(self, chat_id: int) -> None:
|
||||
await self.set_default_engine(chat_id, None)
|
||||
|
||||
def _get_chat_locked(self, chat_id: int) -> _ChatPrefs | None:
|
||||
return self._state.chats.get(_chat_key(chat_id))
|
||||
|
||||
def _ensure_chat_locked(self, chat_id: int) -> _ChatPrefs:
|
||||
key = _chat_key(chat_id)
|
||||
entry = self._state.chats.get(key)
|
||||
if entry is not None:
|
||||
return entry
|
||||
entry = _ChatPrefs()
|
||||
self._state.chats[key] = entry
|
||||
return entry
|
||||
|
||||
def _remove_chat_locked(self, chat_id: int) -> bool:
|
||||
key = _chat_key(chat_id)
|
||||
if key not in self._state.chats:
|
||||
return False
|
||||
del self._state.chats[key]
|
||||
return True
|
||||
@@ -38,6 +38,7 @@ from ..transport import MessageRef, RenderedMessage, SendOptions
|
||||
from ..transport_runtime import ResolvedMessage, TransportRuntime
|
||||
from ..utils.paths import reset_run_base_dir, set_run_base_dir
|
||||
from .bridge import TelegramBridgeConfig, send_plain
|
||||
from .chat_prefs import ChatPrefsStore
|
||||
from .chat_sessions import ChatSessionStore
|
||||
from .context import (
|
||||
_format_context,
|
||||
@@ -47,6 +48,7 @@ from .context import (
|
||||
_usage_ctx_set,
|
||||
_usage_topic,
|
||||
)
|
||||
from .engine_defaults import resolve_engine_for_message
|
||||
from .files import (
|
||||
default_upload_name,
|
||||
default_upload_path,
|
||||
@@ -79,6 +81,7 @@ __all__ = [
|
||||
"FILE_GET_USAGE",
|
||||
"FILE_PUT_USAGE",
|
||||
"_dispatch_command",
|
||||
"_handle_agent_command",
|
||||
"_handle_chat_new_command",
|
||||
"_handle_file_command",
|
||||
"_handle_file_get",
|
||||
@@ -97,6 +100,7 @@ __all__ = [
|
||||
_MAX_BOT_COMMANDS = 100
|
||||
FILE_PUT_USAGE = "usage: `/file put <path>`"
|
||||
FILE_GET_USAGE = "usage: `/file get <path>`"
|
||||
AGENT_USAGE = "usage: `/agent`, `/agent set <engine>`, or `/agent clear`"
|
||||
|
||||
|
||||
def is_cancel_command(text: str) -> bool:
|
||||
@@ -362,6 +366,29 @@ async def _check_file_permissions(
|
||||
return False
|
||||
|
||||
|
||||
async def _check_agent_permissions(
|
||||
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
|
||||
) -> bool:
|
||||
reply = _reply_sender(cfg, msg)
|
||||
sender_id = msg.sender_id
|
||||
if sender_id is None:
|
||||
await reply(text="cannot verify sender for agent defaults.")
|
||||
return False
|
||||
is_private = msg.chat_type == "private"
|
||||
if msg.chat_type is None:
|
||||
is_private = msg.chat_id > 0
|
||||
if is_private:
|
||||
return True
|
||||
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
|
||||
if member is None:
|
||||
await reply(text="failed to verify agent permissions.")
|
||||
return False
|
||||
if member.status in {"creator", "administrator"}:
|
||||
return True
|
||||
await reply(text="changing default agents is restricted to group admins.")
|
||||
return False
|
||||
|
||||
|
||||
async def _prepare_file_put_plan(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
@@ -1034,6 +1061,125 @@ async def _handle_ctx_command(
|
||||
)
|
||||
|
||||
|
||||
async def _handle_agent_command(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
chat_prefs: ChatPrefsStore | None,
|
||||
*,
|
||||
resolved_scope: str | None = None,
|
||||
scope_chat_ids: frozenset[int] | None = None,
|
||||
) -> None:
|
||||
reply = _reply_sender(cfg, msg)
|
||||
tkey = (
|
||||
_topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
|
||||
if topic_store is not None
|
||||
else None
|
||||
)
|
||||
tokens = split_command_args(args_text)
|
||||
action = tokens[0].lower() if tokens else "show"
|
||||
|
||||
if action in {"show", ""}:
|
||||
try:
|
||||
resolved = cfg.runtime.resolve_message(
|
||||
text="",
|
||||
reply_text=msg.reply_to_text,
|
||||
ambient_context=ambient_context,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
except DirectiveError as exc:
|
||||
await reply(text=f"error:\n{exc}")
|
||||
return
|
||||
selection = await resolve_engine_for_message(
|
||||
runtime=cfg.runtime,
|
||||
context=resolved.context,
|
||||
explicit_engine=None,
|
||||
chat_id=msg.chat_id,
|
||||
topic_key=tkey,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
)
|
||||
source_labels = {
|
||||
"directive": "directive",
|
||||
"topic_default": "topic default",
|
||||
"chat_default": "chat default",
|
||||
"project_default": "project default",
|
||||
"global_default": "global default",
|
||||
}
|
||||
agent_line = f"agent: {selection.engine} ({source_labels[selection.source]})"
|
||||
topic_default = selection.topic_default or "none"
|
||||
if tkey is None:
|
||||
topic_default = "none"
|
||||
if chat_prefs is None:
|
||||
chat_default = "unavailable"
|
||||
else:
|
||||
chat_default = selection.chat_default or "none"
|
||||
project_default = (
|
||||
selection.project_default
|
||||
if selection.project_default is not None
|
||||
else "none"
|
||||
)
|
||||
defaults_line = (
|
||||
"defaults: "
|
||||
f"topic: {topic_default}, "
|
||||
f"chat: {chat_default}, "
|
||||
f"project: {project_default}, "
|
||||
f"global: {cfg.runtime.default_engine}"
|
||||
)
|
||||
available = ", ".join(cfg.runtime.engine_ids)
|
||||
available_line = f"available: {available}"
|
||||
await reply(text="\n\n".join([agent_line, defaults_line, available_line]))
|
||||
return
|
||||
|
||||
if action == "set":
|
||||
if len(tokens) < 2:
|
||||
await reply(text=AGENT_USAGE)
|
||||
return
|
||||
if not await _check_agent_permissions(cfg, msg):
|
||||
return
|
||||
engine = tokens[1].strip().lower()
|
||||
if engine not in cfg.runtime.engine_ids:
|
||||
available = ", ".join(cfg.runtime.engine_ids)
|
||||
await reply(
|
||||
text=f"unknown engine `{engine}`.\navailable agents: `{available}`",
|
||||
)
|
||||
return
|
||||
if tkey is not None:
|
||||
if topic_store is None:
|
||||
await reply(text="topic defaults are unavailable.")
|
||||
return
|
||||
await topic_store.set_default_engine(tkey[0], tkey[1], engine)
|
||||
await reply(text=f"topic default agent set to `{engine}`")
|
||||
return
|
||||
if chat_prefs is None:
|
||||
await reply(text="chat defaults are unavailable (no config path).")
|
||||
return
|
||||
await chat_prefs.set_default_engine(msg.chat_id, engine)
|
||||
await reply(text=f"chat default agent set to `{engine}`")
|
||||
return
|
||||
|
||||
if action == "clear":
|
||||
if not await _check_agent_permissions(cfg, msg):
|
||||
return
|
||||
if tkey is not None:
|
||||
if topic_store is None:
|
||||
await reply(text="topic defaults are unavailable.")
|
||||
return
|
||||
await topic_store.clear_default_engine(tkey[0], tkey[1])
|
||||
await reply(text="topic default agent cleared.")
|
||||
return
|
||||
if chat_prefs is None:
|
||||
await reply(text="chat defaults are unavailable (no config path).")
|
||||
return
|
||||
await chat_prefs.clear_default_engine(msg.chat_id)
|
||||
await reply(text="chat default agent cleared.")
|
||||
return
|
||||
|
||||
await reply(text=AGENT_USAGE)
|
||||
|
||||
|
||||
async def _handle_new_command(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
@@ -1369,6 +1515,7 @@ class _TelegramCommandExecutor(CommandExecutor):
|
||||
thread_id: int | None,
|
||||
show_resume_line: bool,
|
||||
stateful_mode: bool,
|
||||
default_engine_override: EngineId | None,
|
||||
) -> None:
|
||||
self._exec_cfg = exec_cfg
|
||||
self._runtime = runtime
|
||||
@@ -1380,6 +1527,7 @@ class _TelegramCommandExecutor(CommandExecutor):
|
||||
self._thread_id = thread_id
|
||||
self._show_resume_line = show_resume_line
|
||||
self._stateful_mode = stateful_mode
|
||||
self._default_engine_override = default_engine_override
|
||||
self._reply_ref = MessageRef(
|
||||
channel_id=chat_id,
|
||||
message_id=user_msg_id,
|
||||
@@ -1398,6 +1546,15 @@ class _TelegramCommandExecutor(CommandExecutor):
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _apply_default_engine(self, request: RunRequest) -> RunRequest:
|
||||
if request.engine is not None or self._default_engine_override is None:
|
||||
return request
|
||||
return RunRequest(
|
||||
prompt=request.prompt,
|
||||
engine=self._default_engine_override,
|
||||
context=request.context,
|
||||
)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
message: RenderedMessage | str,
|
||||
@@ -1425,6 +1582,7 @@ class _TelegramCommandExecutor(CommandExecutor):
|
||||
self, request: RunRequest, *, mode: RunMode = "emit"
|
||||
) -> RunResult:
|
||||
request = self._apply_default_context(request)
|
||||
request = self._apply_default_engine(request)
|
||||
effective_show_resume_line = _should_show_resume_line(
|
||||
show_resume_line=self._show_resume_line,
|
||||
stateful_mode=self._stateful_mode,
|
||||
@@ -1511,6 +1669,7 @@ async def _dispatch_command(
|
||||
scheduler: ThreadScheduler,
|
||||
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None,
|
||||
stateful_mode: bool,
|
||||
default_engine_override: EngineId | None,
|
||||
) -> None:
|
||||
allowlist = cfg.runtime.allowlist
|
||||
chat_id = msg.chat_id
|
||||
@@ -1535,6 +1694,7 @@ async def _dispatch_command(
|
||||
thread_id=msg.thread_id,
|
||||
show_resume_line=cfg.show_resume_line,
|
||||
stateful_mode=stateful_mode,
|
||||
default_engine_override=default_engine_override,
|
||||
)
|
||||
message_ref = MessageRef(
|
||||
channel_id=chat_id,
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from ..context import RunContext
|
||||
from ..model import EngineId
|
||||
from ..transport_runtime import TransportRuntime
|
||||
from .chat_prefs import ChatPrefsStore
|
||||
from .topic_state import TopicStateStore
|
||||
|
||||
EngineSource = Literal[
|
||||
"directive",
|
||||
"topic_default",
|
||||
"chat_default",
|
||||
"project_default",
|
||||
"global_default",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EngineResolution:
|
||||
engine: EngineId
|
||||
source: EngineSource
|
||||
topic_default: EngineId | None
|
||||
chat_default: EngineId | None
|
||||
project_default: EngineId | None
|
||||
|
||||
|
||||
async def resolve_engine_for_message(
|
||||
*,
|
||||
runtime: TransportRuntime,
|
||||
context: RunContext | None,
|
||||
explicit_engine: EngineId | None,
|
||||
chat_id: int,
|
||||
topic_key: tuple[int, int] | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
chat_prefs: ChatPrefsStore | None,
|
||||
) -> EngineResolution:
|
||||
topic_default = None
|
||||
if topic_store is not None and topic_key is not None:
|
||||
topic_default = await topic_store.get_default_engine(*topic_key)
|
||||
chat_default = None
|
||||
if chat_prefs is not None:
|
||||
chat_default = await chat_prefs.get_default_engine(chat_id)
|
||||
project_default = runtime.project_default_engine(context)
|
||||
|
||||
if explicit_engine is not None:
|
||||
return EngineResolution(
|
||||
engine=explicit_engine,
|
||||
source="directive",
|
||||
topic_default=topic_default,
|
||||
chat_default=chat_default,
|
||||
project_default=project_default,
|
||||
)
|
||||
if topic_default is not None:
|
||||
return EngineResolution(
|
||||
engine=topic_default,
|
||||
source="topic_default",
|
||||
topic_default=topic_default,
|
||||
chat_default=chat_default,
|
||||
project_default=project_default,
|
||||
)
|
||||
if chat_default is not None:
|
||||
return EngineResolution(
|
||||
engine=chat_default,
|
||||
source="chat_default",
|
||||
topic_default=topic_default,
|
||||
chat_default=chat_default,
|
||||
project_default=project_default,
|
||||
)
|
||||
if project_default is not None:
|
||||
return EngineResolution(
|
||||
engine=project_default,
|
||||
source="project_default",
|
||||
topic_default=topic_default,
|
||||
chat_default=chat_default,
|
||||
project_default=project_default,
|
||||
)
|
||||
return EngineResolution(
|
||||
engine=runtime.default_engine,
|
||||
source="global_default",
|
||||
topic_default=topic_default,
|
||||
chat_default=chat_default,
|
||||
project_default=project_default,
|
||||
)
|
||||
+73
-18
@@ -23,6 +23,7 @@ from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
|
||||
from .commands import (
|
||||
FILE_PUT_USAGE,
|
||||
_dispatch_command,
|
||||
_handle_agent_command,
|
||||
_handle_chat_new_command,
|
||||
_handle_ctx_command,
|
||||
_handle_file_command,
|
||||
@@ -50,7 +51,9 @@ from .topics import (
|
||||
_validate_topics_setup,
|
||||
)
|
||||
from .client import poll_incoming
|
||||
from .chat_prefs import ChatPrefsStore, resolve_prefs_path
|
||||
from .chat_sessions import ChatSessionStore, resolve_sessions_path
|
||||
from .engine_defaults import resolve_engine_for_message
|
||||
from .topic_state import TopicStateStore, resolve_state_path
|
||||
from .types import (
|
||||
TelegramCallbackQuery,
|
||||
@@ -110,6 +113,7 @@ def _dispatch_builtin_command(
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
chat_prefs: ChatPrefsStore | None,
|
||||
resolved_scope: str | None,
|
||||
scope_chat_ids: frozenset[int],
|
||||
reply: Callable[..., Awaitable[None]],
|
||||
@@ -165,6 +169,19 @@ def _dispatch_builtin_command(
|
||||
}
|
||||
)
|
||||
|
||||
if command_id == "agent":
|
||||
handlers["agent"] = partial(
|
||||
_handle_agent_command,
|
||||
cfg,
|
||||
msg,
|
||||
args_text,
|
||||
ambient_context,
|
||||
topic_store,
|
||||
chat_prefs,
|
||||
resolved_scope=resolved_scope,
|
||||
scope_chat_ids=scope_chat_ids,
|
||||
)
|
||||
|
||||
handler = handlers.get(command_id)
|
||||
if handler is None:
|
||||
return False
|
||||
@@ -311,6 +328,7 @@ async def run_main_loop(
|
||||
)
|
||||
topic_store: TopicStateStore | None = None
|
||||
chat_session_store: ChatSessionStore | None = None
|
||||
chat_prefs: ChatPrefsStore | None = None
|
||||
media_groups: dict[tuple[int, str], _MediaGroupState] = {}
|
||||
resolved_topics_scope: str | None = None
|
||||
topics_chat_ids: frozenset[int] = frozenset()
|
||||
@@ -333,6 +351,12 @@ async def run_main_loop(
|
||||
|
||||
try:
|
||||
config_path = cfg.runtime.config_path
|
||||
if config_path is not None:
|
||||
chat_prefs = ChatPrefsStore(resolve_prefs_path(config_path))
|
||||
logger.info(
|
||||
"chat_prefs.enabled",
|
||||
state_path=str(resolve_prefs_path(config_path)),
|
||||
)
|
||||
if cfg.session_mode == "chat":
|
||||
if config_path is None:
|
||||
raise ConfigError(
|
||||
@@ -552,6 +576,23 @@ async def run_main_loop(
|
||||
return None
|
||||
return resolved
|
||||
|
||||
async def resolve_engine_defaults(
|
||||
*,
|
||||
explicit_engine: EngineId | None,
|
||||
context: RunContext | None,
|
||||
chat_id: int,
|
||||
topic_key: tuple[int, int] | None,
|
||||
):
|
||||
return await resolve_engine_for_message(
|
||||
runtime=cfg.runtime,
|
||||
context=context,
|
||||
explicit_engine=explicit_engine,
|
||||
chat_id=chat_id,
|
||||
topic_key=topic_key,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
)
|
||||
|
||||
async def run_prompt_from_upload(
|
||||
msg: TelegramIncomingMessage,
|
||||
prompt_text: str,
|
||||
@@ -570,7 +611,6 @@ async def run_main_loop(
|
||||
else None
|
||||
)
|
||||
resume_token = resolved.resume_token
|
||||
engine_override = resolved.engine_override
|
||||
context = resolved.context
|
||||
chat_session_key = _chat_session_key(msg, store=chat_session_store)
|
||||
topic_key = (
|
||||
@@ -578,6 +618,13 @@ async def run_main_loop(
|
||||
if topic_store is not None
|
||||
else None
|
||||
)
|
||||
engine_resolution = await resolve_engine_defaults(
|
||||
explicit_engine=resolved.engine_override,
|
||||
context=context,
|
||||
chat_id=chat_id,
|
||||
topic_key=topic_key,
|
||||
)
|
||||
engine_override = engine_resolution.engine
|
||||
if resume_token is None and reply_id is not None:
|
||||
running_task = running_tasks.get(
|
||||
MessageRef(channel_id=chat_id, message_id=reply_id)
|
||||
@@ -600,10 +647,7 @@ async def run_main_loop(
|
||||
and topic_store is not None
|
||||
and topic_key is not None
|
||||
):
|
||||
engine_for_session = cfg.runtime.resolve_engine(
|
||||
engine_override=engine_override,
|
||||
context=context,
|
||||
)
|
||||
engine_for_session = engine_resolution.engine
|
||||
stored = await topic_store.get_session_resume(
|
||||
topic_key[0], topic_key[1], engine_for_session
|
||||
)
|
||||
@@ -614,10 +658,7 @@ async def run_main_loop(
|
||||
and chat_session_store is not None
|
||||
and chat_session_key is not None
|
||||
):
|
||||
engine_for_session = cfg.runtime.resolve_engine(
|
||||
engine_override=engine_override,
|
||||
context=context,
|
||||
)
|
||||
engine_for_session = engine_resolution.engine
|
||||
stored = await chat_session_store.get_session_resume(
|
||||
chat_session_key[0],
|
||||
chat_session_key[1],
|
||||
@@ -815,6 +856,7 @@ async def run_main_loop(
|
||||
args_text=args_text,
|
||||
ambient_context=ambient_context,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
resolved_scope=resolved_topics_scope,
|
||||
scope_chat_ids=topics_chat_ids,
|
||||
reply=reply,
|
||||
@@ -853,6 +895,18 @@ async def run_main_loop(
|
||||
if command_id not in command_ids:
|
||||
refresh_commands()
|
||||
if command_id in command_ids:
|
||||
engine_resolution = await resolve_engine_defaults(
|
||||
explicit_engine=None,
|
||||
context=ambient_context,
|
||||
chat_id=chat_id,
|
||||
topic_key=topic_key,
|
||||
)
|
||||
default_engine_override = (
|
||||
engine_resolution.engine
|
||||
if engine_resolution.source
|
||||
in {"directive", "topic_default", "chat_default"}
|
||||
else None
|
||||
)
|
||||
tg.start_soon(
|
||||
_dispatch_command,
|
||||
cfg,
|
||||
@@ -868,6 +922,7 @@ async def run_main_loop(
|
||||
chat_session_key,
|
||||
),
|
||||
stateful_mode,
|
||||
default_engine_override,
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -885,8 +940,14 @@ async def run_main_loop(
|
||||
|
||||
text = resolved.prompt
|
||||
resume_token = resolved.resume_token
|
||||
engine_override = resolved.engine_override
|
||||
context = resolved.context
|
||||
engine_resolution = await resolve_engine_defaults(
|
||||
explicit_engine=resolved.engine_override,
|
||||
context=context,
|
||||
chat_id=chat_id,
|
||||
topic_key=topic_key,
|
||||
)
|
||||
engine_override = engine_resolution.engine
|
||||
if (
|
||||
topic_store is not None
|
||||
and topic_key is not None
|
||||
@@ -936,10 +997,7 @@ async def run_main_loop(
|
||||
and topic_store is not None
|
||||
and topic_key is not None
|
||||
):
|
||||
engine_for_session = cfg.runtime.resolve_engine(
|
||||
engine_override=engine_override,
|
||||
context=context,
|
||||
)
|
||||
engine_for_session = engine_resolution.engine
|
||||
stored = await topic_store.get_session_resume(
|
||||
topic_key[0], topic_key[1], engine_for_session
|
||||
)
|
||||
@@ -950,10 +1008,7 @@ async def run_main_loop(
|
||||
and chat_session_store is not None
|
||||
and chat_session_key is not None
|
||||
):
|
||||
engine_for_session = cfg.runtime.resolve_engine(
|
||||
engine_override=engine_override,
|
||||
context=context,
|
||||
)
|
||||
engine_for_session = engine_resolution.engine
|
||||
stored = await chat_session_store.get_session_resume(
|
||||
chat_session_key[0], chat_session_key[1], engine_for_session
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ class TopicThreadSnapshot:
|
||||
context: RunContext | None
|
||||
sessions: dict[str, str]
|
||||
topic_title: str | None
|
||||
default_engine: str | None
|
||||
|
||||
|
||||
class _ContextState(msgspec.Struct, forbid_unknown_fields=False):
|
||||
@@ -38,6 +39,7 @@ class _ThreadState(msgspec.Struct, forbid_unknown_fields=False):
|
||||
context: _ContextState | None = None
|
||||
sessions: dict[str, _SessionState] = msgspec.field(default_factory=dict)
|
||||
topic_title: str | None = None
|
||||
default_engine: str | None = None
|
||||
|
||||
|
||||
class _TopicState(msgspec.Struct, forbid_unknown_fields=False):
|
||||
@@ -151,6 +153,27 @@ class TopicStateStore(JsonStateStore[_TopicState]):
|
||||
return None
|
||||
return ResumeToken(engine=engine, value=entry.resume)
|
||||
|
||||
async def get_default_engine(self, chat_id: int, thread_id: int) -> str | None:
|
||||
async with self._lock:
|
||||
self._reload_locked_if_needed()
|
||||
thread = self._get_thread_locked(chat_id, thread_id)
|
||||
if thread is None:
|
||||
return None
|
||||
return _normalize_text(thread.default_engine)
|
||||
|
||||
async def set_default_engine(
|
||||
self, chat_id: int, thread_id: int, engine: str | None
|
||||
) -> None:
|
||||
normalized = _normalize_text(engine)
|
||||
async with self._lock:
|
||||
self._reload_locked_if_needed()
|
||||
thread = self._ensure_thread_locked(chat_id, thread_id)
|
||||
thread.default_engine = normalized
|
||||
self._save_locked()
|
||||
|
||||
async def clear_default_engine(self, chat_id: int, thread_id: int) -> None:
|
||||
await self.set_default_engine(chat_id, thread_id, None)
|
||||
|
||||
async def set_session_resume(
|
||||
self, chat_id: int, thread_id: int, token: ResumeToken
|
||||
) -> None:
|
||||
@@ -205,6 +228,7 @@ class TopicStateStore(JsonStateStore[_TopicState]):
|
||||
context=_context_from_state(thread.context),
|
||||
sessions=sessions,
|
||||
topic_title=thread.topic_title,
|
||||
default_engine=_normalize_text(thread.default_engine),
|
||||
)
|
||||
|
||||
def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None:
|
||||
|
||||
@@ -198,7 +198,6 @@ class TransportRuntime:
|
||||
)
|
||||
engine_override = self._resolve_engine_override(
|
||||
directives_engine=directives.engine,
|
||||
context=context,
|
||||
)
|
||||
|
||||
return ResolvedMessage(
|
||||
@@ -209,6 +208,14 @@ class TransportRuntime:
|
||||
context_source=context_source,
|
||||
)
|
||||
|
||||
def project_default_engine(self, context: RunContext | None) -> EngineId | None:
|
||||
if context is None or context.project is None:
|
||||
return None
|
||||
project = self._projects.projects.get(context.project)
|
||||
if project is None:
|
||||
return None
|
||||
return project.default_engine
|
||||
|
||||
def _resolve_context(
|
||||
self,
|
||||
*,
|
||||
@@ -253,19 +260,9 @@ class TransportRuntime:
|
||||
self,
|
||||
*,
|
||||
directives_engine: EngineId | None,
|
||||
context: RunContext | None,
|
||||
) -> EngineId | None:
|
||||
if directives_engine is not None:
|
||||
return directives_engine
|
||||
if context is None:
|
||||
return None
|
||||
project = (
|
||||
self._projects.projects.get(context.project)
|
||||
if context.project is not None
|
||||
else None
|
||||
)
|
||||
if project is not None and project.default_engine is not None:
|
||||
return project.default_engine
|
||||
return None
|
||||
|
||||
@property
|
||||
|
||||
@@ -1445,6 +1445,86 @@ async def test_run_main_loop_persists_topic_sessions_in_project_scope(
|
||||
assert stored == ResumeToken(engine=CODEX_ENGINE, value=resume_value)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_main_loop_auto_resumes_topic_default_engine(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
state_path = tmp_path / "takopi.toml"
|
||||
topic_path = resolve_state_path(state_path)
|
||||
store = TopicStateStore(topic_path)
|
||||
await store.set_session_resume(
|
||||
123, 77, ResumeToken(engine=CODEX_ENGINE, value="resume-codex")
|
||||
)
|
||||
await store.set_session_resume(
|
||||
123, 77, ResumeToken(engine=EngineId("claude"), value="resume-claude")
|
||||
)
|
||||
await store.set_default_engine(123, 77, "claude")
|
||||
|
||||
transport = _FakeTransport()
|
||||
bot = _FakeBot()
|
||||
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
|
||||
claude_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("claude"))
|
||||
router = AutoRouter(
|
||||
entries=[
|
||||
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
|
||||
RunnerEntry(engine=claude_runner.engine, runner=claude_runner),
|
||||
],
|
||||
default_engine=codex_runner.engine,
|
||||
)
|
||||
projects = ProjectsConfig(
|
||||
projects={
|
||||
"proj": ProjectConfig(
|
||||
alias="proj",
|
||||
path=tmp_path,
|
||||
worktrees_dir=Path(".worktrees"),
|
||||
chat_id=123,
|
||||
)
|
||||
},
|
||||
default_project=None,
|
||||
chat_map={123: "proj"},
|
||||
)
|
||||
runtime = TransportRuntime(
|
||||
router=router,
|
||||
projects=projects,
|
||||
config_path=state_path,
|
||||
)
|
||||
cfg = TelegramBridgeConfig(
|
||||
bot=bot,
|
||||
runtime=runtime,
|
||||
chat_id=123,
|
||||
startup_msg="",
|
||||
exec_cfg=ExecBridgeConfig(
|
||||
transport=transport,
|
||||
presenter=MarkdownPresenter(),
|
||||
final_notify=True,
|
||||
),
|
||||
topics=TelegramTopicsSettings(
|
||||
enabled=True,
|
||||
scope="main",
|
||||
),
|
||||
)
|
||||
|
||||
async def poller(_cfg: TelegramBridgeConfig):
|
||||
yield TelegramIncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=1,
|
||||
text="hello",
|
||||
reply_to_message_id=None,
|
||||
reply_to_text=None,
|
||||
sender_id=123,
|
||||
thread_id=77,
|
||||
)
|
||||
|
||||
await run_main_loop(cfg, poller)
|
||||
|
||||
assert codex_runner.calls == []
|
||||
assert len(claude_runner.calls) == 1
|
||||
assert claude_runner.calls[0][1] == ResumeToken(
|
||||
engine=EngineId("claude"), value="resume-claude"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_main_loop_auto_resumes_chat_sessions(tmp_path: Path) -> None:
|
||||
resume_value = "resume-123"
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
|
||||
from takopi.telegram.chat_prefs import ChatPrefsStore
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_chat_prefs_store_roundtrip(tmp_path) -> None:
|
||||
path = tmp_path / "telegram_chat_prefs_state.json"
|
||||
store = ChatPrefsStore(path)
|
||||
await store.set_default_engine(123, "codex")
|
||||
await store.set_default_engine(123, "codex")
|
||||
await store.clear_default_engine(456)
|
||||
|
||||
assert await store.get_default_engine(123) == "codex"
|
||||
|
||||
store2 = ChatPrefsStore(path)
|
||||
assert await store2.get_default_engine(123) == "codex"
|
||||
|
||||
await store2.clear_default_engine(123)
|
||||
assert await store2.get_default_engine(123) is None
|
||||
@@ -0,0 +1,78 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from takopi.config import ProjectConfig, ProjectsConfig
|
||||
from takopi.context import RunContext
|
||||
from takopi.model import EngineId
|
||||
from takopi.router import AutoRouter, RunnerEntry
|
||||
from takopi.runners.mock import Return, ScriptRunner
|
||||
from takopi.telegram.chat_prefs import ChatPrefsStore
|
||||
from takopi.telegram.engine_defaults import resolve_engine_for_message
|
||||
from takopi.telegram.topic_state import TopicStateStore
|
||||
from takopi.transport_runtime import TransportRuntime
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_resolve_engine_for_message_sources(tmp_path) -> None:
|
||||
codex = ScriptRunner([Return(answer="ok")], engine=EngineId("codex"))
|
||||
pi = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
|
||||
router = AutoRouter(
|
||||
entries=[
|
||||
RunnerEntry(engine=codex.engine, runner=codex),
|
||||
RunnerEntry(engine=pi.engine, runner=pi),
|
||||
],
|
||||
default_engine=codex.engine,
|
||||
)
|
||||
project = ProjectConfig(
|
||||
alias="proj",
|
||||
path=tmp_path,
|
||||
worktrees_dir=Path(".worktrees"),
|
||||
default_engine=pi.engine,
|
||||
)
|
||||
runtime = TransportRuntime(
|
||||
router=router,
|
||||
projects=ProjectsConfig(projects={"proj": project}, default_project=None),
|
||||
)
|
||||
chat_prefs = ChatPrefsStore(tmp_path / "telegram_chat_prefs_state.json")
|
||||
topic_store = TopicStateStore(tmp_path / "telegram_topics_state.json")
|
||||
await chat_prefs.set_default_engine(1, "pi")
|
||||
await topic_store.set_default_engine(1, 10, "codex")
|
||||
|
||||
resolved = await resolve_engine_for_message(
|
||||
runtime=runtime,
|
||||
context=RunContext(project="proj"),
|
||||
explicit_engine=EngineId("codex"),
|
||||
chat_id=1,
|
||||
topic_key=(1, 10),
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
)
|
||||
assert resolved.source == "directive"
|
||||
assert resolved.engine == "codex"
|
||||
|
||||
await topic_store.clear_default_engine(1, 10)
|
||||
resolved = await resolve_engine_for_message(
|
||||
runtime=runtime,
|
||||
context=RunContext(project="proj"),
|
||||
explicit_engine=None,
|
||||
chat_id=1,
|
||||
topic_key=(1, 10),
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
)
|
||||
assert resolved.source == "chat_default"
|
||||
assert resolved.engine == "pi"
|
||||
|
||||
await chat_prefs.clear_default_engine(1)
|
||||
resolved = await resolve_engine_for_message(
|
||||
runtime=runtime,
|
||||
context=RunContext(project="proj"),
|
||||
explicit_engine=None,
|
||||
chat_id=1,
|
||||
topic_key=(1, 10),
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
)
|
||||
assert resolved.source == "project_default"
|
||||
assert resolved.engine == "pi"
|
||||
@@ -11,18 +11,21 @@ async def test_topic_state_store_roundtrip(tmp_path) -> None:
|
||||
store = TopicStateStore(path)
|
||||
context = RunContext(project="proj", branch="feat/topic")
|
||||
await store.set_context(1, 10, context)
|
||||
await store.set_default_engine(1, 10, "claude")
|
||||
await store.set_session_resume(1, 10, ResumeToken(engine="codex", value="abc123"))
|
||||
|
||||
snapshot = await store.get_thread(1, 10)
|
||||
assert snapshot is not None
|
||||
assert snapshot.context == context
|
||||
assert snapshot.sessions == {"codex": "abc123"}
|
||||
assert snapshot.default_engine == "claude"
|
||||
|
||||
store2 = TopicStateStore(path)
|
||||
snapshot2 = await store2.get_thread(1, 10)
|
||||
assert snapshot2 is not None
|
||||
assert snapshot2.context == context
|
||||
assert snapshot2.sessions == {"codex": "abc123"}
|
||||
assert snapshot2.default_engine == "claude"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -47,3 +50,7 @@ async def test_topic_state_store_clear_and_find(tmp_path) -> None:
|
||||
snapshot = await store.get_thread(2, 20)
|
||||
assert snapshot is not None
|
||||
assert snapshot.context is None
|
||||
await store.clear_default_engine(2, 20)
|
||||
snapshot = await store.get_thread(2, 20)
|
||||
assert snapshot is not None
|
||||
assert snapshot.default_engine is None
|
||||
|
||||
Reference in New Issue
Block a user