feat: add telegram /model and /reasoning overrides (#147)
This commit is contained in:
@@ -36,6 +36,8 @@ This line is parsed from replies and takes precedence over new directives.
|
||||
|---------|-------------|
|
||||
| `/cancel` | Reply to the progress message to stop the current run. |
|
||||
| `/agent` | Show/set the default agent for the current scope. |
|
||||
| `/model` | Show/set the model override for the current scope. |
|
||||
| `/reasoning` | Show/set the reasoning override for the current scope. |
|
||||
| `/file put <path>` | Upload a document into the repo/worktree (requires file transfer enabled). |
|
||||
| `/file get <path>` | Fetch a file or directory back into Telegram. |
|
||||
| `/topic <project> @branch` | Create/bind a topic (topics enabled). |
|
||||
|
||||
@@ -58,7 +58,7 @@ Explicit invocation includes any of:
|
||||
- `@botname` mention in the message.
|
||||
- `/engine` or `/project_alias` as the first token.
|
||||
- Replying to a bot message.
|
||||
- Built-in or plugin slash commands (for example `/agent`, `/file`, `/trigger`).
|
||||
- Built-in or plugin slash commands (for example `/agent`, `/model`, `/reasoning`, `/file`, `/trigger`).
|
||||
|
||||
Commands:
|
||||
|
||||
|
||||
+1
-1
@@ -7,7 +7,7 @@ _ID_RE = re.compile(ID_PATTERN)
|
||||
|
||||
RESERVED_CLI_COMMANDS = frozenset({"init", "plugins", "doctor"})
|
||||
RESERVED_CHAT_COMMANDS = frozenset(
|
||||
{"cancel", "file", "new", "agent", "trigger", "topic", "ctx"}
|
||||
{"cancel", "file", "new", "agent", "model", "reasoning", "trigger", "topic", "ctx"}
|
||||
)
|
||||
RESERVED_ENGINE_IDS = RESERVED_CLI_COMMANDS | RESERVED_CHAT_COMMANDS
|
||||
RESERVED_COMMAND_IDS = RESERVED_CLI_COMMANDS | RESERVED_CHAT_COMMANDS
|
||||
|
||||
@@ -14,6 +14,7 @@ from ..events import EventFactory
|
||||
from ..logging import get_logger
|
||||
from ..model import Action, ActionKind, EngineId, ResumeToken, TakopiEvent
|
||||
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
||||
from .run_options import get_run_options
|
||||
from ..schemas import claude as claude_schema
|
||||
from .tool_actions import tool_input_path, tool_kind_and_title
|
||||
|
||||
@@ -296,11 +297,15 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
return f"`claude --resume {token.value}`"
|
||||
|
||||
def _build_args(self, prompt: str, resume: ResumeToken | None) -> list[str]:
|
||||
run_options = get_run_options()
|
||||
args: list[str] = ["-p", "--output-format", "stream-json", "--verbose"]
|
||||
if resume is not None:
|
||||
args.extend(["--resume", resume.value])
|
||||
if self.model is not None:
|
||||
args.extend(["--model", str(self.model)])
|
||||
model = self.model
|
||||
if run_options is not None and run_options.model:
|
||||
model = run_options.model
|
||||
if model is not None:
|
||||
args.extend(["--model", str(model)])
|
||||
allowed_tools = _coerce_comma_list(self.allowed_tools)
|
||||
if allowed_tools is not None:
|
||||
args.extend(["--allowedTools", allowed_tools])
|
||||
|
||||
@@ -13,6 +13,7 @@ from ..events import EventFactory
|
||||
from ..logging import get_logger
|
||||
from ..model import ActionPhase, EngineId, ResumeToken, TakopiEvent
|
||||
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
||||
from .run_options import get_run_options
|
||||
from ..schemas import codex as codex_schema
|
||||
from ..utils.paths import relativize_command
|
||||
|
||||
@@ -426,13 +427,26 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: Any,
|
||||
) -> list[str]:
|
||||
args = [
|
||||
*self.extra_args,
|
||||
run_options = get_run_options()
|
||||
args = [*self.extra_args]
|
||||
if run_options is not None:
|
||||
if run_options.model:
|
||||
args.extend(["--model", str(run_options.model)])
|
||||
if run_options.reasoning:
|
||||
args.extend(
|
||||
[
|
||||
"-c",
|
||||
f"model_reasoning_effort={run_options.reasoning}",
|
||||
]
|
||||
)
|
||||
args.extend(
|
||||
[
|
||||
"exec",
|
||||
"--json",
|
||||
"--skip-git-repo-check",
|
||||
"--color=never",
|
||||
]
|
||||
)
|
||||
if resume:
|
||||
args.extend(["resume", resume.value, "-"])
|
||||
else:
|
||||
|
||||
@@ -34,6 +34,7 @@ from ..model import (
|
||||
TakopiEvent,
|
||||
)
|
||||
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
||||
from .run_options import get_run_options
|
||||
from ..schemas import opencode as opencode_schema
|
||||
from ..utils.paths import relativize_path
|
||||
from .tool_actions import tool_input_path, tool_kind_and_title
|
||||
@@ -325,11 +326,15 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: Any,
|
||||
) -> list[str]:
|
||||
run_options = get_run_options()
|
||||
args = ["run", "--format", "json"]
|
||||
if resume is not None:
|
||||
args.extend(["--session", resume.value])
|
||||
if self.model is not None:
|
||||
args.extend(["--model", str(self.model)])
|
||||
model = self.model
|
||||
if run_options is not None and run_options.model:
|
||||
model = run_options.model
|
||||
if model is not None:
|
||||
args.extend(["--model", str(model)])
|
||||
args.extend(["--", prompt])
|
||||
return args
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from ..model import (
|
||||
TakopiEvent,
|
||||
)
|
||||
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
||||
from .run_options import get_run_options
|
||||
from ..schemas import pi as pi_schema
|
||||
from ..utils.paths import get_run_base_dir
|
||||
from .tool_actions import tool_kind_and_title
|
||||
@@ -322,11 +323,15 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: PiStreamState,
|
||||
) -> list[str]:
|
||||
run_options = get_run_options()
|
||||
args: list[str] = [*self.extra_args, "--print", "--mode", "json"]
|
||||
if self.provider:
|
||||
args.extend(["--provider", self.provider])
|
||||
if self.model:
|
||||
args.extend(["--model", self.model])
|
||||
model = self.model
|
||||
if run_options is not None and run_options.model:
|
||||
model = run_options.model
|
||||
if model:
|
||||
args.extend(["--model", model])
|
||||
args.extend(["--session", state.resume.value])
|
||||
args.append(self._sanitize_prompt(prompt))
|
||||
return args
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar, Token
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class EngineRunOptions:
|
||||
model: str | None = None
|
||||
reasoning: str | None = None
|
||||
|
||||
|
||||
_RUN_OPTIONS: ContextVar[EngineRunOptions | None] = ContextVar(
|
||||
"takopi.engine_run_options", default=None
|
||||
)
|
||||
|
||||
|
||||
def get_run_options() -> EngineRunOptions | None:
|
||||
return _RUN_OPTIONS.get()
|
||||
|
||||
|
||||
def set_run_options(options: EngineRunOptions | None) -> Token:
|
||||
return _RUN_OPTIONS.set(options)
|
||||
|
||||
|
||||
def reset_run_options(token: Token) -> None:
|
||||
_RUN_OPTIONS.reset(token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def apply_run_options(options: EngineRunOptions | None) -> Iterator[None]:
|
||||
token = set_run_options(options)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_run_options(token)
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
import msgspec
|
||||
|
||||
from ..logging import get_logger
|
||||
from .engine_overrides import EngineOverrides, normalize_overrides
|
||||
from .state_store import JsonStateStore
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -16,6 +17,7 @@ STATE_FILENAME = "telegram_chat_prefs_state.json"
|
||||
class _ChatPrefs(msgspec.Struct, forbid_unknown_fields=False):
|
||||
default_engine: str | None = None
|
||||
trigger_mode: str | None = None
|
||||
engine_overrides: dict[str, EngineOverrides] = msgspec.field(default_factory=dict)
|
||||
|
||||
|
||||
class _ChatPrefsState(msgspec.Struct, forbid_unknown_fields=False):
|
||||
@@ -49,6 +51,13 @@ def _normalize_trigger_mode(value: str | None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_engine_id(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
value = value.strip().lower()
|
||||
return value or None
|
||||
|
||||
|
||||
def _new_state() -> _ChatPrefsState:
|
||||
return _ChatPrefsState(version=STATE_VERSION, chats={})
|
||||
|
||||
@@ -120,6 +129,45 @@ class ChatPrefsStore(JsonStateStore[_ChatPrefsState]):
|
||||
async def clear_trigger_mode(self, chat_id: int) -> None:
|
||||
await self.set_trigger_mode(chat_id, None)
|
||||
|
||||
async def get_engine_override(
|
||||
self, chat_id: int, engine: str
|
||||
) -> EngineOverrides | None:
|
||||
engine_key = _normalize_engine_id(engine)
|
||||
if engine_key is None:
|
||||
return None
|
||||
async with self._lock:
|
||||
self._reload_locked_if_needed()
|
||||
chat = self._get_chat_locked(chat_id)
|
||||
if chat is None:
|
||||
return None
|
||||
override = chat.engine_overrides.get(engine_key)
|
||||
return normalize_overrides(override)
|
||||
|
||||
async def set_engine_override(
|
||||
self, chat_id: int, engine: str, override: EngineOverrides | None
|
||||
) -> None:
|
||||
engine_key = _normalize_engine_id(engine)
|
||||
if engine_key is None:
|
||||
return
|
||||
normalized = normalize_overrides(override)
|
||||
async with self._lock:
|
||||
self._reload_locked_if_needed()
|
||||
chat = self._get_chat_locked(chat_id)
|
||||
if normalized is None:
|
||||
if chat is None:
|
||||
return
|
||||
chat.engine_overrides.pop(engine_key, None)
|
||||
if self._chat_is_empty(chat):
|
||||
self._remove_chat_locked(chat_id)
|
||||
self._save_locked()
|
||||
return
|
||||
chat = self._ensure_chat_locked(chat_id)
|
||||
chat.engine_overrides[engine_key] = normalized
|
||||
self._save_locked()
|
||||
|
||||
async def clear_engine_override(self, chat_id: int, engine: str) -> None:
|
||||
await self.set_engine_override(chat_id, engine, None)
|
||||
|
||||
def _get_chat_locked(self, chat_id: int) -> _ChatPrefs | None:
|
||||
return self._state.chats.get(_chat_key(chat_id))
|
||||
|
||||
@@ -136,8 +184,16 @@ class ChatPrefsStore(JsonStateStore[_ChatPrefsState]):
|
||||
return (
|
||||
_normalize_text(chat.default_engine) is None
|
||||
and _normalize_trigger_mode(chat.trigger_mode) is None
|
||||
and not self._has_engine_overrides(chat.engine_overrides)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _has_engine_overrides(overrides: dict[str, EngineOverrides]) -> bool:
|
||||
for override in overrides.values():
|
||||
if normalize_overrides(override) is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _remove_chat_locked(self, chat_id: int) -> bool:
|
||||
key = _chat_key(chat_id)
|
||||
if key not in self._state.chats:
|
||||
|
||||
@@ -9,6 +9,7 @@ from ...commands import CommandContext, get_command
|
||||
from ...config import ConfigError
|
||||
from ...logging import get_logger
|
||||
from ...model import EngineId, ResumeToken
|
||||
from ...runners.run_options import EngineRunOptions
|
||||
from ...runner_bridge import RunningTasks
|
||||
from ...scheduler import ThreadScheduler
|
||||
from ...transport import MessageRef
|
||||
@@ -33,6 +34,8 @@ async def _dispatch_command(
|
||||
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None,
|
||||
stateful_mode: bool,
|
||||
default_engine_override: EngineId | None,
|
||||
engine_overrides_resolver: Callable[[EngineId], Awaitable[EngineRunOptions | None]]
|
||||
| None,
|
||||
) -> None:
|
||||
allowlist = cfg.runtime.allowlist
|
||||
chat_id = msg.chat_id
|
||||
@@ -52,6 +55,7 @@ async def _dispatch_command(
|
||||
running_tasks=running_tasks,
|
||||
scheduler=scheduler,
|
||||
on_thread_known=on_thread_known,
|
||||
engine_overrides_resolver=engine_overrides_resolver,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
thread_id=msg.thread_id,
|
||||
|
||||
@@ -11,10 +11,11 @@ from ...commands import CommandExecutor, RunMode, RunRequest, RunResult
|
||||
from ...config import ConfigError
|
||||
from ...context import RunContext
|
||||
from ...logging import bind_run_context, clear_context, get_logger
|
||||
from ...model import EngineId, ResumeToken, TakopiEvent
|
||||
from ...model import Action, ActionEvent, EngineId, ResumeToken, TakopiEvent
|
||||
from ...progress import ProgressTracker
|
||||
from ...router import RunnerUnavailableError
|
||||
from ...runner import Runner
|
||||
from ...runners.run_options import EngineRunOptions, apply_run_options
|
||||
from ...runner_bridge import (
|
||||
ExecBridgeConfig,
|
||||
IncomingMessage as RunnerIncomingMessage,
|
||||
@@ -26,6 +27,7 @@ from ...transport import MessageRef, RenderedMessage, SendOptions
|
||||
from ...transport_runtime import TransportRuntime
|
||||
from ...utils.paths import reset_run_base_dir, set_run_base_dir
|
||||
from ..bridge import send_plain
|
||||
from ..engine_overrides import supports_reasoning
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -53,6 +55,54 @@ class _ResumeLineProxy:
|
||||
return self.runner.run(prompt, resume)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _PreludeRunner:
|
||||
runner: Runner
|
||||
prelude_events: Sequence[TakopiEvent]
|
||||
|
||||
@property
|
||||
def engine(self) -> str:
|
||||
return self.runner.engine
|
||||
|
||||
def is_resume_line(self, line: str) -> bool:
|
||||
return self.runner.is_resume_line(line)
|
||||
|
||||
def format_resume(self, token: ResumeToken) -> str:
|
||||
return self.runner.format_resume(token)
|
||||
|
||||
def extract_resume(self, text: str | None) -> ResumeToken | None:
|
||||
return self.runner.extract_resume(text)
|
||||
|
||||
async def run(
|
||||
self, prompt: str, resume: ResumeToken | None
|
||||
) -> AsyncIterator[TakopiEvent]:
|
||||
for event in self.prelude_events:
|
||||
yield event
|
||||
async for event in self.runner.run(prompt, resume):
|
||||
yield event
|
||||
|
||||
|
||||
def _reasoning_warning(
|
||||
*, engine: str, run_options: EngineRunOptions | None
|
||||
) -> ActionEvent | None:
|
||||
if run_options is None or not run_options.reasoning:
|
||||
return None
|
||||
if supports_reasoning(engine):
|
||||
return None
|
||||
message = f"reasoning override is not supported for `{engine}`; ignoring."
|
||||
return ActionEvent(
|
||||
engine=engine,
|
||||
action=Action(
|
||||
id=f"{engine}.override.reasoning",
|
||||
kind="note",
|
||||
title=message,
|
||||
detail={},
|
||||
),
|
||||
phase="completed",
|
||||
ok=True,
|
||||
)
|
||||
|
||||
|
||||
def _should_show_resume_line(
|
||||
*,
|
||||
show_resume_line: bool,
|
||||
@@ -108,6 +158,7 @@ async def _run_engine(
|
||||
thread_id: int | None = None,
|
||||
show_resume_line: bool = True,
|
||||
progress_ref: MessageRef | None = None,
|
||||
run_options: EngineRunOptions | None = None,
|
||||
) -> None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
@@ -128,6 +179,9 @@ async def _run_engine(
|
||||
runner: Runner = entry.runner
|
||||
if not show_resume_line:
|
||||
runner = cast(Runner, _ResumeLineProxy(runner))
|
||||
warning = _reasoning_warning(engine=runner.engine, run_options=run_options)
|
||||
if warning is not None:
|
||||
runner = cast(Runner, _PreludeRunner(runner, [warning]))
|
||||
if not entry.available:
|
||||
reason = entry.issue or "engine unavailable"
|
||||
await _send_runner_unavailable(
|
||||
@@ -167,6 +221,7 @@ async def _run_engine(
|
||||
reply_to=reply_ref,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
with apply_run_options(run_options):
|
||||
await handle_message(
|
||||
exec_cfg,
|
||||
runner=runner,
|
||||
@@ -235,6 +290,10 @@ class _TelegramCommandExecutor(CommandExecutor):
|
||||
running_tasks: RunningTasks,
|
||||
scheduler: ThreadScheduler,
|
||||
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None,
|
||||
engine_overrides_resolver: Callable[
|
||||
[EngineId], Awaitable[EngineRunOptions | None]
|
||||
]
|
||||
| None,
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
thread_id: int | None,
|
||||
@@ -247,6 +306,7 @@ class _TelegramCommandExecutor(CommandExecutor):
|
||||
self._running_tasks = running_tasks
|
||||
self._scheduler = scheduler
|
||||
self._on_thread_known = on_thread_known
|
||||
self._engine_overrides_resolver = engine_overrides_resolver
|
||||
self._chat_id = chat_id
|
||||
self._user_msg_id = user_msg_id
|
||||
self._thread_id = thread_id
|
||||
@@ -317,6 +377,9 @@ class _TelegramCommandExecutor(CommandExecutor):
|
||||
engine_override=request.engine,
|
||||
context=request.context,
|
||||
)
|
||||
run_options = None
|
||||
if self._engine_overrides_resolver is not None:
|
||||
run_options = await self._engine_overrides_resolver(engine)
|
||||
on_thread_known = (
|
||||
self._scheduler.note_thread_known
|
||||
if self._on_thread_known is None
|
||||
@@ -343,6 +406,7 @@ class _TelegramCommandExecutor(CommandExecutor):
|
||||
engine_override=engine,
|
||||
thread_id=self._thread_id,
|
||||
show_resume_line=effective_show_resume_line,
|
||||
run_options=run_options,
|
||||
)
|
||||
return RunResult(engine=engine, message=capture.last_message)
|
||||
await _run_engine(
|
||||
@@ -359,6 +423,7 @@ class _TelegramCommandExecutor(CommandExecutor):
|
||||
engine_override=engine,
|
||||
thread_id=self._thread_id,
|
||||
show_resume_line=effective_show_resume_line,
|
||||
run_options=run_options,
|
||||
)
|
||||
return RunResult(engine=engine, message=None)
|
||||
|
||||
|
||||
@@ -73,6 +73,8 @@ def build_bot_commands(
|
||||
for cmd, description in [
|
||||
("new", "start a new thread"),
|
||||
("agent", "set default agent"),
|
||||
("model", "set model override"),
|
||||
("reasoning", "set reasoning override"),
|
||||
("trigger", "set trigger mode"),
|
||||
]:
|
||||
if cmd in seen:
|
||||
|
||||
@@ -0,0 +1,284 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...context import RunContext
|
||||
from ...directives import DirectiveError
|
||||
from ..chat_prefs import ChatPrefsStore
|
||||
from ..engine_defaults import resolve_engine_for_message
|
||||
from ..engine_overrides import EngineOverrides, resolve_override_value
|
||||
from ..files import split_command_args
|
||||
from ..topic_state import TopicStateStore
|
||||
from ..topics import _topic_key
|
||||
from ..types import TelegramIncomingMessage
|
||||
from .reply import make_reply
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bridge import TelegramBridgeConfig
|
||||
|
||||
MODEL_USAGE = (
|
||||
"usage: `/model`, `/model set <model>`, "
|
||||
"`/model set <engine> <model>`, or `/model clear [engine]`"
|
||||
)
|
||||
|
||||
|
||||
async def _check_model_permissions(
|
||||
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
|
||||
) -> bool:
|
||||
reply = make_reply(cfg, msg)
|
||||
sender_id = msg.sender_id
|
||||
if sender_id is None:
|
||||
await reply(text="cannot verify sender for model overrides.")
|
||||
return False
|
||||
is_private = msg.chat_type == "private"
|
||||
if msg.chat_type is None:
|
||||
is_private = msg.chat_id > 0
|
||||
if is_private:
|
||||
return True
|
||||
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
|
||||
if member is None:
|
||||
await reply(text="failed to verify model override permissions.")
|
||||
return False
|
||||
if member.status in {"creator", "administrator"}:
|
||||
return True
|
||||
await reply(text="changing model overrides is restricted to group admins.")
|
||||
return False
|
||||
|
||||
|
||||
async def _resolve_engine_selection(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
*,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
chat_prefs: ChatPrefsStore | None,
|
||||
topic_key: tuple[int, int] | None,
|
||||
) -> tuple[str, str] | None:
|
||||
reply = make_reply(cfg, msg)
|
||||
try:
|
||||
resolved = cfg.runtime.resolve_message(
|
||||
text="",
|
||||
reply_text=msg.reply_to_text,
|
||||
ambient_context=ambient_context,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
except DirectiveError as exc:
|
||||
await reply(text=f"error:\n{exc}")
|
||||
return None
|
||||
selection = await resolve_engine_for_message(
|
||||
runtime=cfg.runtime,
|
||||
context=resolved.context,
|
||||
explicit_engine=None,
|
||||
chat_id=msg.chat_id,
|
||||
topic_key=topic_key,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
)
|
||||
return selection.engine, selection.source
|
||||
|
||||
|
||||
def _parse_set_args(
|
||||
tokens: tuple[str, ...], *, engine_ids: set[str]
|
||||
) -> tuple[str | None, str | None]:
|
||||
if len(tokens) < 2:
|
||||
return None, None
|
||||
if len(tokens) == 2:
|
||||
maybe_engine = tokens[1].strip().lower()
|
||||
if maybe_engine in engine_ids:
|
||||
return None, None
|
||||
return None, tokens[1].strip()
|
||||
maybe_engine = tokens[1].strip().lower()
|
||||
if maybe_engine in engine_ids:
|
||||
model = " ".join(tokens[2:]).strip()
|
||||
return maybe_engine, model or None
|
||||
model = " ".join(tokens[1:]).strip()
|
||||
return None, model or None
|
||||
|
||||
|
||||
async def _handle_model_command(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
chat_prefs: ChatPrefsStore | None,
|
||||
*,
|
||||
resolved_scope: str | None = None,
|
||||
scope_chat_ids: frozenset[int] | None = None,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
tkey = (
|
||||
_topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
|
||||
if topic_store is not None
|
||||
else None
|
||||
)
|
||||
tokens = split_command_args(args_text)
|
||||
action = tokens[0].lower() if tokens else "show"
|
||||
engine_ids = {engine.lower() for engine in cfg.runtime.engine_ids}
|
||||
|
||||
if action in {"show", ""}:
|
||||
selection = await _resolve_engine_selection(
|
||||
cfg,
|
||||
msg,
|
||||
ambient_context=ambient_context,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
topic_key=tkey,
|
||||
)
|
||||
if selection is None:
|
||||
return
|
||||
engine, engine_source = selection
|
||||
topic_override = None
|
||||
if tkey is not None and topic_store is not None:
|
||||
topic_override = await topic_store.get_engine_override(
|
||||
tkey[0], tkey[1], engine
|
||||
)
|
||||
chat_override = None
|
||||
if chat_prefs is not None:
|
||||
chat_override = await chat_prefs.get_engine_override(msg.chat_id, engine)
|
||||
resolution = resolve_override_value(
|
||||
topic_override=topic_override,
|
||||
chat_override=chat_override,
|
||||
field="model",
|
||||
)
|
||||
source_labels = {
|
||||
"directive": "directive",
|
||||
"topic_default": "topic default",
|
||||
"chat_default": "chat default",
|
||||
"project_default": "project default",
|
||||
"global_default": "global default",
|
||||
}
|
||||
override_labels = {
|
||||
"topic_override": "topic override",
|
||||
"chat_default": "chat default",
|
||||
"default": "no override",
|
||||
}
|
||||
engine_line = f"engine: {engine} ({source_labels[engine_source]})"
|
||||
model_value = resolution.value or "default"
|
||||
model_line = f"model: {model_value} ({override_labels[resolution.source]})"
|
||||
topic_label = resolution.topic_value or "none"
|
||||
if tkey is None:
|
||||
topic_label = "none"
|
||||
chat_label = (
|
||||
"unavailable" if chat_prefs is None else resolution.chat_value or "none"
|
||||
)
|
||||
defaults_line = f"defaults: topic: {topic_label}, chat: {chat_label}"
|
||||
available_line = f"available engines: {', '.join(cfg.runtime.engine_ids)}"
|
||||
await reply(
|
||||
text="\n\n".join([engine_line, model_line, defaults_line, available_line])
|
||||
)
|
||||
return
|
||||
|
||||
if action == "set":
|
||||
engine_arg, model = _parse_set_args(tokens, engine_ids=engine_ids)
|
||||
if model is None:
|
||||
await reply(text=MODEL_USAGE)
|
||||
return
|
||||
if not await _check_model_permissions(cfg, msg):
|
||||
return
|
||||
if engine_arg is None:
|
||||
selection = await _resolve_engine_selection(
|
||||
cfg,
|
||||
msg,
|
||||
ambient_context=ambient_context,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
topic_key=tkey,
|
||||
)
|
||||
if selection is None:
|
||||
return
|
||||
engine, _ = selection
|
||||
else:
|
||||
engine = engine_arg
|
||||
if engine not in engine_ids:
|
||||
available = ", ".join(cfg.runtime.engine_ids)
|
||||
await reply(
|
||||
text=f"unknown engine `{engine}`.\navailable agents: `{available}`"
|
||||
)
|
||||
return
|
||||
if tkey is not None:
|
||||
if topic_store is None:
|
||||
await reply(text="topic model overrides are unavailable.")
|
||||
return
|
||||
current = await topic_store.get_engine_override(tkey[0], tkey[1], engine)
|
||||
updated = EngineOverrides(
|
||||
model=model,
|
||||
reasoning=current.reasoning if current is not None else None,
|
||||
)
|
||||
await topic_store.set_engine_override(tkey[0], tkey[1], engine, updated)
|
||||
await reply(
|
||||
text=(
|
||||
f"topic model override set to `{model}` for `{engine}`.\n"
|
||||
"If you want a clean start on the new model, run `/new`."
|
||||
)
|
||||
)
|
||||
return
|
||||
if chat_prefs is None:
|
||||
await reply(text="chat model overrides are unavailable (no config path).")
|
||||
return
|
||||
current = await chat_prefs.get_engine_override(msg.chat_id, engine)
|
||||
updated = EngineOverrides(
|
||||
model=model,
|
||||
reasoning=current.reasoning if current is not None else None,
|
||||
)
|
||||
await chat_prefs.set_engine_override(msg.chat_id, engine, updated)
|
||||
await reply(
|
||||
text=(
|
||||
f"chat model override set to `{model}` for `{engine}`.\n"
|
||||
"If you want a clean start on the new model, run `/new`."
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if action == "clear":
|
||||
engine = None
|
||||
if len(tokens) > 2:
|
||||
await reply(text=MODEL_USAGE)
|
||||
return
|
||||
if len(tokens) == 2:
|
||||
engine = tokens[1].strip().lower() or None
|
||||
if not await _check_model_permissions(cfg, msg):
|
||||
return
|
||||
if engine is None:
|
||||
selection = await _resolve_engine_selection(
|
||||
cfg,
|
||||
msg,
|
||||
ambient_context=ambient_context,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
topic_key=tkey,
|
||||
)
|
||||
if selection is None:
|
||||
return
|
||||
engine, _ = selection
|
||||
if engine not in engine_ids:
|
||||
available = ", ".join(cfg.runtime.engine_ids)
|
||||
await reply(
|
||||
text=f"unknown engine `{engine}`.\navailable agents: `{available}`"
|
||||
)
|
||||
return
|
||||
if tkey is not None:
|
||||
if topic_store is None:
|
||||
await reply(text="topic model overrides are unavailable.")
|
||||
return
|
||||
current = await topic_store.get_engine_override(tkey[0], tkey[1], engine)
|
||||
updated = EngineOverrides(
|
||||
model=None,
|
||||
reasoning=current.reasoning if current is not None else None,
|
||||
)
|
||||
await topic_store.set_engine_override(tkey[0], tkey[1], engine, updated)
|
||||
await reply(text="topic model override cleared (using chat default).")
|
||||
return
|
||||
if chat_prefs is None:
|
||||
await reply(text="chat model overrides are unavailable (no config path).")
|
||||
return
|
||||
current = await chat_prefs.get_engine_override(msg.chat_id, engine)
|
||||
updated = EngineOverrides(
|
||||
model=None,
|
||||
reasoning=current.reasoning if current is not None else None,
|
||||
)
|
||||
await chat_prefs.set_engine_override(msg.chat_id, engine, updated)
|
||||
await reply(text="chat model override cleared.")
|
||||
return
|
||||
|
||||
await reply(text=MODEL_USAGE)
|
||||
@@ -0,0 +1,308 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...context import RunContext
|
||||
from ...directives import DirectiveError
|
||||
from ..chat_prefs import ChatPrefsStore
|
||||
from ..engine_defaults import resolve_engine_for_message
|
||||
from ..engine_overrides import (
|
||||
EngineOverrides,
|
||||
allowed_reasoning_levels,
|
||||
resolve_override_value,
|
||||
)
|
||||
from ..files import split_command_args
|
||||
from ..topic_state import TopicStateStore
|
||||
from ..topics import _topic_key
|
||||
from ..types import TelegramIncomingMessage
|
||||
from .reply import make_reply
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bridge import TelegramBridgeConfig
|
||||
|
||||
REASONING_USAGE = (
|
||||
"usage: `/reasoning`, `/reasoning set <level>`, "
|
||||
"`/reasoning set <engine> <level>`, or `/reasoning clear [engine]`"
|
||||
)
|
||||
|
||||
|
||||
async def _check_reasoning_permissions(
|
||||
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
|
||||
) -> bool:
|
||||
reply = make_reply(cfg, msg)
|
||||
sender_id = msg.sender_id
|
||||
if sender_id is None:
|
||||
await reply(text="cannot verify sender for reasoning overrides.")
|
||||
return False
|
||||
is_private = msg.chat_type == "private"
|
||||
if msg.chat_type is None:
|
||||
is_private = msg.chat_id > 0
|
||||
if is_private:
|
||||
return True
|
||||
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
|
||||
if member is None:
|
||||
await reply(text="failed to verify reasoning override permissions.")
|
||||
return False
|
||||
if member.status in {"creator", "administrator"}:
|
||||
return True
|
||||
await reply(text="changing reasoning overrides is restricted to group admins.")
|
||||
return False
|
||||
|
||||
|
||||
async def _resolve_engine_selection(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
*,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
chat_prefs: ChatPrefsStore | None,
|
||||
topic_key: tuple[int, int] | None,
|
||||
) -> tuple[str, str] | None:
|
||||
reply = make_reply(cfg, msg)
|
||||
try:
|
||||
resolved = cfg.runtime.resolve_message(
|
||||
text="",
|
||||
reply_text=msg.reply_to_text,
|
||||
ambient_context=ambient_context,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
except DirectiveError as exc:
|
||||
await reply(text=f"error:\n{exc}")
|
||||
return None
|
||||
selection = await resolve_engine_for_message(
|
||||
runtime=cfg.runtime,
|
||||
context=resolved.context,
|
||||
explicit_engine=None,
|
||||
chat_id=msg.chat_id,
|
||||
topic_key=topic_key,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
)
|
||||
return selection.engine, selection.source
|
||||
|
||||
|
||||
def _parse_set_args(
|
||||
tokens: tuple[str, ...], *, engine_ids: set[str]
|
||||
) -> tuple[str | None, str | None]:
|
||||
if len(tokens) < 2:
|
||||
return None, None
|
||||
if len(tokens) == 2:
|
||||
maybe_engine = tokens[1].strip().lower()
|
||||
if maybe_engine in engine_ids:
|
||||
return None, None
|
||||
return None, tokens[1].strip()
|
||||
maybe_engine = tokens[1].strip().lower()
|
||||
if maybe_engine in engine_ids:
|
||||
level = " ".join(tokens[2:]).strip()
|
||||
return maybe_engine, level or None
|
||||
level = " ".join(tokens[1:]).strip()
|
||||
return None, level or None
|
||||
|
||||
|
||||
async def _handle_reasoning_command(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
chat_prefs: ChatPrefsStore | None,
|
||||
*,
|
||||
resolved_scope: str | None = None,
|
||||
scope_chat_ids: frozenset[int] | None = None,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
tkey = (
|
||||
_topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
|
||||
if topic_store is not None
|
||||
else None
|
||||
)
|
||||
tokens = split_command_args(args_text)
|
||||
action = tokens[0].lower() if tokens else "show"
|
||||
engine_ids = {engine.lower() for engine in cfg.runtime.engine_ids}
|
||||
|
||||
if action in {"show", ""}:
|
||||
selection = await _resolve_engine_selection(
|
||||
cfg,
|
||||
msg,
|
||||
ambient_context=ambient_context,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
topic_key=tkey,
|
||||
)
|
||||
if selection is None:
|
||||
return
|
||||
engine, engine_source = selection
|
||||
topic_override = None
|
||||
if tkey is not None and topic_store is not None:
|
||||
topic_override = await topic_store.get_engine_override(
|
||||
tkey[0], tkey[1], engine
|
||||
)
|
||||
chat_override = None
|
||||
if chat_prefs is not None:
|
||||
chat_override = await chat_prefs.get_engine_override(msg.chat_id, engine)
|
||||
resolution = resolve_override_value(
|
||||
topic_override=topic_override,
|
||||
chat_override=chat_override,
|
||||
field="reasoning",
|
||||
)
|
||||
source_labels = {
|
||||
"directive": "directive",
|
||||
"topic_default": "topic default",
|
||||
"chat_default": "chat default",
|
||||
"project_default": "project default",
|
||||
"global_default": "global default",
|
||||
}
|
||||
override_labels = {
|
||||
"topic_override": "topic override",
|
||||
"chat_default": "chat default",
|
||||
"default": "no override",
|
||||
}
|
||||
engine_line = f"engine: {engine} ({source_labels[engine_source]})"
|
||||
reasoning_value = resolution.value or "default"
|
||||
reasoning_line = (
|
||||
f"reasoning: {reasoning_value} ({override_labels[resolution.source]})"
|
||||
)
|
||||
topic_label = resolution.topic_value or "none"
|
||||
if tkey is None:
|
||||
topic_label = "none"
|
||||
chat_label = (
|
||||
"unavailable" if chat_prefs is None else resolution.chat_value or "none"
|
||||
)
|
||||
defaults_line = f"defaults: topic: {topic_label}, chat: {chat_label}"
|
||||
available_levels = ", ".join(allowed_reasoning_levels(engine))
|
||||
available_line = f"available levels: {available_levels}"
|
||||
await reply(
|
||||
text="\n\n".join(
|
||||
[engine_line, reasoning_line, defaults_line, available_line]
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if action == "set":
|
||||
engine_arg, level = _parse_set_args(tokens, engine_ids=engine_ids)
|
||||
if level is None:
|
||||
await reply(text=REASONING_USAGE)
|
||||
return
|
||||
if not await _check_reasoning_permissions(cfg, msg):
|
||||
return
|
||||
if engine_arg is None:
|
||||
selection = await _resolve_engine_selection(
|
||||
cfg,
|
||||
msg,
|
||||
ambient_context=ambient_context,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
topic_key=tkey,
|
||||
)
|
||||
if selection is None:
|
||||
return
|
||||
engine, _ = selection
|
||||
else:
|
||||
engine = engine_arg
|
||||
if engine not in engine_ids:
|
||||
available = ", ".join(cfg.runtime.engine_ids)
|
||||
await reply(
|
||||
text=f"unknown engine `{engine}`.\navailable agents: `{available}`"
|
||||
)
|
||||
return
|
||||
normalized_level = level.strip().lower()
|
||||
allowed = allowed_reasoning_levels(engine)
|
||||
if normalized_level not in allowed:
|
||||
await reply(
|
||||
text=(
|
||||
f"unknown reasoning level `{level}`.\n"
|
||||
f"available levels: {', '.join(allowed)}"
|
||||
)
|
||||
)
|
||||
return
|
||||
if tkey is not None:
|
||||
if topic_store is None:
|
||||
await reply(text="topic reasoning overrides are unavailable.")
|
||||
return
|
||||
current = await topic_store.get_engine_override(tkey[0], tkey[1], engine)
|
||||
updated = EngineOverrides(
|
||||
model=current.model if current is not None else None,
|
||||
reasoning=normalized_level,
|
||||
)
|
||||
await topic_store.set_engine_override(tkey[0], tkey[1], engine, updated)
|
||||
await reply(
|
||||
text=(
|
||||
f"topic reasoning override set to `{normalized_level}` "
|
||||
f"for `{engine}`.\n"
|
||||
"If you want a clean start on the new setting, run `/new`."
|
||||
)
|
||||
)
|
||||
return
|
||||
if chat_prefs is None:
|
||||
await reply(
|
||||
text="chat reasoning overrides are unavailable (no config path)."
|
||||
)
|
||||
return
|
||||
current = await chat_prefs.get_engine_override(msg.chat_id, engine)
|
||||
updated = EngineOverrides(
|
||||
model=current.model if current is not None else None,
|
||||
reasoning=normalized_level,
|
||||
)
|
||||
await chat_prefs.set_engine_override(msg.chat_id, engine, updated)
|
||||
await reply(
|
||||
text=(
|
||||
f"chat reasoning override set to `{normalized_level}` for `{engine}`.\n"
|
||||
"If you want a clean start on the new setting, run `/new`."
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if action == "clear":
|
||||
engine = None
|
||||
if len(tokens) > 2:
|
||||
await reply(text=REASONING_USAGE)
|
||||
return
|
||||
if len(tokens) == 2:
|
||||
engine = tokens[1].strip().lower() or None
|
||||
if not await _check_reasoning_permissions(cfg, msg):
|
||||
return
|
||||
if engine is None:
|
||||
selection = await _resolve_engine_selection(
|
||||
cfg,
|
||||
msg,
|
||||
ambient_context=ambient_context,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
topic_key=tkey,
|
||||
)
|
||||
if selection is None:
|
||||
return
|
||||
engine, _ = selection
|
||||
if engine not in engine_ids:
|
||||
available = ", ".join(cfg.runtime.engine_ids)
|
||||
await reply(
|
||||
text=f"unknown engine `{engine}`.\navailable agents: `{available}`"
|
||||
)
|
||||
return
|
||||
if tkey is not None:
|
||||
if topic_store is None:
|
||||
await reply(text="topic reasoning overrides are unavailable.")
|
||||
return
|
||||
current = await topic_store.get_engine_override(tkey[0], tkey[1], engine)
|
||||
updated = EngineOverrides(
|
||||
model=current.model if current is not None else None,
|
||||
reasoning=None,
|
||||
)
|
||||
await topic_store.set_engine_override(tkey[0], tkey[1], engine, updated)
|
||||
await reply(text="topic reasoning override cleared (using chat default).")
|
||||
return
|
||||
if chat_prefs is None:
|
||||
await reply(
|
||||
text="chat reasoning overrides are unavailable (no config path)."
|
||||
)
|
||||
return
|
||||
current = await chat_prefs.get_engine_override(msg.chat_id, engine)
|
||||
updated = EngineOverrides(
|
||||
model=current.model if current is not None else None,
|
||||
reasoning=None,
|
||||
)
|
||||
await chat_prefs.set_engine_override(msg.chat_id, engine, updated)
|
||||
await reply(text="chat reasoning override cleared.")
|
||||
return
|
||||
|
||||
await reply(text=REASONING_USAGE)
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import msgspec
|
||||
|
||||
OverrideSource = Literal["topic_override", "chat_default", "default"]
|
||||
|
||||
REASONING_LEVELS: tuple[str, ...] = ("minimal", "low", "medium", "high", "xhigh")
|
||||
OPENCODE_REASONING_LEVELS: tuple[str, ...] = ("none", *REASONING_LEVELS)
|
||||
REASONING_SUPPORTED_ENGINES = frozenset({"codex"})
|
||||
|
||||
|
||||
class EngineOverrides(msgspec.Struct, forbid_unknown_fields=False):
|
||||
model: str | None = None
|
||||
reasoning: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OverrideValueResolution:
|
||||
value: str | None
|
||||
source: OverrideSource
|
||||
topic_value: str | None
|
||||
chat_value: str | None
|
||||
|
||||
|
||||
def normalize_override_value(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
cleaned = value.strip()
|
||||
return cleaned or None
|
||||
|
||||
|
||||
def normalize_overrides(overrides: EngineOverrides | None) -> EngineOverrides | None:
|
||||
if overrides is None:
|
||||
return None
|
||||
model = normalize_override_value(overrides.model)
|
||||
reasoning = normalize_override_value(overrides.reasoning)
|
||||
if model is None and reasoning is None:
|
||||
return None
|
||||
return EngineOverrides(model=model, reasoning=reasoning)
|
||||
|
||||
|
||||
def merge_overrides(
|
||||
topic_override: EngineOverrides | None,
|
||||
chat_override: EngineOverrides | None,
|
||||
) -> EngineOverrides | None:
|
||||
topic = normalize_overrides(topic_override)
|
||||
chat = normalize_overrides(chat_override)
|
||||
if topic is None and chat is None:
|
||||
return None
|
||||
model = None
|
||||
reasoning = None
|
||||
if topic is not None and topic.model is not None:
|
||||
model = topic.model
|
||||
elif chat is not None:
|
||||
model = chat.model
|
||||
if topic is not None and topic.reasoning is not None:
|
||||
reasoning = topic.reasoning
|
||||
elif chat is not None:
|
||||
reasoning = chat.reasoning
|
||||
return normalize_overrides(EngineOverrides(model=model, reasoning=reasoning))
|
||||
|
||||
|
||||
def resolve_override_value(
|
||||
*,
|
||||
topic_override: EngineOverrides | None,
|
||||
chat_override: EngineOverrides | None,
|
||||
field: Literal["model", "reasoning"],
|
||||
) -> OverrideValueResolution:
|
||||
topic_value = normalize_override_value(
|
||||
getattr(topic_override, field, None) if topic_override is not None else None
|
||||
)
|
||||
chat_value = normalize_override_value(
|
||||
getattr(chat_override, field, None) if chat_override is not None else None
|
||||
)
|
||||
if topic_value is not None:
|
||||
return OverrideValueResolution(
|
||||
value=topic_value,
|
||||
source="topic_override",
|
||||
topic_value=topic_value,
|
||||
chat_value=chat_value,
|
||||
)
|
||||
if chat_value is not None:
|
||||
return OverrideValueResolution(
|
||||
value=chat_value,
|
||||
source="chat_default",
|
||||
topic_value=topic_value,
|
||||
chat_value=chat_value,
|
||||
)
|
||||
return OverrideValueResolution(
|
||||
value=None,
|
||||
source="default",
|
||||
topic_value=topic_value,
|
||||
chat_value=chat_value,
|
||||
)
|
||||
|
||||
|
||||
def allowed_reasoning_levels(engine: str) -> tuple[str, ...]:
|
||||
if engine == "opencode":
|
||||
return OPENCODE_REASONING_LEVELS
|
||||
return REASONING_LEVELS
|
||||
|
||||
|
||||
def supports_reasoning(engine: str) -> bool:
|
||||
return engine in REASONING_SUPPORTED_ENGINES
|
||||
@@ -14,6 +14,7 @@ from ..commands import list_command_ids
|
||||
from ..directives import DirectiveError
|
||||
from ..logging import get_logger
|
||||
from ..model import EngineId, ResumeToken
|
||||
from ..runners.run_options import EngineRunOptions
|
||||
from ..scheduler import ThreadJob, ThreadScheduler
|
||||
from ..progress import ProgressTracker
|
||||
from ..settings import TelegramTransportSettings
|
||||
@@ -42,6 +43,8 @@ from .commands.topics import (
|
||||
_handle_new_command,
|
||||
_handle_topic_command,
|
||||
)
|
||||
from .commands.model import _handle_model_command
|
||||
from .commands.reasoning import _handle_reasoning_command
|
||||
from .commands.trigger import _handle_trigger_command
|
||||
from .context import _merge_topic_context, _usage_ctx_set, _usage_topic
|
||||
from .topics import (
|
||||
@@ -55,6 +58,7 @@ from .topics import (
|
||||
from .client import poll_incoming
|
||||
from .chat_prefs import ChatPrefsStore, resolve_prefs_path
|
||||
from .chat_sessions import ChatSessionStore, resolve_sessions_path
|
||||
from .engine_overrides import merge_overrides
|
||||
from .engine_defaults import resolve_engine_for_message
|
||||
from .topic_state import TopicStateStore, resolve_state_path
|
||||
from .trigger_mode import resolve_trigger_mode, should_trigger_run
|
||||
@@ -86,6 +90,27 @@ def _chat_session_key(
|
||||
return (msg.chat_id, msg.sender_id)
|
||||
|
||||
|
||||
async def _resolve_engine_run_options(
|
||||
chat_id: int,
|
||||
thread_id: int | None,
|
||||
engine: EngineId,
|
||||
chat_prefs: ChatPrefsStore | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> EngineRunOptions | None:
|
||||
topic_override = None
|
||||
if topic_store is not None and thread_id is not None:
|
||||
topic_override = await topic_store.get_engine_override(
|
||||
chat_id, thread_id, engine
|
||||
)
|
||||
chat_override = None
|
||||
if chat_prefs is not None:
|
||||
chat_override = await chat_prefs.get_engine_override(chat_id, engine)
|
||||
merged = merge_overrides(topic_override, chat_override)
|
||||
if merged is None:
|
||||
return None
|
||||
return EngineRunOptions(model=merged.model, reasoning=merged.reasoning)
|
||||
|
||||
|
||||
def _allowed_chat_ids(cfg: TelegramBridgeConfig) -> set[int]:
|
||||
allowed = set(cfg.chat_ids or ())
|
||||
allowed.add(cfg.chat_id)
|
||||
@@ -187,6 +212,32 @@ def _dispatch_builtin_command(
|
||||
scope_chat_ids=scope_chat_ids,
|
||||
)
|
||||
|
||||
if command_id == "model":
|
||||
handlers["model"] = partial(
|
||||
_handle_model_command,
|
||||
cfg,
|
||||
msg,
|
||||
args_text,
|
||||
ambient_context,
|
||||
topic_store,
|
||||
chat_prefs,
|
||||
resolved_scope=resolved_scope,
|
||||
scope_chat_ids=scope_chat_ids,
|
||||
)
|
||||
|
||||
if command_id == "reasoning":
|
||||
handlers["reasoning"] = partial(
|
||||
_handle_reasoning_command,
|
||||
cfg,
|
||||
msg,
|
||||
args_text,
|
||||
ambient_context,
|
||||
topic_store,
|
||||
chat_prefs,
|
||||
resolved_scope=resolved_scope,
|
||||
scope_chat_ids=scope_chat_ids,
|
||||
)
|
||||
|
||||
if command_id == "trigger":
|
||||
handlers["trigger"] = partial(
|
||||
_handle_trigger_command,
|
||||
@@ -605,6 +656,24 @@ async def run_main_loop(
|
||||
stateful_mode=stateful_mode,
|
||||
context=context,
|
||||
)
|
||||
engine_for_overrides = (
|
||||
resume_token.engine
|
||||
if resume_token is not None
|
||||
else engine_override
|
||||
if engine_override is not None
|
||||
else cfg.runtime.resolve_engine(
|
||||
engine_override=None,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
overrides_thread_id = topic_key[1] if topic_key is not None else None
|
||||
run_options = await _resolve_engine_run_options(
|
||||
chat_id,
|
||||
overrides_thread_id,
|
||||
engine_for_overrides,
|
||||
chat_prefs=chat_prefs,
|
||||
topic_store=topic_store,
|
||||
)
|
||||
await _run_engine(
|
||||
exec_cfg=cfg.exec_cfg,
|
||||
runtime=cfg.runtime,
|
||||
@@ -622,6 +691,7 @@ async def run_main_loop(
|
||||
thread_id=thread_id,
|
||||
show_resume_line=show_resume_line,
|
||||
progress_ref=progress_ref,
|
||||
run_options=run_options,
|
||||
)
|
||||
|
||||
async def run_thread_job(job: ThreadJob) -> None:
|
||||
@@ -1377,6 +1447,16 @@ async def run_main_loop(
|
||||
in {"directive", "topic_default", "chat_default"}
|
||||
else None
|
||||
)
|
||||
overrides_thread_id = (
|
||||
topic_key[1] if topic_key is not None else None
|
||||
)
|
||||
engine_overrides_resolver = partial(
|
||||
_resolve_engine_run_options,
|
||||
chat_id,
|
||||
overrides_thread_id,
|
||||
chat_prefs=chat_prefs,
|
||||
topic_store=topic_store,
|
||||
)
|
||||
tg.start_soon(
|
||||
_dispatch_command,
|
||||
cfg,
|
||||
@@ -1393,6 +1473,7 @@ async def run_main_loop(
|
||||
),
|
||||
stateful_mode,
|
||||
default_engine_override,
|
||||
engine_overrides_resolver,
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import msgspec
|
||||
from ..context import RunContext
|
||||
from ..logging import get_logger
|
||||
from ..model import ResumeToken
|
||||
from .engine_overrides import EngineOverrides, normalize_overrides
|
||||
from .state_store import JsonStateStore
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -41,6 +42,7 @@ class _ThreadState(msgspec.Struct, forbid_unknown_fields=False):
|
||||
topic_title: str | None = None
|
||||
default_engine: str | None = None
|
||||
trigger_mode: str | None = None
|
||||
engine_overrides: dict[str, EngineOverrides] = msgspec.field(default_factory=dict)
|
||||
|
||||
|
||||
class _TopicState(msgspec.Struct, forbid_unknown_fields=False):
|
||||
@@ -74,6 +76,13 @@ def _normalize_trigger_mode(value: str | None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_engine_id(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
value = value.strip().lower()
|
||||
return value or None
|
||||
|
||||
|
||||
def _context_from_state(state: _ContextState | None) -> RunContext | None:
|
||||
if state is None:
|
||||
return None
|
||||
@@ -181,6 +190,20 @@ class TopicStateStore(JsonStateStore[_TopicState]):
|
||||
return None
|
||||
return _normalize_trigger_mode(thread.trigger_mode)
|
||||
|
||||
async def get_engine_override(
|
||||
self, chat_id: int, thread_id: int, engine: str
|
||||
) -> EngineOverrides | None:
|
||||
engine_key = _normalize_engine_id(engine)
|
||||
if engine_key is None:
|
||||
return None
|
||||
async with self._lock:
|
||||
self._reload_locked_if_needed()
|
||||
thread = self._get_thread_locked(chat_id, thread_id)
|
||||
if thread is None:
|
||||
return None
|
||||
override = thread.engine_overrides.get(engine_key)
|
||||
return normalize_overrides(override)
|
||||
|
||||
async def set_default_engine(
|
||||
self, chat_id: int, thread_id: int, engine: str | None
|
||||
) -> None:
|
||||
@@ -207,6 +230,31 @@ class TopicStateStore(JsonStateStore[_TopicState]):
|
||||
async def clear_trigger_mode(self, chat_id: int, thread_id: int) -> None:
|
||||
await self.set_trigger_mode(chat_id, thread_id, None)
|
||||
|
||||
async def set_engine_override(
|
||||
self,
|
||||
chat_id: int,
|
||||
thread_id: int,
|
||||
engine: str,
|
||||
override: EngineOverrides | None,
|
||||
) -> None:
|
||||
engine_key = _normalize_engine_id(engine)
|
||||
if engine_key is None:
|
||||
return
|
||||
normalized = normalize_overrides(override)
|
||||
async with self._lock:
|
||||
self._reload_locked_if_needed()
|
||||
thread = self._ensure_thread_locked(chat_id, thread_id)
|
||||
if normalized is None:
|
||||
thread.engine_overrides.pop(engine_key, None)
|
||||
else:
|
||||
thread.engine_overrides[engine_key] = normalized
|
||||
self._save_locked()
|
||||
|
||||
async def clear_engine_override(
|
||||
self, chat_id: int, thread_id: int, engine: str
|
||||
) -> None:
|
||||
await self.set_engine_override(chat_id, thread_id, engine, None)
|
||||
|
||||
async def set_session_resume(
|
||||
self, chat_id: int, thread_id: int, token: ResumeToken
|
||||
) -> None:
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
from takopi.model import ResumeToken
|
||||
from takopi.runners.claude import ClaudeRunner
|
||||
from takopi.runners.codex import CodexRunner
|
||||
from takopi.runners.opencode import OpenCodeRunner, OpenCodeStreamState
|
||||
from takopi.runners.pi import ENGINE as PI_ENGINE, PiRunner, PiStreamState
|
||||
from takopi.runners.run_options import EngineRunOptions, apply_run_options
|
||||
|
||||
|
||||
def test_codex_run_options_override_model_and_reasoning() -> None:
|
||||
runner = CodexRunner(codex_cmd="codex", extra_args=["-c", "notify=[]"])
|
||||
state = runner.new_state("hi", None)
|
||||
with apply_run_options(EngineRunOptions(model="gpt-4.1-mini", reasoning="low")):
|
||||
args = runner.build_args("hi", None, state=state)
|
||||
|
||||
assert args == [
|
||||
"-c",
|
||||
"notify=[]",
|
||||
"--model",
|
||||
"gpt-4.1-mini",
|
||||
"-c",
|
||||
"model_reasoning_effort=low",
|
||||
"exec",
|
||||
"--json",
|
||||
"--skip-git-repo-check",
|
||||
"--color=never",
|
||||
"-",
|
||||
]
|
||||
|
||||
|
||||
def test_claude_run_options_override_model() -> None:
|
||||
runner = ClaudeRunner(claude_cmd="claude", model="claude-sonnet")
|
||||
with apply_run_options(EngineRunOptions(model="claude-opus")):
|
||||
args = runner.build_args("hi", None, state=None)
|
||||
|
||||
assert "--model" in args
|
||||
model_idx = args.index("--model") + 1
|
||||
assert args[model_idx] == "claude-opus"
|
||||
|
||||
|
||||
def test_opencode_run_options_override_model() -> None:
|
||||
runner = OpenCodeRunner(opencode_cmd="opencode", model="claude-sonnet")
|
||||
state = OpenCodeStreamState()
|
||||
with apply_run_options(EngineRunOptions(model="gpt-4o-mini")):
|
||||
args = runner.build_args("hi", None, state=state)
|
||||
|
||||
assert "--model" in args
|
||||
model_idx = args.index("--model") + 1
|
||||
assert args[model_idx] == "gpt-4o-mini"
|
||||
|
||||
|
||||
def test_pi_run_options_override_model() -> None:
|
||||
runner = PiRunner(extra_args=[], model="pi-default", provider=None)
|
||||
state = PiStreamState(resume=ResumeToken(engine=PI_ENGINE, value="sess.jsonl"))
|
||||
with apply_run_options(EngineRunOptions(model="pi-override")):
|
||||
args = runner.build_args("hi", None, state=state)
|
||||
|
||||
assert "--model" in args
|
||||
model_idx = args.index("--model") + 1
|
||||
assert args[model_idx] == "pi-override"
|
||||
@@ -8,6 +8,8 @@ import pytest
|
||||
from takopi import commands, plugins
|
||||
from takopi.telegram.commands.executor import _CaptureTransport, _run_engine
|
||||
from takopi.telegram.commands.file_transfer import _handle_file_get, _handle_file_put
|
||||
from takopi.telegram.commands.model import _handle_model_command
|
||||
from takopi.telegram.commands.reasoning import _handle_reasoning_command
|
||||
from takopi.telegram.commands.topics import _handle_topic_command
|
||||
import takopi.telegram.loop as telegram_loop
|
||||
import takopi.telegram.topics as telegram_topics
|
||||
@@ -38,6 +40,7 @@ from takopi.telegram.render import MAX_BODY_CHARS
|
||||
from takopi.telegram.topic_state import TopicStateStore, resolve_state_path
|
||||
from takopi.telegram.chat_sessions import ChatSessionStore, resolve_sessions_path
|
||||
from takopi.telegram.chat_prefs import ChatPrefsStore, resolve_prefs_path
|
||||
from takopi.telegram.engine_overrides import EngineOverrides
|
||||
from takopi.context import RunContext
|
||||
from takopi.config import ProjectConfig, ProjectsConfig
|
||||
from takopi.runner_bridge import ExecBridgeConfig, RunningTask
|
||||
@@ -1348,6 +1351,230 @@ async def test_topic_command_recreates_stale_topic(tmp_path: Path) -> None:
|
||||
assert snapshot.context == RunContext(project="takopi", branch="master")
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_model_command_show_reports_overrides(tmp_path: Path) -> None:
|
||||
transport = _FakeTransport()
|
||||
cfg = _make_cfg(transport)
|
||||
cfg = replace(cfg, topics=TelegramTopicsSettings(enabled=True, scope="main"))
|
||||
chat_prefs = ChatPrefsStore(tmp_path / "telegram_chat_prefs_state.json")
|
||||
topic_store = TopicStateStore(tmp_path / "telegram_topics_state.json")
|
||||
await chat_prefs.set_engine_override(
|
||||
123,
|
||||
CODEX_ENGINE,
|
||||
EngineOverrides(model="gpt-4.1-mini", reasoning=None),
|
||||
)
|
||||
await topic_store.set_engine_override(
|
||||
123,
|
||||
77,
|
||||
CODEX_ENGINE,
|
||||
EngineOverrides(model="gpt-4.1", reasoning=None),
|
||||
)
|
||||
msg = TelegramIncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=10,
|
||||
text="/model",
|
||||
reply_to_message_id=None,
|
||||
reply_to_text=None,
|
||||
sender_id=123,
|
||||
thread_id=77,
|
||||
)
|
||||
|
||||
await _handle_model_command(
|
||||
cfg,
|
||||
msg,
|
||||
"",
|
||||
ambient_context=None,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
resolved_scope="main",
|
||||
scope_chat_ids=frozenset({123}),
|
||||
)
|
||||
|
||||
text = transport.send_calls[-1]["message"].text
|
||||
assert "engine: codex (global default)" in text
|
||||
assert "model: gpt-4.1 (topic override)" in text
|
||||
assert "defaults: topic: gpt-4.1, chat: gpt-4.1-mini" in text
|
||||
assert "available engines: codex" in text
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_model_command_set_and_clear_chat_override(tmp_path: Path) -> None:
|
||||
transport = _FakeTransport()
|
||||
cfg = _make_cfg(transport)
|
||||
chat_prefs = ChatPrefsStore(tmp_path / "telegram_chat_prefs_state.json")
|
||||
await chat_prefs.set_engine_override(
|
||||
123,
|
||||
CODEX_ENGINE,
|
||||
EngineOverrides(model=None, reasoning="low"),
|
||||
)
|
||||
msg = TelegramIncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=10,
|
||||
text="/model set gpt-4.1-mini",
|
||||
reply_to_message_id=None,
|
||||
reply_to_text=None,
|
||||
sender_id=456,
|
||||
chat_type="supergroup",
|
||||
)
|
||||
|
||||
await _handle_model_command(
|
||||
cfg,
|
||||
msg,
|
||||
"set gpt-4.1-mini",
|
||||
ambient_context=None,
|
||||
topic_store=None,
|
||||
chat_prefs=chat_prefs,
|
||||
)
|
||||
|
||||
override = await chat_prefs.get_engine_override(123, CODEX_ENGINE)
|
||||
assert override is not None
|
||||
assert override.model == "gpt-4.1-mini"
|
||||
assert override.reasoning == "low"
|
||||
assert (
|
||||
"chat model override set to gpt-4.1-mini for codex."
|
||||
in transport.send_calls[-1]["message"].text
|
||||
)
|
||||
|
||||
msg_clear = replace(
|
||||
msg,
|
||||
message_id=11,
|
||||
text="/model clear codex",
|
||||
)
|
||||
await _handle_model_command(
|
||||
cfg,
|
||||
msg_clear,
|
||||
"clear codex",
|
||||
ambient_context=None,
|
||||
topic_store=None,
|
||||
chat_prefs=chat_prefs,
|
||||
)
|
||||
|
||||
override = await chat_prefs.get_engine_override(123, CODEX_ENGINE)
|
||||
assert override is not None
|
||||
assert override.model is None
|
||||
assert override.reasoning == "low"
|
||||
assert "chat model override cleared." in transport.send_calls[-1]["message"].text
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_reasoning_command_set_and_clear_topic_override(tmp_path: Path) -> None:
|
||||
transport = _FakeTransport()
|
||||
cfg = _make_cfg(transport)
|
||||
cfg = replace(cfg, topics=TelegramTopicsSettings(enabled=True, scope="main"))
|
||||
topic_store = TopicStateStore(tmp_path / "telegram_topics_state.json")
|
||||
await topic_store.set_engine_override(
|
||||
123,
|
||||
77,
|
||||
CODEX_ENGINE,
|
||||
EngineOverrides(model="gpt-4.1-mini", reasoning=None),
|
||||
)
|
||||
msg = TelegramIncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=10,
|
||||
text="/reasoning set High",
|
||||
reply_to_message_id=None,
|
||||
reply_to_text=None,
|
||||
sender_id=456,
|
||||
chat_type="supergroup",
|
||||
thread_id=77,
|
||||
)
|
||||
|
||||
await _handle_reasoning_command(
|
||||
cfg,
|
||||
msg,
|
||||
"set High",
|
||||
ambient_context=None,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=None,
|
||||
resolved_scope="main",
|
||||
scope_chat_ids=frozenset({123}),
|
||||
)
|
||||
|
||||
override = await topic_store.get_engine_override(123, 77, CODEX_ENGINE)
|
||||
assert override is not None
|
||||
assert override.model == "gpt-4.1-mini"
|
||||
assert override.reasoning == "high"
|
||||
assert (
|
||||
"topic reasoning override set to high for codex."
|
||||
in transport.send_calls[-1]["message"].text
|
||||
)
|
||||
|
||||
msg_clear = replace(
|
||||
msg,
|
||||
message_id=11,
|
||||
text="/reasoning clear",
|
||||
)
|
||||
await _handle_reasoning_command(
|
||||
cfg,
|
||||
msg_clear,
|
||||
"clear",
|
||||
ambient_context=None,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=None,
|
||||
resolved_scope="main",
|
||||
scope_chat_ids=frozenset({123}),
|
||||
)
|
||||
|
||||
override = await topic_store.get_engine_override(123, 77, CODEX_ENGINE)
|
||||
assert override is not None
|
||||
assert override.model == "gpt-4.1-mini"
|
||||
assert override.reasoning is None
|
||||
assert (
|
||||
"topic reasoning override cleared (using chat default)."
|
||||
in transport.send_calls[-1]["message"].text
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_reasoning_command_show_reports_overrides(tmp_path: Path) -> None:
|
||||
transport = _FakeTransport()
|
||||
cfg = _make_cfg(transport)
|
||||
cfg = replace(cfg, topics=TelegramTopicsSettings(enabled=True, scope="main"))
|
||||
chat_prefs = ChatPrefsStore(tmp_path / "telegram_chat_prefs_state.json")
|
||||
topic_store = TopicStateStore(tmp_path / "telegram_topics_state.json")
|
||||
await chat_prefs.set_engine_override(
|
||||
123,
|
||||
CODEX_ENGINE,
|
||||
EngineOverrides(model=None, reasoning="low"),
|
||||
)
|
||||
await topic_store.set_engine_override(
|
||||
123,
|
||||
88,
|
||||
CODEX_ENGINE,
|
||||
EngineOverrides(model=None, reasoning="high"),
|
||||
)
|
||||
msg = TelegramIncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=10,
|
||||
text="/reasoning",
|
||||
reply_to_message_id=None,
|
||||
reply_to_text=None,
|
||||
sender_id=123,
|
||||
thread_id=88,
|
||||
)
|
||||
|
||||
await _handle_reasoning_command(
|
||||
cfg,
|
||||
msg,
|
||||
"",
|
||||
ambient_context=None,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
resolved_scope="main",
|
||||
scope_chat_ids=frozenset({123}),
|
||||
)
|
||||
|
||||
text = transport.send_calls[-1]["message"].text
|
||||
assert "engine: codex (global default)" in text
|
||||
assert "reasoning: high (topic override)" in text
|
||||
assert "defaults: topic: high, chat: low" in text
|
||||
assert "available levels: minimal, low, medium, high, xhigh" in text
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_send_with_resume_waits_for_token() -> None:
|
||||
transport = _FakeTransport()
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
import pytest
|
||||
|
||||
from takopi.telegram.chat_prefs import ChatPrefsStore
|
||||
from takopi.telegram.engine_overrides import (
|
||||
EngineOverrides,
|
||||
merge_overrides,
|
||||
resolve_override_value,
|
||||
)
|
||||
from takopi.telegram.topic_state import TopicStateStore
|
||||
|
||||
|
||||
def test_merge_overrides_prefers_topic_values() -> None:
|
||||
topic = EngineOverrides(model=None, reasoning="high")
|
||||
chat = EngineOverrides(model="gpt-4.1-mini", reasoning=None)
|
||||
merged = merge_overrides(topic, chat)
|
||||
|
||||
assert merged is not None
|
||||
assert merged.model == "gpt-4.1-mini"
|
||||
assert merged.reasoning == "high"
|
||||
|
||||
|
||||
def test_resolve_override_value_tracks_sources() -> None:
|
||||
topic = EngineOverrides(model="gpt-4.1", reasoning=None)
|
||||
chat = EngineOverrides(model="gpt-4.1-mini", reasoning="low")
|
||||
resolution = resolve_override_value(
|
||||
topic_override=topic,
|
||||
chat_override=chat,
|
||||
field="model",
|
||||
)
|
||||
|
||||
assert resolution.value == "gpt-4.1"
|
||||
assert resolution.source == "topic_override"
|
||||
assert resolution.topic_value == "gpt-4.1"
|
||||
assert resolution.chat_value == "gpt-4.1-mini"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_chat_prefs_engine_overrides_roundtrip(tmp_path) -> None:
|
||||
path = tmp_path / "telegram_chat_prefs_state.json"
|
||||
store = ChatPrefsStore(path)
|
||||
await store.set_engine_override(
|
||||
123,
|
||||
"codex",
|
||||
EngineOverrides(model="gpt-4.1-mini", reasoning="low"),
|
||||
)
|
||||
|
||||
override = await store.get_engine_override(123, "codex")
|
||||
assert override is not None
|
||||
assert override.model == "gpt-4.1-mini"
|
||||
assert override.reasoning == "low"
|
||||
|
||||
store2 = ChatPrefsStore(path)
|
||||
override2 = await store2.get_engine_override(123, "codex")
|
||||
assert override2 is not None
|
||||
assert override2.model == "gpt-4.1-mini"
|
||||
assert override2.reasoning == "low"
|
||||
|
||||
await store2.set_engine_override(
|
||||
123,
|
||||
"codex",
|
||||
EngineOverrides(model=None, reasoning="low"),
|
||||
)
|
||||
override3 = await store2.get_engine_override(123, "codex")
|
||||
assert override3 is not None
|
||||
assert override3.model is None
|
||||
assert override3.reasoning == "low"
|
||||
|
||||
await store2.set_engine_override(
|
||||
123,
|
||||
"codex",
|
||||
EngineOverrides(model=None, reasoning=None),
|
||||
)
|
||||
override4 = await store2.get_engine_override(123, "codex")
|
||||
assert override4 is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_topic_state_engine_overrides_roundtrip(tmp_path) -> None:
|
||||
path = tmp_path / "telegram_topics_state.json"
|
||||
store = TopicStateStore(path)
|
||||
await store.set_engine_override(
|
||||
1,
|
||||
10,
|
||||
"codex",
|
||||
EngineOverrides(model="gpt-4.1", reasoning="medium"),
|
||||
)
|
||||
|
||||
override = await store.get_engine_override(1, 10, "codex")
|
||||
assert override is not None
|
||||
assert override.model == "gpt-4.1"
|
||||
assert override.reasoning == "medium"
|
||||
|
||||
store2 = TopicStateStore(path)
|
||||
override2 = await store2.get_engine_override(1, 10, "codex")
|
||||
assert override2 is not None
|
||||
assert override2.model == "gpt-4.1"
|
||||
assert override2.reasoning == "medium"
|
||||
Reference in New Issue
Block a user