From abd0aa2bb4932087db1d4e368dd0f576d2f0f212 Mon Sep 17 00:00:00 2001 From: banteg <4562643+banteg@users.noreply.github.com> Date: Tue, 13 Jan 2026 01:34:08 +0400 Subject: [PATCH] refactor: cleanup, linting, and tooling updates (#108) --- .codex/AGENTS.md | 1 + pyproject.toml | 9 +- src/takopi/cli.py | 90 ++++++++------ src/takopi/commands.py | 35 ++---- src/takopi/config_watch.py | 9 +- src/takopi/engines.py | 32 ++--- src/takopi/logging.py | 17 +-- src/takopi/plugins.py | 35 +++++- src/takopi/runner.py | 4 +- src/takopi/runners/claude.py | 24 ++-- src/takopi/runners/codex.py | 5 +- src/takopi/runners/mock.py | 2 - src/takopi/runners/opencode.py | 8 +- src/takopi/runners/pi.py | 6 - src/takopi/runtime_loader.py | 2 +- src/takopi/settings.py | 13 +- src/takopi/telegram/chat_sessions.py | 75 +++--------- src/takopi/telegram/client.py | 16 +-- src/takopi/telegram/commands.py | 170 ++++++++------------------ src/takopi/telegram/state_store.py | 94 +++++++++++++++ src/takopi/telegram/topic_state.py | 73 +++--------- src/takopi/transport_runtime.py | 172 +++++++++++++++++---------- src/takopi/transports.py | 33 ++--- src/takopi/utils/streams.py | 2 +- src/takopi/utils/subprocess.py | 2 +- tests/test_claude_schema.py | 2 +- tests/test_codex_schema.py | 4 +- tests/test_opencode_schema.py | 2 +- tests/test_pi_schema.py | 2 +- tests/test_projects_config.py | 2 +- tests/test_settings.py | 2 +- 31 files changed, 457 insertions(+), 486 deletions(-) create mode 100644 src/takopi/telegram/state_store.py diff --git a/.codex/AGENTS.md b/.codex/AGENTS.md index 10ea866..f06228b 100644 --- a/.codex/AGENTS.md +++ b/.codex/AGENTS.md @@ -1,3 +1,4 @@ after you finish work, commit with a conventional message. only commit the files you edited. always run `just check` before code commits. if you fix anything from `just check`, rerun it and confirm it passes before committing. +when using gh to edit or create PR descriptions, prefer `--body-file` to preserve newlines. diff --git a/pyproject.toml b/pyproject.toml index 450993b..09cd049 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,5 +61,12 @@ dev = [ ] [tool.pytest.ini_options] -addopts = ["--cov=takopi", "--cov-report=term-missing", "--cov-fail-under=70"] +addopts = ["--cov=takopi", "--cov-report=term-missing", "--cov-fail-under=75"] testpaths = ["tests"] + +[tool.ruff.lint] +extend-select = ["B904", "BLE001", "S110", "RUF043"] + +[tool.ty.src] +include = ["src", "tests"] +exclude = ["scripts"] diff --git a/src/takopi/cli.py b/src/takopi/cli.py index 02f180e..38e5393 100644 --- a/src/takopi/cli.py +++ b/src/takopi/cli.py @@ -12,6 +12,7 @@ from . import __version__ from .config import ConfigError, load_or_init_config, write_config from .config_migrations import migrate_config from .commands import get_command +from .backends import EngineBackend from .engines import get_backend, list_backend_ids from .ids import RESERVED_COMMAND_IDS, RESERVED_ENGINE_IDS from .lockfile import LockError, LockHandle, acquire_lock, token_fingerprint @@ -108,6 +109,26 @@ def _default_engine_for_setup( return value +def _resolve_setup_engine( + default_engine_override: str | None, +) -> tuple[ + TakopiSettings | None, + Path | None, + list[str] | None, + str, + EngineBackend, +]: + settings_hint, config_hint = _load_settings_optional() + allowlist = resolve_plugins_allowlist(settings_hint) + default_engine = _default_engine_for_setup( + default_engine_override, + settings=settings_hint, + config_path=config_hint, + ) + engine_backend = get_backend(default_engine, allowlist=allowlist) + return settings_hint, config_hint, allowlist, default_engine, engine_backend + + def _config_path_display(path: Path) -> str: home = Path.home() try: @@ -148,33 +169,31 @@ def _run_auto_router( setup_logging(debug=debug) lock_handle: LockHandle | None = None try: - settings_hint, config_hint = _load_settings_optional() - allowlist = resolve_plugins_allowlist(settings_hint) - default_engine = _default_engine_for_setup( - default_engine_override, - settings=settings_hint, - config_path=config_hint, - ) - engine_backend = get_backend(default_engine, allowlist=allowlist) + ( + settings_hint, + config_hint, + allowlist, + default_engine, + engine_backend, + ) = _resolve_setup_engine(default_engine_override) transport_id = _resolve_transport_id(transport_override) transport_backend = get_transport(transport_id, allowlist=allowlist) except ConfigError as e: typer.echo(f"error: {e}", err=True) - raise typer.Exit(code=1) + raise typer.Exit(code=1) from e if onboard: if not _should_run_interactive(): typer.echo("error: --onboard requires a TTY", err=True) raise typer.Exit(code=1) if not transport_backend.interactive_setup(force=True): raise typer.Exit(code=1) - settings_hint, config_hint = _load_settings_optional() - allowlist = resolve_plugins_allowlist(settings_hint) - default_engine = _default_engine_for_setup( - default_engine_override, - settings=settings_hint, - config_path=config_hint, - ) - engine_backend = get_backend(default_engine, allowlist=allowlist) + ( + settings_hint, + config_hint, + allowlist, + default_engine, + engine_backend, + ) = _resolve_setup_engine(default_engine_override) setup = transport_backend.check_setup( engine_backend, transport_override=transport_override, @@ -189,27 +208,25 @@ def _run_auto_router( default=False, ) if run_onboard and transport_backend.interactive_setup(force=True): - settings_hint, config_hint = _load_settings_optional() - allowlist = resolve_plugins_allowlist(settings_hint) - default_engine = _default_engine_for_setup( - default_engine_override, - settings=settings_hint, - config_path=config_hint, - ) - engine_backend = get_backend(default_engine, allowlist=allowlist) + ( + settings_hint, + config_hint, + allowlist, + default_engine, + engine_backend, + ) = _resolve_setup_engine(default_engine_override) setup = transport_backend.check_setup( engine_backend, transport_override=transport_override, ) elif transport_backend.interactive_setup(force=False): - settings_hint, config_hint = _load_settings_optional() - allowlist = resolve_plugins_allowlist(settings_hint) - default_engine = _default_engine_for_setup( - default_engine_override, - settings=settings_hint, - config_path=config_hint, - ) - engine_backend = get_backend(default_engine, allowlist=allowlist) + ( + settings_hint, + config_hint, + allowlist, + default_engine, + engine_backend, + ) = _resolve_setup_engine(default_engine_override) setup = transport_backend.check_setup( engine_backend, transport_override=transport_override, @@ -252,10 +269,10 @@ def _run_auto_router( ) except ConfigError as e: typer.echo(f"error: {e}", err=True) - raise typer.Exit(code=1) + raise typer.Exit(code=1) from e except KeyboardInterrupt: logger.info("shutdown.interrupted") - raise typer.Exit(code=130) + raise typer.Exit(code=130) from None finally: if lock_handle is not None: lock_handle.release() @@ -279,8 +296,7 @@ def _default_alias_from_path(path: Path) -> str | None: name = path.name if not name: return None - if name.endswith(".git"): - name = name[: -len(".git")] + name = name.removesuffix(".git") return name or None diff --git a/src/takopi/commands.py b/src/takopi/commands.py index 6b8d26c..c911ec6 100644 --- a/src/takopi/commands.py +++ b/src/takopi/commands.py @@ -9,13 +9,7 @@ from .config import ConfigError from .context import RunContext from .ids import RESERVED_COMMAND_IDS from .model import EngineId -from .plugins import ( - COMMAND_GROUP, - PluginLoadFailed, - PluginNotFound, - load_entrypoint, - list_ids, -) +from .plugins import COMMAND_GROUP, list_ids, load_plugin_backend from .transport import MessageRef, RenderedMessage from .transport_runtime import TransportRuntime @@ -122,25 +116,14 @@ def get_command( ) -> CommandBackend | None: if command_id.lower() in RESERVED_COMMAND_IDS: raise ConfigError(f"Command id {command_id!r} is reserved.") - try: - backend = load_entrypoint( - COMMAND_GROUP, - command_id, - allowlist=allowlist, - validator=_validate_command_backend, - ) - except PluginNotFound as exc: - if not required: - return None - if exc.available: - available = ", ".join(exc.available) - message = f"Unknown command {command_id!r}. Available: {available}." - else: - message = f"Unknown command {command_id!r}." - raise ConfigError(message) from exc - except PluginLoadFailed as exc: - raise ConfigError(f"Failed to load command {command_id!r}: {exc}") from exc - return backend + return load_plugin_backend( + COMMAND_GROUP, + command_id, + allowlist=allowlist, + validator=_validate_command_backend, + kind_label="command", + required=required, + ) def list_command_ids(*, allowlist: Iterable[str] | None = None) -> list[str]: diff --git a/src/takopi/config_watch.py b/src/takopi/config_watch.py index 5fe3c92..88acbec 100644 --- a/src/takopi/config_watch.py +++ b/src/takopi/config_watch.py @@ -40,6 +40,13 @@ def config_status(path: Path) -> tuple[str, tuple[int, int] | None]: return "ok", (stat.st_mtime_ns, stat.st_size) +def _matches_config_path(candidate: str, config_path: Path) -> bool: + try: + return Path(candidate).resolve(strict=False) == config_path + except OSError: + return False + + def _reload_config( config_path: Path, default_engine_override: str | None, @@ -76,7 +83,7 @@ async def watch_config( logger.warning("config.watch.unavailable", path=str(config_path), status=status) async for changes in awatch(watch_root): - if not any(Path(path) == config_path for _, path in changes): + if not any(_matches_config_path(path, config_path) for _, path in changes): continue status, current = config_status(config_path) diff --git a/src/takopi/engines.py b/src/takopi/engines.py index 0065d40..244fa45 100644 --- a/src/takopi/engines.py +++ b/src/takopi/engines.py @@ -4,13 +4,7 @@ from typing import Iterable from .backends import EngineBackend from .config import ConfigError -from .plugins import ( - ENGINE_GROUP, - PluginLoadFailed, - PluginNotFound, - load_entrypoint, - list_ids, -) +from .plugins import ENGINE_GROUP, list_ids, load_plugin_backend from .ids import RESERVED_ENGINE_IDS @@ -28,22 +22,14 @@ def get_backend( ) -> EngineBackend: if engine_id.lower() in RESERVED_ENGINE_IDS: raise ConfigError(f"Engine id {engine_id!r} is reserved.") - try: - backend = load_entrypoint( - ENGINE_GROUP, - engine_id, - allowlist=allowlist, - validator=_validate_engine_backend, - ) - except PluginNotFound as exc: - if exc.available: - available = ", ".join(exc.available) - message = f"Unknown engine {engine_id!r}. Available: {available}." - else: - message = f"Unknown engine {engine_id!r}." - raise ConfigError(message) from exc - except PluginLoadFailed as exc: - raise ConfigError(f"Failed to load engine {engine_id!r}: {exc}") from exc + backend = load_plugin_backend( + ENGINE_GROUP, + engine_id, + allowlist=allowlist, + validator=_validate_engine_backend, + kind_label="engine", + ) + assert backend is not None return backend diff --git a/src/takopi/logging.py b/src/takopi/logging.py index 8d6a4bc..3775c70 100644 --- a/src/takopi/logging.py +++ b/src/takopi/logging.py @@ -126,8 +126,8 @@ def _file_sink( payload = payload.decode("utf-8", errors="replace") _log_file_handle.write(payload + "\n") _log_file_handle.flush() - except Exception: - pass + except Exception: # noqa: BLE001 + return event_dict return event_dict @@ -202,8 +202,8 @@ class SafeWriter(io.TextIOBase): self._closed = True try: self._stream.close() - except Exception: - pass + except Exception: # noqa: BLE001 + return def setup_logging( @@ -236,13 +236,14 @@ def setup_logging( if _log_file_handle is not None: try: _log_file_handle.close() - except Exception: - pass - _log_file_handle = None + except Exception: # noqa: BLE001 + _log_file_handle = None + else: + _log_file_handle = None if log_file: try: _log_file_handle = open(log_file, "a", encoding="utf-8") - except Exception: + except OSError: _log_file_handle = None processors = cast( diff --git a/src/takopi/plugins.py b/src/takopi/plugins.py index c130ed0..813851b 100644 --- a/src/takopi/plugins.py +++ b/src/takopi/plugins.py @@ -100,7 +100,7 @@ def entrypoint_distribution_name(ep: EntryPoint) -> str | None: return None try: return metadata["Name"] - except Exception: + except (KeyError, TypeError): return None @@ -281,3 +281,36 @@ def load_entrypoint( _LOADED[key] = loaded clear_load_errors(group=group, name=name) return loaded + + +def load_plugin_backend( + group: str, + name: str, + *, + allowlist: Iterable[str] | None = None, + validator: Callable[[Any, EntryPoint], None] | None = None, + kind_label: str, + required: bool = True, +) -> Any | None: + try: + return load_entrypoint( + group, + name, + allowlist=allowlist, + validator=validator, + ) + except PluginNotFound as exc: + if not required: + return None + if exc.available: + available = ", ".join(exc.available) + message = f"Unknown {kind_label} {name!r}. Available: {available}." + else: + message = f"Unknown {kind_label} {name!r}." + from .config import ConfigError + + raise ConfigError(message) from exc + except PluginLoadFailed as exc: + from .config import ConfigError + + raise ConfigError(f"Failed to load {kind_label} {name!r}: {exc}") from exc diff --git a/src/takopi/runner.py b/src/takopi/runner.py index ef4f15a..a65649e 100644 --- a/src/takopi/runner.py +++ b/src/takopi/runner.py @@ -427,7 +427,7 @@ class JsonlSubprocessRunner(BaseRunner): line_text = line.decode("utf-8", errors="replace") try: decoded = self.decode_jsonl(line=line) - except Exception as exc: + except Exception as exc: # noqa: BLE001 log_pipeline( logger, "jsonl.parse.error", @@ -470,7 +470,7 @@ class JsonlSubprocessRunner(BaseRunner): resume=resume, found_session=found_session, ) - except Exception as exc: + except Exception as exc: # noqa: BLE001 log_pipeline( logger, "runner.translate.error", diff --git a/src/takopi/runners/claude.py b/src/takopi/runners/claude.py index 475b353..c3f3e93 100644 --- a/src/takopi/runners/claude.py +++ b/src/takopi/runners/claude.py @@ -155,15 +155,12 @@ def _tool_result_event( normalized = _normalize_tool_result(raw_result) preview = normalized - detail = dict(action.detail) - detail.update( - { - "tool_use_id": content.tool_use_id, - "result_preview": preview, - "result_len": len(normalized), - "is_error": is_error, - } - ) + detail = action.detail | { + "tool_use_id": content.tool_use_id, + "result_preview": preview, + "result_len": len(normalized), + "is_error": is_error, + } return factory.action_completed( action_id=action.id, kind=action.kind, @@ -367,7 +364,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner): *, state: Any, ) -> list[str]: - _ = state return self._build_args(prompt, resume) def stdin_payload( @@ -377,11 +373,9 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner): *, state: Any, ) -> bytes | None: - _ = prompt, resume, state return None def env(self, *, state: Any) -> dict[str, str] | None: - _ = state if self.use_api_billing is not True: env = dict(os.environ) env.pop("ANTHROPIC_API_KEY", None) @@ -389,7 +383,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner): return None def new_state(self, prompt: str, resume: ResumeToken | None) -> ClaudeStreamState: - _ = prompt, resume return ClaudeStreamState() def start_run( @@ -399,7 +392,7 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner): *, state: ClaudeStreamState, ) -> None: - _ = state, prompt, resume + pass def decode_jsonl( self, @@ -416,7 +409,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner): error: Exception, state: ClaudeStreamState, ) -> list[TakopiEvent]: - _ = raw, line, state if isinstance(error, msgspec.DecodeError): self.get_logger().warning( "jsonl.msgspec.invalid", @@ -439,7 +431,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner): line: str, state: ClaudeStreamState, ) -> list[TakopiEvent]: - _ = raw, line, state return [] def translate( @@ -450,7 +441,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner): resume: ResumeToken | None, found_session: ResumeToken | None, ) -> list[TakopiEvent]: - _ = resume, found_session return translate_claude_event( data, title=self.session_title, diff --git a/src/takopi/runners/codex.py b/src/takopi/runners/codex.py index dae835d..2c81b6a 100644 --- a/src/takopi/runners/codex.py +++ b/src/takopi/runners/codex.py @@ -426,7 +426,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner): *, state: Any, ) -> list[str]: - _ = prompt, state args = [ *self.extra_args, "exec", @@ -441,7 +440,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner): return args def new_state(self, prompt: str, resume: ResumeToken | None) -> CodexRunState: - _ = prompt, resume return CodexRunState(factory=EventFactory(ENGINE)) def start_run( @@ -451,7 +449,7 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner): *, state: CodexRunState, ) -> None: - _ = state, prompt, resume + pass def decode_jsonl(self, *, line: bytes) -> codex_schema.ThreadEvent: return codex_schema.decode_event(line) @@ -464,7 +462,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner): error: Exception, state: CodexRunState, ) -> list[TakopiEvent]: - _ = raw, line if isinstance(error, msgspec.DecodeError): self.get_logger().warning( "jsonl.msgspec.invalid", diff --git a/src/takopi/runners/mock.py b/src/takopi/runners/mock.py index b4dcbfd..c9e0cb4 100644 --- a/src/takopi/runners/mock.py +++ b/src/takopi/runners/mock.py @@ -84,7 +84,6 @@ class MockRunner(SessionLockMixin, ResumeTokenMixin, Runner): async def run( self, prompt: str, resume: ResumeToken | None ) -> AsyncIterator[TakopiEvent]: - _ = prompt token_value = None if resume is not None: if resume.engine != self.engine: @@ -158,7 +157,6 @@ class ScriptRunner(MockRunner): self, prompt: str, resume: ResumeToken | None ) -> AsyncIterator[TakopiEvent]: self.calls.append((prompt, resume)) - _ = prompt token_value = None if resume is not None: if resume.engine != self.engine: diff --git a/src/takopi/runners/opencode.py b/src/takopi/runners/opencode.py index d0b7d44..a28bb17 100644 --- a/src/takopi/runners/opencode.py +++ b/src/takopi/runners/opencode.py @@ -367,7 +367,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner): *, state: Any, ) -> list[str]: - _ = state args = ["run", "--format", "json"] if resume is not None: args.extend(["--session", resume.value]) @@ -383,11 +382,9 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner): *, state: Any, ) -> bytes | None: - _ = prompt, resume, state return None def new_state(self, prompt: str, resume: ResumeToken | None) -> OpenCodeStreamState: - _ = prompt, resume return OpenCodeStreamState() def start_run( @@ -397,7 +394,7 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner): *, state: OpenCodeStreamState, ) -> None: - _ = state, prompt, resume + pass def invalid_json_events( self, @@ -406,7 +403,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner): line: str, state: OpenCodeStreamState, ) -> list[TakopiEvent]: - _ = line message = "invalid JSON from opencode; ignoring line" return [self.note_event(message, state=state, detail={"line": raw})] @@ -418,7 +414,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner): resume: ResumeToken | None, found_session: ResumeToken | None, ) -> list[TakopiEvent]: - _ = resume, found_session return translate_opencode_event( data, title=self.session_title, @@ -436,7 +431,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner): error: Exception, state: OpenCodeStreamState, ) -> list[TakopiEvent]: - _ = raw, line, state if isinstance(error, msgspec.DecodeError): self.get_logger().warning( "jsonl.msgspec.invalid", diff --git a/src/takopi/runners/pi.py b/src/takopi/runners/pi.py index eeeee1e..ef64fb0 100644 --- a/src/takopi/runners/pi.py +++ b/src/takopi/runners/pi.py @@ -290,7 +290,6 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner): *, state: PiStreamState, ) -> list[str]: - _ = resume args: list[str] = [*self.extra_args, "--print", "--mode", "json"] if self.provider: args.extend(["--provider", self.provider]) @@ -307,18 +306,15 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner): *, state: PiStreamState, ) -> bytes | None: - _ = prompt, resume, state return None def env(self, *, state: PiStreamState) -> dict[str, str] | None: - _ = state env = dict(os.environ) env.setdefault("NO_COLOR", "1") env.setdefault("CI", "1") return env def new_state(self, prompt: str, resume: ResumeToken | None) -> PiStreamState: - _ = prompt if resume is None: session_path = self._new_session_path() token = ResumeToken(engine=ENGINE, value=session_path) @@ -334,7 +330,6 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner): resume: ResumeToken | None, found_session: ResumeToken | None, ) -> list[TakopiEvent]: - _ = resume, found_session meta: dict[str, Any] = {"cwd": os.getcwd()} if self.model: meta["model"] = self.model @@ -362,7 +357,6 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner): error: Exception, state: PiStreamState, ) -> list[TakopiEvent]: - _ = raw, line, state if isinstance(error, msgspec.DecodeError): self.get_logger().warning( "jsonl.msgspec.invalid", diff --git a/src/takopi/runtime_loader.py b/src/takopi/runtime_loader.py index afecb65..bf6cae6 100644 --- a/src/takopi/runtime_loader.py +++ b/src/takopi/runtime_loader.py @@ -103,7 +103,7 @@ def build_router( if engine_cfg: try: runner = backend.build_runner({}, config_path) - except Exception as fallback_exc: + except Exception as fallback_exc: # noqa: BLE001 warnings.append(f"{engine_id}: {issue or str(fallback_exc)}") continue status = "bad_config" diff --git a/src/takopi/settings.py b/src/takopi/settings.py index 4d8cb4a..e24f837 100644 --- a/src/takopi/settings.py +++ b/src/takopi/settings.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Annotated, Any, Iterable, Literal +from typing import Annotated, Any, ClassVar, Iterable, Literal from pydantic import ( BaseModel, @@ -62,6 +62,9 @@ class TelegramTopicsSettings(BaseModel): class TelegramFilesSettings(BaseModel): model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) + max_upload_bytes: ClassVar[int] = 20 * 1024 * 1024 + max_download_bytes: ClassVar[int] = 50 * 1024 * 1024 + enabled: bool = False auto_put: bool = True auto_put_mode: Literal["upload", "prompt"] = "upload" @@ -84,14 +87,6 @@ class TelegramFilesSettings(BaseModel): raise ValueError("files.uploads_dir must be a relative path") return value - @property - def max_upload_bytes(self) -> int: - return 20 * 1024 * 1024 - - @property - def max_download_bytes(self) -> int: - return 50 * 1024 * 1024 - class TelegramTransportSettings(BaseModel): model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) diff --git a/src/takopi/telegram/chat_sessions.py b/src/takopi/telegram/chat_sessions.py index 1e61987..6acc1f2 100644 --- a/src/takopi/telegram/chat_sessions.py +++ b/src/takopi/telegram/chat_sessions.py @@ -1,14 +1,12 @@ from __future__ import annotations -import json -import os from pathlib import Path -import anyio import msgspec from ..logging import get_logger from ..model import ResumeToken +from .state_store import JsonStateStore logger = get_logger(__name__) @@ -38,13 +36,20 @@ def _chat_key(chat_id: int, owner_id: int | None) -> str: return f"{chat_id}:{owner}" -class ChatSessionStore: +def _new_state() -> _ChatSessionsState: + return _ChatSessionsState(version=STATE_VERSION, chats={}) + + +class ChatSessionStore(JsonStateStore[_ChatSessionsState]): def __init__(self, path: Path) -> None: - self._path = path - self._lock = anyio.Lock() - self._loaded = False - self._mtime_ns: int | None = None - self._state = _ChatSessionsState(version=STATE_VERSION, chats={}) + super().__init__( + path, + version=STATE_VERSION, + state_type=_ChatSessionsState, + state_factory=_new_state, + log_prefix="telegram.chat_sessions", + logger=logger, + ) async def get_session_resume( self, chat_id: int, owner_id: int | None, engine: str @@ -77,58 +82,6 @@ class ChatSessionStore: chat.sessions = {} self._save_locked() - def _stat_mtime_ns(self) -> int | None: - try: - return self._path.stat().st_mtime_ns - except FileNotFoundError: - return None - - def _reload_locked_if_needed(self) -> None: - current = self._stat_mtime_ns() - if self._loaded and current == self._mtime_ns: - return - self._load_locked() - - def _load_locked(self) -> None: - self._loaded = True - self._mtime_ns = self._stat_mtime_ns() - if self._mtime_ns is None: - self._state = _ChatSessionsState(version=STATE_VERSION, chats={}) - return - try: - payload = msgspec.json.decode( - self._path.read_bytes(), type=_ChatSessionsState - ) - except Exception as exc: - logger.warning( - "telegram.chat_sessions.load_failed", - path=str(self._path), - error=str(exc), - error_type=exc.__class__.__name__, - ) - self._state = _ChatSessionsState(version=STATE_VERSION, chats={}) - return - if payload.version != STATE_VERSION: - logger.warning( - "telegram.chat_sessions.version_mismatch", - path=str(self._path), - version=payload.version, - expected=STATE_VERSION, - ) - self._state = _ChatSessionsState(version=STATE_VERSION, chats={}) - return - self._state = payload - - def _save_locked(self) -> None: - self._path.parent.mkdir(parents=True, exist_ok=True) - payload = msgspec.to_builtins(self._state) - tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp") - with open(tmp_path, "w", encoding="utf-8") as handle: - json.dump(payload, handle, indent=2, sort_keys=True) - handle.write("\n") - os.replace(tmp_path, self._path) - self._mtime_ns = self._stat_mtime_ns() - def _get_chat_locked(self, chat_id: int, owner_id: int | None) -> _ChatState | None: return self._state.chats.get(_chat_key(chat_id, owner_id)) diff --git a/src/takopi/telegram/client.py b/src/takopi/telegram/client.py index 46dca64..f3b2f4a 100644 --- a/src/takopi/telegram/client.py +++ b/src/takopi/telegram/client.py @@ -515,7 +515,7 @@ class TelegramOutbox: async def execute_op(self, op: OutboxOp) -> Any: try: return await op.execute() - except Exception as exc: + except Exception as exc: # noqa: BLE001 if isinstance(exc, RetryAfter): raise if self._on_error is not None: @@ -566,7 +566,7 @@ class TelegramOutbox: op.set_result(result) except cancel_exc: return - except Exception as exc: + except Exception as exc: # noqa: BLE001 async with self._cond: self._closed = True self.fail_pending() @@ -759,7 +759,7 @@ class TelegramClient: retry_after: float | None = None try: response_payload = resp.json() - except Exception: + except Exception: # noqa: BLE001 response_payload = None if isinstance(response_payload, dict): retry_after = retry_after_from_payload(response_payload) @@ -785,7 +785,7 @@ class TelegramClient: try: response_payload = resp.json() - except Exception as exc: + except Exception as exc: # noqa: BLE001 body = resp.text logger.error( "telegram.bad_response", @@ -815,7 +815,7 @@ class TelegramClient: return None try: return msgspec.convert(payload, type=model) - except Exception as exc: + except Exception as exc: # noqa: BLE001 logger.error( "telegram.decode_error", method=method, @@ -862,7 +862,7 @@ class TelegramClient: return None try: return msgspec.convert(raw, type=list[Update]) - except Exception as exc: + except Exception as exc: # noqa: BLE001 logger.error( "telegram.decode_error", method="getUpdates", @@ -881,7 +881,7 @@ class TelegramClient: return None try: return msgspec.convert(result, type=list[Update]) - except Exception as exc: + except Exception as exc: # noqa: BLE001 logger.error( "telegram.decode_error", method="getUpdates", @@ -928,7 +928,7 @@ class TelegramClient: retry_after: float | None = None try: response_payload = resp.json() - except Exception: + except Exception: # noqa: BLE001 response_payload = None if isinstance(response_payload, dict): retry_after = retry_after_from_payload(response_payload) diff --git a/src/takopi/telegram/commands.py b/src/takopi/telegram/commands.py index 4d5c271..434596f 100644 --- a/src/takopi/telegram/commands.py +++ b/src/takopi/telegram/commands.py @@ -37,7 +37,7 @@ from ..scheduler import ThreadScheduler from ..transport import MessageRef, RenderedMessage, SendOptions from ..transport_runtime import ResolvedMessage, TransportRuntime from ..utils.paths import reset_run_base_dir, set_run_base_dir -from .bridge import send_plain +from .bridge import TelegramBridgeConfig, send_plain from .chat_sessions import ChatSessionStore from .context import ( _format_context, @@ -203,13 +203,25 @@ def _reserved_commands(runtime: TransportRuntime) -> set[str]: } -async def _set_command_menu(cfg) -> None: +def _reply_sender( + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage +) -> Callable[..., Awaitable[None]]: + return partial( + send_plain, + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + thread_id=msg.thread_id, + ) + + +async def _set_command_menu(cfg: TelegramBridgeConfig) -> None: commands = build_bot_commands(cfg.runtime, include_file=cfg.files.enabled) if not commands: return try: ok = await cfg.bot.set_my_commands(commands) - except Exception as exc: + except Exception as exc: # noqa: BLE001 logger.info( "startup.command_menu.failed", error=str(exc), @@ -297,7 +309,7 @@ def _should_show_resume_line( def resolve_file_put_paths( plan: _FilePutPlan, *, - cfg, + cfg: TelegramBridgeConfig, require_dir: bool, ) -> tuple[Path | None, Path | None, str | None]: path_value = plan.path_value @@ -322,14 +334,10 @@ def resolve_file_put_paths( return None, rel_path, None -async def _check_file_permissions(cfg, msg: TelegramIncomingMessage) -> bool: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) +async def _check_file_permissions( + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage +) -> bool: + reply = _reply_sender(cfg, msg) sender_id = msg.sender_id if sender_id is None: await reply(text="cannot verify sender for file transfer.") @@ -355,19 +363,13 @@ async def _check_file_permissions(cfg, msg: TelegramIncomingMessage) -> bool: async def _prepare_file_put_plan( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, args_text: str, ambient_context: RunContext | None, topic_store: TopicStateStore | None, ) -> _FilePutPlan | None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) if not await _check_file_permissions(cfg, msg): return None try: @@ -423,7 +425,7 @@ def _format_file_put_failures(failed: Sequence[_FilePutResult]) -> str | None: async def _save_document_payload( - cfg, + cfg: TelegramBridgeConfig, *, document: TelegramDocument, run_root: Path, @@ -524,19 +526,13 @@ async def _save_document_payload( async def _handle_file_command( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, args_text: str, ambient_context: RunContext | None, topic_store: TopicStateStore | None, ) -> None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) command, rest, error = parse_file_command(args_text) if error is not None: await reply(text=error) @@ -548,7 +544,7 @@ async def _handle_file_command( async def _handle_file_put_default( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, ambient_context: RunContext | None, topic_store: TopicStateStore | None, @@ -557,19 +553,13 @@ async def _handle_file_put_default( async def _save_file_put( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, args_text: str, ambient_context: RunContext | None, topic_store: TopicStateStore | None, ) -> _SavedFilePut | None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) document = msg.document if document is None: await reply(text=FILE_PUT_USAGE) @@ -613,19 +603,13 @@ async def _save_file_put( async def _handle_file_put( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, args_text: str, ambient_context: RunContext | None, topic_store: TopicStateStore | None, ) -> None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) saved = await _save_file_put( cfg, msg, @@ -645,20 +629,14 @@ async def _handle_file_put( async def _handle_file_put_group( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, args_text: str, messages: Sequence[TelegramIncomingMessage], ambient_context: RunContext | None, topic_store: TopicStateStore | None, ) -> None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) saved_group = await _save_file_put_group( cfg, msg, @@ -700,20 +678,14 @@ async def _handle_file_put_group( async def _save_file_put_group( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, args_text: str, messages: Sequence[TelegramIncomingMessage], ambient_context: RunContext | None, topic_store: TopicStateStore | None, ) -> _SavedFilePutGroup | None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) documents = [item.document for item in messages if item.document is not None] if not documents: await reply(text=FILE_PUT_USAGE) @@ -759,7 +731,7 @@ async def _save_file_put_group( async def _handle_media_group( - cfg, + cfg: TelegramBridgeConfig, messages: Sequence[TelegramIncomingMessage], topic_store: TopicStateStore | None, run_prompt: Callable[ @@ -779,13 +751,7 @@ async def _handle_media_group( (item for item in ordered if item.text.strip()), ordered[0], ) - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=command_msg.chat_id, - user_msg_id=command_msg.message_id, - thread_id=command_msg.thread_id, - ) + reply = _reply_sender(cfg, command_msg) topic_key = _topic_key(command_msg, cfg) if topic_store is not None else None chat_project = _topics_chat_project(cfg, command_msg.chat_id) bound_context = ( @@ -884,19 +850,13 @@ async def _handle_media_group( async def _handle_file_get( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, args_text: str, ambient_context: RunContext | None, topic_store: TopicStateStore | None, ) -> None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) if not await _check_file_permissions(cfg, msg): return try: @@ -989,7 +949,7 @@ async def _handle_file_get( async def _handle_ctx_command( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, args_text: str, store: TopicStateStore, @@ -997,13 +957,7 @@ async def _handle_ctx_command( resolved_scope: str | None = None, scope_chat_ids: frozenset[int] | None = None, ) -> None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) error = _topics_command_error( cfg, msg.chat_id, @@ -1081,20 +1035,14 @@ async def _handle_ctx_command( async def _handle_new_command( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, store: TopicStateStore, *, resolved_scope: str | None = None, scope_chat_ids: frozenset[int] | None = None, ) -> None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) error = _topics_command_error( cfg, msg.chat_id, @@ -1113,18 +1061,12 @@ async def _handle_new_command( async def _handle_chat_new_command( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, store: ChatSessionStore, session_key: tuple[int, int | None] | None, ) -> None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) if session_key is None: await reply(text="no stored sessions to clear for this chat.") return @@ -1137,7 +1079,7 @@ async def _handle_chat_new_command( async def _handle_topic_command( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, args_text: str, store: TopicStateStore, @@ -1145,13 +1087,7 @@ async def _handle_topic_command( resolved_scope: str | None = None, scope_chat_ids: frozenset[int] | None = None, ) -> None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) error = _topics_command_error( cfg, msg.chat_id, @@ -1203,17 +1139,11 @@ async def _handle_topic_command( async def handle_cancel( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, running_tasks: RunningTasks, ) -> None: - reply = partial( - send_plain, - cfg.exec_cfg.transport, - chat_id=msg.chat_id, - user_msg_id=msg.message_id, - thread_id=msg.thread_id, - ) + reply = _reply_sender(cfg, msg) chat_id = msg.chat_id reply_id = msg.reply_to_message_id @@ -1239,7 +1169,7 @@ async def handle_cancel( async def handle_callback_cancel( - cfg, + cfg: TelegramBridgeConfig, query: TelegramCallbackQuery, running_tasks: RunningTasks, ) -> None: @@ -1572,7 +1502,7 @@ class _TelegramCommandExecutor(CommandExecutor): async def _dispatch_command( - cfg, + cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, text: str, command_id: str, diff --git a/src/takopi/telegram/state_store.py b/src/takopi/telegram/state_store.py new file mode 100644 index 0000000..35582dd --- /dev/null +++ b/src/takopi/telegram/state_store.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any, Callable, Generic, Protocol, TypeVar + +import anyio +import msgspec + +T = TypeVar("T", bound="_VersionedState") + + +class _Logger(Protocol): + def warning(self, event: str, **fields: Any) -> None: ... + + +class _VersionedState(Protocol): + version: int + + +class JsonStateStore(Generic[T]): + def __init__( + self, + path: Path, + *, + version: int, + state_type: type[T], + state_factory: Callable[[], T], + log_prefix: str, + logger: _Logger, + ) -> None: + self._path = path + self._lock = anyio.Lock() + self._loaded = False + self._mtime_ns: int | None = None + self._state_type = state_type + self._state_factory = state_factory + self._version = version + self._log_prefix = log_prefix + self._logger = logger + self._state = state_factory() + + def _stat_mtime_ns(self) -> int | None: + try: + return self._path.stat().st_mtime_ns + except FileNotFoundError: + return None + + def _reload_locked_if_needed(self) -> None: + current = self._stat_mtime_ns() + if self._loaded and current == self._mtime_ns: + return + self._load_locked() + + def _load_locked(self) -> None: + self._loaded = True + self._mtime_ns = self._stat_mtime_ns() + if self._mtime_ns is None: + self._state = self._state_factory() + return + try: + payload = msgspec.json.decode( + self._path.read_bytes(), type=self._state_type + ) + except Exception as exc: # noqa: BLE001 + self._logger.warning( + f"{self._log_prefix}.load_failed", + path=str(self._path), + error=str(exc), + error_type=exc.__class__.__name__, + ) + self._state = self._state_factory() + return + if payload.version != self._version: + self._logger.warning( + f"{self._log_prefix}.version_mismatch", + path=str(self._path), + version=payload.version, + expected=self._version, + ) + self._state = self._state_factory() + return + self._state = payload + + def _save_locked(self) -> None: + self._path.parent.mkdir(parents=True, exist_ok=True) + payload = msgspec.to_builtins(self._state) + tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp") + with open(tmp_path, "w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + os.replace(tmp_path, self._path) + self._mtime_ns = self._stat_mtime_ns() diff --git a/src/takopi/telegram/topic_state.py b/src/takopi/telegram/topic_state.py index 074ac34..8cebd58 100644 --- a/src/takopi/telegram/topic_state.py +++ b/src/takopi/telegram/topic_state.py @@ -1,16 +1,14 @@ from __future__ import annotations -import json -import os from dataclasses import dataclass from pathlib import Path -import anyio import msgspec from ..context import RunContext from ..logging import get_logger from ..model import ResumeToken +from .state_store import JsonStateStore logger = get_logger(__name__) @@ -82,13 +80,20 @@ def _context_to_state(context: RunContext | None) -> _ContextState | None: return _ContextState(project=project, branch=branch) -class TopicStateStore: +def _new_state() -> _TopicState: + return _TopicState(version=STATE_VERSION, threads={}) + + +class TopicStateStore(JsonStateStore[_TopicState]): def __init__(self, path: Path) -> None: - self._path = path - self._lock = anyio.Lock() - self._loaded = False - self._mtime_ns: int | None = None - self._state = _TopicState(version=STATE_VERSION, threads={}) + super().__init__( + path, + version=STATE_VERSION, + state_type=_TopicState, + state_factory=_new_state, + log_prefix="telegram.topic_state", + logger=logger, + ) async def get_thread( self, chat_id: int, thread_id: int @@ -202,56 +207,6 @@ class TopicStateStore: topic_title=thread.topic_title, ) - def _stat_mtime_ns(self) -> int | None: - try: - return self._path.stat().st_mtime_ns - except FileNotFoundError: - return None - - def _reload_locked_if_needed(self) -> None: - current = self._stat_mtime_ns() - if self._loaded and current == self._mtime_ns: - return - self._load_locked() - - def _load_locked(self) -> None: - self._loaded = True - self._mtime_ns = self._stat_mtime_ns() - if self._mtime_ns is None: - self._state = _TopicState(version=STATE_VERSION, threads={}) - return - try: - payload = msgspec.json.decode(self._path.read_bytes(), type=_TopicState) - except Exception as exc: - logger.warning( - "telegram.topic_state.load_failed", - path=str(self._path), - error=str(exc), - error_type=exc.__class__.__name__, - ) - self._state = _TopicState(version=STATE_VERSION, threads={}) - return - if payload.version != STATE_VERSION: - logger.warning( - "telegram.topic_state.version_mismatch", - path=str(self._path), - version=payload.version, - expected=STATE_VERSION, - ) - self._state = _TopicState(version=STATE_VERSION, threads={}) - return - self._state = payload - - def _save_locked(self) -> None: - self._path.parent.mkdir(parents=True, exist_ok=True) - payload = msgspec.to_builtins(self._state) - tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp") - with open(tmp_path, "w", encoding="utf-8") as handle: - json.dump(payload, handle, indent=2, sort_keys=True) - handle.write("\n") - os.replace(tmp_path, self._path) - self._mtime_ns = self._stat_mtime_ns() - def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None: return self._state.threads.get(_thread_key(chat_id, thread_id)) diff --git a/src/takopi/transport_runtime.py b/src/takopi/transport_runtime.py index 2d5768d..e3bc83c 100644 --- a/src/takopi/transport_runtime.py +++ b/src/takopi/transport_runtime.py @@ -3,17 +3,30 @@ from __future__ import annotations from collections.abc import Iterable, Mapping from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, TypeAlias from .config import ConfigError, ProjectsConfig from .context import RunContext -from .directives import format_context_line, parse_context_line, parse_directives +from .directives import ( + ParsedDirectives, + format_context_line, + parse_context_line, + parse_directives, +) from .model import EngineId, ResumeToken from .plugins import normalize_allowlist from .router import AutoRouter, EngineStatus from .runner import Runner from .worktrees import WorktreeError, resolve_run_cwd +ContextSource: TypeAlias = Literal[ + "reply_ctx", + "directives", + "ambient", + "default_project", + "none", +] + @dataclass(frozen=True, slots=True) class ResolvedMessage: @@ -21,13 +34,7 @@ class ResolvedMessage: resume_token: ResumeToken | None engine_override: EngineId | None context: RunContext | None - context_source: Literal[ - "reply_ctx", - "directives", - "ambient", - "default_project", - "none", - ] = "none" + context_source: ContextSource = "none" @dataclass(frozen=True, slots=True) @@ -58,12 +65,14 @@ class TransportRuntime: plugin_configs: Mapping[str, Any] | None = None, watch_config: bool = False, ) -> None: - self._router = router - self._projects = projects - self._allowlist = normalize_allowlist(allowlist) - self._config_path = config_path - self._plugin_configs = dict(plugin_configs or {}) - self._watch_config = watch_config + self._apply( + router=router, + projects=projects, + allowlist=allowlist, + config_path=config_path, + plugin_configs=plugin_configs, + watch_config=watch_config, + ) def update( self, @@ -74,6 +83,25 @@ class TransportRuntime: config_path: Path | None = None, plugin_configs: Mapping[str, Any] | None = None, watch_config: bool = False, + ) -> None: + self._apply( + router=router, + projects=projects, + allowlist=allowlist, + config_path=config_path, + plugin_configs=plugin_configs, + watch_config=watch_config, + ) + + def _apply( + self, + *, + router: AutoRouter, + projects: ProjectsConfig, + allowlist: Iterable[str] | None, + config_path: Path | None, + plugin_configs: Mapping[str, Any] | None, + watch_config: bool, ) -> None: self._router = router self._projects = projects @@ -162,51 +190,16 @@ class TransportRuntime: chat_project = self._projects.project_for_chat(chat_id) default_project = chat_project or self._projects.default_project - context_source: Literal[ - "reply_ctx", - "directives", - "ambient", - "default_project", - "none", - ] = "none" - context: RunContext | None = None - - if reply_ctx is not None: - context = reply_ctx - context_source = "reply_ctx" - else: - project_key = directives.project - branch = directives.branch - if project_key is None: - if ambient_context is not None and ambient_context.project is not None: - project_key = ambient_context.project - else: - project_key = default_project - if branch is None: - if ( - ambient_context is not None - and ambient_context.branch is not None - and project_key == ambient_context.project - ): - branch = ambient_context.branch - if project_key is not None or branch is not None: - context = RunContext(project=project_key, branch=branch) - if directives.project is not None or directives.branch is not None: - context_source = "directives" - elif ambient_context is not None and ambient_context.project is not None: - context_source = "ambient" - elif default_project is not None: - context_source = "default_project" - - engine_override = directives.engine - if engine_override is None and context is not None: - project = ( - self._projects.projects.get(context.project) - if context.project is not None - else None - ) - if project is not None and project.default_engine is not None: - engine_override = project.default_engine + context, context_source = self._resolve_context( + directives=directives, + reply_ctx=reply_ctx, + ambient_context=ambient_context, + default_project=default_project, + ) + engine_override = self._resolve_engine_override( + directives_engine=directives.engine, + context=context, + ) return ResolvedMessage( prompt=directives.prompt, @@ -216,6 +209,65 @@ class TransportRuntime: context_source=context_source, ) + def _resolve_context( + self, + *, + directives: ParsedDirectives, + reply_ctx: RunContext | None, + ambient_context: RunContext | None, + default_project: str | None, + ) -> tuple[RunContext | None, ContextSource]: + if reply_ctx is not None: + return reply_ctx, "reply_ctx" + + project_key = directives.project + branch = directives.branch + if project_key is None: + if ambient_context is not None and ambient_context.project is not None: + project_key = ambient_context.project + else: + project_key = default_project + if branch is None: + if ( + ambient_context is not None + and ambient_context.branch is not None + and project_key == ambient_context.project + ): + branch = ambient_context.branch + context: RunContext | None = None + if project_key is not None or branch is not None: + context = RunContext(project=project_key, branch=branch) + + if directives.project is not None or directives.branch is not None: + context_source: ContextSource = "directives" + elif ambient_context is not None and ambient_context.project is not None: + context_source = "ambient" + elif default_project is not None: + context_source = "default_project" + else: + context_source = "none" + + return context, context_source + + def _resolve_engine_override( + self, + *, + directives_engine: EngineId | None, + context: RunContext | None, + ) -> EngineId | None: + if directives_engine is not None: + return directives_engine + if context is None: + return None + project = ( + self._projects.projects.get(context.project) + if context.project is not None + else None + ) + if project is not None and project.default_engine is not None: + return project.default_engine + return None + @property def default_project(self) -> str | None: return self._projects.default_project diff --git a/src/takopi/transports.py b/src/takopi/transports.py index 4e948cd..21e625f 100644 --- a/src/takopi/transports.py +++ b/src/takopi/transports.py @@ -5,14 +5,7 @@ from pathlib import Path from typing import Iterable, Protocol, runtime_checkable from .backends import EngineBackend, SetupIssue -from .config import ConfigError -from .plugins import ( - PluginLoadFailed, - PluginNotFound, - TRANSPORT_GROUP, - load_entrypoint, - list_ids, -) +from .plugins import TRANSPORT_GROUP, list_ids, load_plugin_backend from .transport_runtime import TransportRuntime @@ -67,22 +60,14 @@ def _validate_transport_backend(backend: object, ep) -> None: def get_transport( transport_id: str, *, allowlist: Iterable[str] | None = None ) -> TransportBackend: - try: - backend = load_entrypoint( - TRANSPORT_GROUP, - transport_id, - allowlist=allowlist, - validator=_validate_transport_backend, - ) - except PluginNotFound as exc: - if exc.available: - available = ", ".join(exc.available) - message = f"Unknown transport {transport_id!r}. Available: {available}." - else: - message = f"Unknown transport {transport_id!r}." - raise ConfigError(message) from exc - except PluginLoadFailed as exc: - raise ConfigError(f"Failed to load transport {transport_id!r}: {exc}") from exc + backend = load_plugin_backend( + TRANSPORT_GROUP, + transport_id, + allowlist=allowlist, + validator=_validate_transport_backend, + kind_label="transport", + ) + assert backend is not None return backend diff --git a/src/takopi/utils/streams.py b/src/takopi/utils/streams.py index dd6a3bf..2171e4f 100644 --- a/src/takopi/utils/streams.py +++ b/src/takopi/utils/streams.py @@ -35,7 +35,7 @@ async def drain_stderr( tag=tag, line=text, ) - except Exception as exc: + except Exception as exc: # noqa: BLE001 log_pipeline( logger, "subprocess.stderr.error", diff --git a/src/takopi/utils/subprocess.py b/src/takopi/utils/subprocess.py index 64ef849..5ee1a96 100644 --- a/src/takopi/utils/subprocess.py +++ b/src/takopi/utils/subprocess.py @@ -53,7 +53,7 @@ def _signal_process( return except ProcessLookupError: return - except Exception as exc: + except OSError as exc: logger.debug( log_event, error=str(exc), diff --git a/tests/test_claude_schema.py b/tests/test_claude_schema.py index 8bee334..69f2523 100644 --- a/tests/test_claude_schema.py +++ b/tests/test_claude_schema.py @@ -20,7 +20,7 @@ def _decode_fixture(name: str) -> list[str]: continue try: decoded = claude_schema.decode_stream_json_line(line) - except Exception as exc: + except Exception as exc: # noqa: BLE001 errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}") continue diff --git a/tests/test_codex_schema.py b/tests/test_codex_schema.py index 946e8d8..ad0c7fa 100644 --- a/tests/test_codex_schema.py +++ b/tests/test_codex_schema.py @@ -21,12 +21,12 @@ def _decode_fixture(name: str) -> list[str]: continue try: json.loads(line) - except Exception as exc: + except Exception as exc: # noqa: BLE001 errors.append(f"line {lineno}: invalid JSON ({exc})") continue try: codex_schema.decode_event(line) - except Exception as exc: + except Exception as exc: # noqa: BLE001 errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}") return errors diff --git a/tests/test_opencode_schema.py b/tests/test_opencode_schema.py index 63157e6..31366e2 100644 --- a/tests/test_opencode_schema.py +++ b/tests/test_opencode_schema.py @@ -20,7 +20,7 @@ def _decode_fixture(name: str) -> list[str]: continue try: opencode_schema.decode_event(line) - except Exception as exc: + except Exception as exc: # noqa: BLE001 errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}") return errors diff --git a/tests/test_pi_schema.py b/tests/test_pi_schema.py index 9bee5ac..d4dbc7e 100644 --- a/tests/test_pi_schema.py +++ b/tests/test_pi_schema.py @@ -20,7 +20,7 @@ def _decode_fixture(name: str) -> list[str]: continue try: pi_schema.decode_event(line) - except Exception as exc: + except Exception as exc: # noqa: BLE001 errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}") return errors diff --git a/tests/test_projects_config.py b/tests/test_projects_config.py index be52428..7550125 100644 --- a/tests/test_projects_config.py +++ b/tests/test_projects_config.py @@ -90,7 +90,7 @@ def test_projects_default_engine_unknown() -> None: "projects": {"z80": {"path": "/tmp/repo", "default_engine": "nope"}}, } settings = TakopiSettings.model_validate(config) - with pytest.raises(ConfigError, match="projects.z80.default_engine"): + with pytest.raises(ConfigError, match=r"projects\.z80\.default_engine"): settings.to_projects_config( config_path=Path("takopi.toml"), engine_ids=["codex"], diff --git a/tests/test_settings.py b/tests/test_settings.py index c04a50f..1590efe 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -172,7 +172,7 @@ def test_transport_config_telegram_and_extra(tmp_path: Path) -> None: }, } ) - with pytest.raises(ConfigError, match="transports.discord"): + with pytest.raises(ConfigError, match=r"transports\.discord"): settings.transport_config("discord", config_path=config_path)