refactor: cleanup, linting, and tooling updates (#108)

This commit is contained in:
banteg
2026-01-13 01:34:08 +04:00
committed by GitHub
parent 2809974698
commit abd0aa2bb4
31 changed files with 457 additions and 486 deletions
+1
View File
@@ -1,3 +1,4 @@
after you finish work, commit with a conventional message. only commit the files you edited. after you finish work, commit with a conventional message. only commit the files you edited.
always run `just check` before code commits. always run `just check` before code commits.
if you fix anything from `just check`, rerun it and confirm it passes before committing. if you fix anything from `just check`, rerun it and confirm it passes before committing.
when using gh to edit or create PR descriptions, prefer `--body-file` to preserve newlines.
+8 -1
View File
@@ -61,5 +61,12 @@ dev = [
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = ["--cov=takopi", "--cov-report=term-missing", "--cov-fail-under=70"] addopts = ["--cov=takopi", "--cov-report=term-missing", "--cov-fail-under=75"]
testpaths = ["tests"] testpaths = ["tests"]
[tool.ruff.lint]
extend-select = ["B904", "BLE001", "S110", "RUF043"]
[tool.ty.src]
include = ["src", "tests"]
exclude = ["scripts"]
+53 -37
View File
@@ -12,6 +12,7 @@ from . import __version__
from .config import ConfigError, load_or_init_config, write_config from .config import ConfigError, load_or_init_config, write_config
from .config_migrations import migrate_config from .config_migrations import migrate_config
from .commands import get_command from .commands import get_command
from .backends import EngineBackend
from .engines import get_backend, list_backend_ids from .engines import get_backend, list_backend_ids
from .ids import RESERVED_COMMAND_IDS, RESERVED_ENGINE_IDS from .ids import RESERVED_COMMAND_IDS, RESERVED_ENGINE_IDS
from .lockfile import LockError, LockHandle, acquire_lock, token_fingerprint from .lockfile import LockError, LockHandle, acquire_lock, token_fingerprint
@@ -108,6 +109,26 @@ def _default_engine_for_setup(
return value return value
def _resolve_setup_engine(
default_engine_override: str | None,
) -> tuple[
TakopiSettings | None,
Path | None,
list[str] | None,
str,
EngineBackend,
]:
settings_hint, config_hint = _load_settings_optional()
allowlist = resolve_plugins_allowlist(settings_hint)
default_engine = _default_engine_for_setup(
default_engine_override,
settings=settings_hint,
config_path=config_hint,
)
engine_backend = get_backend(default_engine, allowlist=allowlist)
return settings_hint, config_hint, allowlist, default_engine, engine_backend
def _config_path_display(path: Path) -> str: def _config_path_display(path: Path) -> str:
home = Path.home() home = Path.home()
try: try:
@@ -148,33 +169,31 @@ def _run_auto_router(
setup_logging(debug=debug) setup_logging(debug=debug)
lock_handle: LockHandle | None = None lock_handle: LockHandle | None = None
try: try:
settings_hint, config_hint = _load_settings_optional() (
allowlist = resolve_plugins_allowlist(settings_hint) settings_hint,
default_engine = _default_engine_for_setup( config_hint,
default_engine_override, allowlist,
settings=settings_hint, default_engine,
config_path=config_hint, engine_backend,
) ) = _resolve_setup_engine(default_engine_override)
engine_backend = get_backend(default_engine, allowlist=allowlist)
transport_id = _resolve_transport_id(transport_override) transport_id = _resolve_transport_id(transport_override)
transport_backend = get_transport(transport_id, allowlist=allowlist) transport_backend = get_transport(transport_id, allowlist=allowlist)
except ConfigError as e: except ConfigError as e:
typer.echo(f"error: {e}", err=True) typer.echo(f"error: {e}", err=True)
raise typer.Exit(code=1) raise typer.Exit(code=1) from e
if onboard: if onboard:
if not _should_run_interactive(): if not _should_run_interactive():
typer.echo("error: --onboard requires a TTY", err=True) typer.echo("error: --onboard requires a TTY", err=True)
raise typer.Exit(code=1) raise typer.Exit(code=1)
if not transport_backend.interactive_setup(force=True): if not transport_backend.interactive_setup(force=True):
raise typer.Exit(code=1) raise typer.Exit(code=1)
settings_hint, config_hint = _load_settings_optional() (
allowlist = resolve_plugins_allowlist(settings_hint) settings_hint,
default_engine = _default_engine_for_setup( config_hint,
default_engine_override, allowlist,
settings=settings_hint, default_engine,
config_path=config_hint, engine_backend,
) ) = _resolve_setup_engine(default_engine_override)
engine_backend = get_backend(default_engine, allowlist=allowlist)
setup = transport_backend.check_setup( setup = transport_backend.check_setup(
engine_backend, engine_backend,
transport_override=transport_override, transport_override=transport_override,
@@ -189,27 +208,25 @@ def _run_auto_router(
default=False, default=False,
) )
if run_onboard and transport_backend.interactive_setup(force=True): if run_onboard and transport_backend.interactive_setup(force=True):
settings_hint, config_hint = _load_settings_optional() (
allowlist = resolve_plugins_allowlist(settings_hint) settings_hint,
default_engine = _default_engine_for_setup( config_hint,
default_engine_override, allowlist,
settings=settings_hint, default_engine,
config_path=config_hint, engine_backend,
) ) = _resolve_setup_engine(default_engine_override)
engine_backend = get_backend(default_engine, allowlist=allowlist)
setup = transport_backend.check_setup( setup = transport_backend.check_setup(
engine_backend, engine_backend,
transport_override=transport_override, transport_override=transport_override,
) )
elif transport_backend.interactive_setup(force=False): elif transport_backend.interactive_setup(force=False):
settings_hint, config_hint = _load_settings_optional() (
allowlist = resolve_plugins_allowlist(settings_hint) settings_hint,
default_engine = _default_engine_for_setup( config_hint,
default_engine_override, allowlist,
settings=settings_hint, default_engine,
config_path=config_hint, engine_backend,
) ) = _resolve_setup_engine(default_engine_override)
engine_backend = get_backend(default_engine, allowlist=allowlist)
setup = transport_backend.check_setup( setup = transport_backend.check_setup(
engine_backend, engine_backend,
transport_override=transport_override, transport_override=transport_override,
@@ -252,10 +269,10 @@ def _run_auto_router(
) )
except ConfigError as e: except ConfigError as e:
typer.echo(f"error: {e}", err=True) typer.echo(f"error: {e}", err=True)
raise typer.Exit(code=1) raise typer.Exit(code=1) from e
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("shutdown.interrupted") logger.info("shutdown.interrupted")
raise typer.Exit(code=130) raise typer.Exit(code=130) from None
finally: finally:
if lock_handle is not None: if lock_handle is not None:
lock_handle.release() lock_handle.release()
@@ -279,8 +296,7 @@ def _default_alias_from_path(path: Path) -> str | None:
name = path.name name = path.name
if not name: if not name:
return None return None
if name.endswith(".git"): name = name.removesuffix(".git")
name = name[: -len(".git")]
return name or None return name or None
+9 -26
View File
@@ -9,13 +9,7 @@ from .config import ConfigError
from .context import RunContext from .context import RunContext
from .ids import RESERVED_COMMAND_IDS from .ids import RESERVED_COMMAND_IDS
from .model import EngineId from .model import EngineId
from .plugins import ( from .plugins import COMMAND_GROUP, list_ids, load_plugin_backend
COMMAND_GROUP,
PluginLoadFailed,
PluginNotFound,
load_entrypoint,
list_ids,
)
from .transport import MessageRef, RenderedMessage from .transport import MessageRef, RenderedMessage
from .transport_runtime import TransportRuntime from .transport_runtime import TransportRuntime
@@ -122,25 +116,14 @@ def get_command(
) -> CommandBackend | None: ) -> CommandBackend | None:
if command_id.lower() in RESERVED_COMMAND_IDS: if command_id.lower() in RESERVED_COMMAND_IDS:
raise ConfigError(f"Command id {command_id!r} is reserved.") raise ConfigError(f"Command id {command_id!r} is reserved.")
try: return load_plugin_backend(
backend = load_entrypoint( COMMAND_GROUP,
COMMAND_GROUP, command_id,
command_id, allowlist=allowlist,
allowlist=allowlist, validator=_validate_command_backend,
validator=_validate_command_backend, kind_label="command",
) required=required,
except PluginNotFound as exc: )
if not required:
return None
if exc.available:
available = ", ".join(exc.available)
message = f"Unknown command {command_id!r}. Available: {available}."
else:
message = f"Unknown command {command_id!r}."
raise ConfigError(message) from exc
except PluginLoadFailed as exc:
raise ConfigError(f"Failed to load command {command_id!r}: {exc}") from exc
return backend
def list_command_ids(*, allowlist: Iterable[str] | None = None) -> list[str]: def list_command_ids(*, allowlist: Iterable[str] | None = None) -> list[str]:
+8 -1
View File
@@ -40,6 +40,13 @@ def config_status(path: Path) -> tuple[str, tuple[int, int] | None]:
return "ok", (stat.st_mtime_ns, stat.st_size) return "ok", (stat.st_mtime_ns, stat.st_size)
def _matches_config_path(candidate: str, config_path: Path) -> bool:
try:
return Path(candidate).resolve(strict=False) == config_path
except OSError:
return False
def _reload_config( def _reload_config(
config_path: Path, config_path: Path,
default_engine_override: str | None, default_engine_override: str | None,
@@ -76,7 +83,7 @@ async def watch_config(
logger.warning("config.watch.unavailable", path=str(config_path), status=status) logger.warning("config.watch.unavailable", path=str(config_path), status=status)
async for changes in awatch(watch_root): async for changes in awatch(watch_root):
if not any(Path(path) == config_path for _, path in changes): if not any(_matches_config_path(path, config_path) for _, path in changes):
continue continue
status, current = config_status(config_path) status, current = config_status(config_path)
+9 -23
View File
@@ -4,13 +4,7 @@ from typing import Iterable
from .backends import EngineBackend from .backends import EngineBackend
from .config import ConfigError from .config import ConfigError
from .plugins import ( from .plugins import ENGINE_GROUP, list_ids, load_plugin_backend
ENGINE_GROUP,
PluginLoadFailed,
PluginNotFound,
load_entrypoint,
list_ids,
)
from .ids import RESERVED_ENGINE_IDS from .ids import RESERVED_ENGINE_IDS
@@ -28,22 +22,14 @@ def get_backend(
) -> EngineBackend: ) -> EngineBackend:
if engine_id.lower() in RESERVED_ENGINE_IDS: if engine_id.lower() in RESERVED_ENGINE_IDS:
raise ConfigError(f"Engine id {engine_id!r} is reserved.") raise ConfigError(f"Engine id {engine_id!r} is reserved.")
try: backend = load_plugin_backend(
backend = load_entrypoint( ENGINE_GROUP,
ENGINE_GROUP, engine_id,
engine_id, allowlist=allowlist,
allowlist=allowlist, validator=_validate_engine_backend,
validator=_validate_engine_backend, kind_label="engine",
) )
except PluginNotFound as exc: assert backend is not None
if exc.available:
available = ", ".join(exc.available)
message = f"Unknown engine {engine_id!r}. Available: {available}."
else:
message = f"Unknown engine {engine_id!r}."
raise ConfigError(message) from exc
except PluginLoadFailed as exc:
raise ConfigError(f"Failed to load engine {engine_id!r}: {exc}") from exc
return backend return backend
+9 -8
View File
@@ -126,8 +126,8 @@ def _file_sink(
payload = payload.decode("utf-8", errors="replace") payload = payload.decode("utf-8", errors="replace")
_log_file_handle.write(payload + "\n") _log_file_handle.write(payload + "\n")
_log_file_handle.flush() _log_file_handle.flush()
except Exception: except Exception: # noqa: BLE001
pass return event_dict
return event_dict return event_dict
@@ -202,8 +202,8 @@ class SafeWriter(io.TextIOBase):
self._closed = True self._closed = True
try: try:
self._stream.close() self._stream.close()
except Exception: except Exception: # noqa: BLE001
pass return
def setup_logging( def setup_logging(
@@ -236,13 +236,14 @@ def setup_logging(
if _log_file_handle is not None: if _log_file_handle is not None:
try: try:
_log_file_handle.close() _log_file_handle.close()
except Exception: except Exception: # noqa: BLE001
pass _log_file_handle = None
_log_file_handle = None else:
_log_file_handle = None
if log_file: if log_file:
try: try:
_log_file_handle = open(log_file, "a", encoding="utf-8") _log_file_handle = open(log_file, "a", encoding="utf-8")
except Exception: except OSError:
_log_file_handle = None _log_file_handle = None
processors = cast( processors = cast(
+34 -1
View File
@@ -100,7 +100,7 @@ def entrypoint_distribution_name(ep: EntryPoint) -> str | None:
return None return None
try: try:
return metadata["Name"] return metadata["Name"]
except Exception: except (KeyError, TypeError):
return None return None
@@ -281,3 +281,36 @@ def load_entrypoint(
_LOADED[key] = loaded _LOADED[key] = loaded
clear_load_errors(group=group, name=name) clear_load_errors(group=group, name=name)
return loaded return loaded
def load_plugin_backend(
group: str,
name: str,
*,
allowlist: Iterable[str] | None = None,
validator: Callable[[Any, EntryPoint], None] | None = None,
kind_label: str,
required: bool = True,
) -> Any | None:
try:
return load_entrypoint(
group,
name,
allowlist=allowlist,
validator=validator,
)
except PluginNotFound as exc:
if not required:
return None
if exc.available:
available = ", ".join(exc.available)
message = f"Unknown {kind_label} {name!r}. Available: {available}."
else:
message = f"Unknown {kind_label} {name!r}."
from .config import ConfigError
raise ConfigError(message) from exc
except PluginLoadFailed as exc:
from .config import ConfigError
raise ConfigError(f"Failed to load {kind_label} {name!r}: {exc}") from exc
+2 -2
View File
@@ -427,7 +427,7 @@ class JsonlSubprocessRunner(BaseRunner):
line_text = line.decode("utf-8", errors="replace") line_text = line.decode("utf-8", errors="replace")
try: try:
decoded = self.decode_jsonl(line=line) decoded = self.decode_jsonl(line=line)
except Exception as exc: except Exception as exc: # noqa: BLE001
log_pipeline( log_pipeline(
logger, logger,
"jsonl.parse.error", "jsonl.parse.error",
@@ -470,7 +470,7 @@ class JsonlSubprocessRunner(BaseRunner):
resume=resume, resume=resume,
found_session=found_session, found_session=found_session,
) )
except Exception as exc: except Exception as exc: # noqa: BLE001
log_pipeline( log_pipeline(
logger, logger,
"runner.translate.error", "runner.translate.error",
+7 -17
View File
@@ -155,15 +155,12 @@ def _tool_result_event(
normalized = _normalize_tool_result(raw_result) normalized = _normalize_tool_result(raw_result)
preview = normalized preview = normalized
detail = dict(action.detail) detail = action.detail | {
detail.update( "tool_use_id": content.tool_use_id,
{ "result_preview": preview,
"tool_use_id": content.tool_use_id, "result_len": len(normalized),
"result_preview": preview, "is_error": is_error,
"result_len": len(normalized), }
"is_error": is_error,
}
)
return factory.action_completed( return factory.action_completed(
action_id=action.id, action_id=action.id,
kind=action.kind, kind=action.kind,
@@ -367,7 +364,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
*, *,
state: Any, state: Any,
) -> list[str]: ) -> list[str]:
_ = state
return self._build_args(prompt, resume) return self._build_args(prompt, resume)
def stdin_payload( def stdin_payload(
@@ -377,11 +373,9 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
*, *,
state: Any, state: Any,
) -> bytes | None: ) -> bytes | None:
_ = prompt, resume, state
return None return None
def env(self, *, state: Any) -> dict[str, str] | None: def env(self, *, state: Any) -> dict[str, str] | None:
_ = state
if self.use_api_billing is not True: if self.use_api_billing is not True:
env = dict(os.environ) env = dict(os.environ)
env.pop("ANTHROPIC_API_KEY", None) env.pop("ANTHROPIC_API_KEY", None)
@@ -389,7 +383,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
return None return None
def new_state(self, prompt: str, resume: ResumeToken | None) -> ClaudeStreamState: def new_state(self, prompt: str, resume: ResumeToken | None) -> ClaudeStreamState:
_ = prompt, resume
return ClaudeStreamState() return ClaudeStreamState()
def start_run( def start_run(
@@ -399,7 +392,7 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
*, *,
state: ClaudeStreamState, state: ClaudeStreamState,
) -> None: ) -> None:
_ = state, prompt, resume pass
def decode_jsonl( def decode_jsonl(
self, self,
@@ -416,7 +409,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
error: Exception, error: Exception,
state: ClaudeStreamState, state: ClaudeStreamState,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
_ = raw, line, state
if isinstance(error, msgspec.DecodeError): if isinstance(error, msgspec.DecodeError):
self.get_logger().warning( self.get_logger().warning(
"jsonl.msgspec.invalid", "jsonl.msgspec.invalid",
@@ -439,7 +431,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
line: str, line: str,
state: ClaudeStreamState, state: ClaudeStreamState,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
_ = raw, line, state
return [] return []
def translate( def translate(
@@ -450,7 +441,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
resume: ResumeToken | None, resume: ResumeToken | None,
found_session: ResumeToken | None, found_session: ResumeToken | None,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
_ = resume, found_session
return translate_claude_event( return translate_claude_event(
data, data,
title=self.session_title, title=self.session_title,
+1 -4
View File
@@ -426,7 +426,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
*, *,
state: Any, state: Any,
) -> list[str]: ) -> list[str]:
_ = prompt, state
args = [ args = [
*self.extra_args, *self.extra_args,
"exec", "exec",
@@ -441,7 +440,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
return args return args
def new_state(self, prompt: str, resume: ResumeToken | None) -> CodexRunState: def new_state(self, prompt: str, resume: ResumeToken | None) -> CodexRunState:
_ = prompt, resume
return CodexRunState(factory=EventFactory(ENGINE)) return CodexRunState(factory=EventFactory(ENGINE))
def start_run( def start_run(
@@ -451,7 +449,7 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
*, *,
state: CodexRunState, state: CodexRunState,
) -> None: ) -> None:
_ = state, prompt, resume pass
def decode_jsonl(self, *, line: bytes) -> codex_schema.ThreadEvent: def decode_jsonl(self, *, line: bytes) -> codex_schema.ThreadEvent:
return codex_schema.decode_event(line) return codex_schema.decode_event(line)
@@ -464,7 +462,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
error: Exception, error: Exception,
state: CodexRunState, state: CodexRunState,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
_ = raw, line
if isinstance(error, msgspec.DecodeError): if isinstance(error, msgspec.DecodeError):
self.get_logger().warning( self.get_logger().warning(
"jsonl.msgspec.invalid", "jsonl.msgspec.invalid",
-2
View File
@@ -84,7 +84,6 @@ class MockRunner(SessionLockMixin, ResumeTokenMixin, Runner):
async def run( async def run(
self, prompt: str, resume: ResumeToken | None self, prompt: str, resume: ResumeToken | None
) -> AsyncIterator[TakopiEvent]: ) -> AsyncIterator[TakopiEvent]:
_ = prompt
token_value = None token_value = None
if resume is not None: if resume is not None:
if resume.engine != self.engine: if resume.engine != self.engine:
@@ -158,7 +157,6 @@ class ScriptRunner(MockRunner):
self, prompt: str, resume: ResumeToken | None self, prompt: str, resume: ResumeToken | None
) -> AsyncIterator[TakopiEvent]: ) -> AsyncIterator[TakopiEvent]:
self.calls.append((prompt, resume)) self.calls.append((prompt, resume))
_ = prompt
token_value = None token_value = None
if resume is not None: if resume is not None:
if resume.engine != self.engine: if resume.engine != self.engine:
+1 -7
View File
@@ -367,7 +367,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
*, *,
state: Any, state: Any,
) -> list[str]: ) -> list[str]:
_ = state
args = ["run", "--format", "json"] args = ["run", "--format", "json"]
if resume is not None: if resume is not None:
args.extend(["--session", resume.value]) args.extend(["--session", resume.value])
@@ -383,11 +382,9 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
*, *,
state: Any, state: Any,
) -> bytes | None: ) -> bytes | None:
_ = prompt, resume, state
return None return None
def new_state(self, prompt: str, resume: ResumeToken | None) -> OpenCodeStreamState: def new_state(self, prompt: str, resume: ResumeToken | None) -> OpenCodeStreamState:
_ = prompt, resume
return OpenCodeStreamState() return OpenCodeStreamState()
def start_run( def start_run(
@@ -397,7 +394,7 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
*, *,
state: OpenCodeStreamState, state: OpenCodeStreamState,
) -> None: ) -> None:
_ = state, prompt, resume pass
def invalid_json_events( def invalid_json_events(
self, self,
@@ -406,7 +403,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
line: str, line: str,
state: OpenCodeStreamState, state: OpenCodeStreamState,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
_ = line
message = "invalid JSON from opencode; ignoring line" message = "invalid JSON from opencode; ignoring line"
return [self.note_event(message, state=state, detail={"line": raw})] return [self.note_event(message, state=state, detail={"line": raw})]
@@ -418,7 +414,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
resume: ResumeToken | None, resume: ResumeToken | None,
found_session: ResumeToken | None, found_session: ResumeToken | None,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
_ = resume, found_session
return translate_opencode_event( return translate_opencode_event(
data, data,
title=self.session_title, title=self.session_title,
@@ -436,7 +431,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
error: Exception, error: Exception,
state: OpenCodeStreamState, state: OpenCodeStreamState,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
_ = raw, line, state
if isinstance(error, msgspec.DecodeError): if isinstance(error, msgspec.DecodeError):
self.get_logger().warning( self.get_logger().warning(
"jsonl.msgspec.invalid", "jsonl.msgspec.invalid",
-6
View File
@@ -290,7 +290,6 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
*, *,
state: PiStreamState, state: PiStreamState,
) -> list[str]: ) -> list[str]:
_ = resume
args: list[str] = [*self.extra_args, "--print", "--mode", "json"] args: list[str] = [*self.extra_args, "--print", "--mode", "json"]
if self.provider: if self.provider:
args.extend(["--provider", self.provider]) args.extend(["--provider", self.provider])
@@ -307,18 +306,15 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
*, *,
state: PiStreamState, state: PiStreamState,
) -> bytes | None: ) -> bytes | None:
_ = prompt, resume, state
return None return None
def env(self, *, state: PiStreamState) -> dict[str, str] | None: def env(self, *, state: PiStreamState) -> dict[str, str] | None:
_ = state
env = dict(os.environ) env = dict(os.environ)
env.setdefault("NO_COLOR", "1") env.setdefault("NO_COLOR", "1")
env.setdefault("CI", "1") env.setdefault("CI", "1")
return env return env
def new_state(self, prompt: str, resume: ResumeToken | None) -> PiStreamState: def new_state(self, prompt: str, resume: ResumeToken | None) -> PiStreamState:
_ = prompt
if resume is None: if resume is None:
session_path = self._new_session_path() session_path = self._new_session_path()
token = ResumeToken(engine=ENGINE, value=session_path) token = ResumeToken(engine=ENGINE, value=session_path)
@@ -334,7 +330,6 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
resume: ResumeToken | None, resume: ResumeToken | None,
found_session: ResumeToken | None, found_session: ResumeToken | None,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
_ = resume, found_session
meta: dict[str, Any] = {"cwd": os.getcwd()} meta: dict[str, Any] = {"cwd": os.getcwd()}
if self.model: if self.model:
meta["model"] = self.model meta["model"] = self.model
@@ -362,7 +357,6 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
error: Exception, error: Exception,
state: PiStreamState, state: PiStreamState,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
_ = raw, line, state
if isinstance(error, msgspec.DecodeError): if isinstance(error, msgspec.DecodeError):
self.get_logger().warning( self.get_logger().warning(
"jsonl.msgspec.invalid", "jsonl.msgspec.invalid",
+1 -1
View File
@@ -103,7 +103,7 @@ def build_router(
if engine_cfg: if engine_cfg:
try: try:
runner = backend.build_runner({}, config_path) runner = backend.build_runner({}, config_path)
except Exception as fallback_exc: except Exception as fallback_exc: # noqa: BLE001
warnings.append(f"{engine_id}: {issue or str(fallback_exc)}") warnings.append(f"{engine_id}: {issue or str(fallback_exc)}")
continue continue
status = "bad_config" status = "bad_config"
+4 -9
View File
@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Annotated, Any, Iterable, Literal from typing import Annotated, Any, ClassVar, Iterable, Literal
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
@@ -62,6 +62,9 @@ class TelegramTopicsSettings(BaseModel):
class TelegramFilesSettings(BaseModel): class TelegramFilesSettings(BaseModel):
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
max_upload_bytes: ClassVar[int] = 20 * 1024 * 1024
max_download_bytes: ClassVar[int] = 50 * 1024 * 1024
enabled: bool = False enabled: bool = False
auto_put: bool = True auto_put: bool = True
auto_put_mode: Literal["upload", "prompt"] = "upload" auto_put_mode: Literal["upload", "prompt"] = "upload"
@@ -84,14 +87,6 @@ class TelegramFilesSettings(BaseModel):
raise ValueError("files.uploads_dir must be a relative path") raise ValueError("files.uploads_dir must be a relative path")
return value return value
@property
def max_upload_bytes(self) -> int:
return 20 * 1024 * 1024
@property
def max_download_bytes(self) -> int:
return 50 * 1024 * 1024
class TelegramTransportSettings(BaseModel): class TelegramTransportSettings(BaseModel):
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
+14 -61
View File
@@ -1,14 +1,12 @@
from __future__ import annotations from __future__ import annotations
import json
import os
from pathlib import Path from pathlib import Path
import anyio
import msgspec import msgspec
from ..logging import get_logger from ..logging import get_logger
from ..model import ResumeToken from ..model import ResumeToken
from .state_store import JsonStateStore
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -38,13 +36,20 @@ def _chat_key(chat_id: int, owner_id: int | None) -> str:
return f"{chat_id}:{owner}" return f"{chat_id}:{owner}"
class ChatSessionStore: def _new_state() -> _ChatSessionsState:
return _ChatSessionsState(version=STATE_VERSION, chats={})
class ChatSessionStore(JsonStateStore[_ChatSessionsState]):
def __init__(self, path: Path) -> None: def __init__(self, path: Path) -> None:
self._path = path super().__init__(
self._lock = anyio.Lock() path,
self._loaded = False version=STATE_VERSION,
self._mtime_ns: int | None = None state_type=_ChatSessionsState,
self._state = _ChatSessionsState(version=STATE_VERSION, chats={}) state_factory=_new_state,
log_prefix="telegram.chat_sessions",
logger=logger,
)
async def get_session_resume( async def get_session_resume(
self, chat_id: int, owner_id: int | None, engine: str self, chat_id: int, owner_id: int | None, engine: str
@@ -77,58 +82,6 @@ class ChatSessionStore:
chat.sessions = {} chat.sessions = {}
self._save_locked() self._save_locked()
def _stat_mtime_ns(self) -> int | None:
try:
return self._path.stat().st_mtime_ns
except FileNotFoundError:
return None
def _reload_locked_if_needed(self) -> None:
current = self._stat_mtime_ns()
if self._loaded and current == self._mtime_ns:
return
self._load_locked()
def _load_locked(self) -> None:
self._loaded = True
self._mtime_ns = self._stat_mtime_ns()
if self._mtime_ns is None:
self._state = _ChatSessionsState(version=STATE_VERSION, chats={})
return
try:
payload = msgspec.json.decode(
self._path.read_bytes(), type=_ChatSessionsState
)
except Exception as exc:
logger.warning(
"telegram.chat_sessions.load_failed",
path=str(self._path),
error=str(exc),
error_type=exc.__class__.__name__,
)
self._state = _ChatSessionsState(version=STATE_VERSION, chats={})
return
if payload.version != STATE_VERSION:
logger.warning(
"telegram.chat_sessions.version_mismatch",
path=str(self._path),
version=payload.version,
expected=STATE_VERSION,
)
self._state = _ChatSessionsState(version=STATE_VERSION, chats={})
return
self._state = payload
def _save_locked(self) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True)
payload = msgspec.to_builtins(self._state)
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2, sort_keys=True)
handle.write("\n")
os.replace(tmp_path, self._path)
self._mtime_ns = self._stat_mtime_ns()
def _get_chat_locked(self, chat_id: int, owner_id: int | None) -> _ChatState | None: def _get_chat_locked(self, chat_id: int, owner_id: int | None) -> _ChatState | None:
return self._state.chats.get(_chat_key(chat_id, owner_id)) return self._state.chats.get(_chat_key(chat_id, owner_id))
+8 -8
View File
@@ -515,7 +515,7 @@ class TelegramOutbox:
async def execute_op(self, op: OutboxOp) -> Any: async def execute_op(self, op: OutboxOp) -> Any:
try: try:
return await op.execute() return await op.execute()
except Exception as exc: except Exception as exc: # noqa: BLE001
if isinstance(exc, RetryAfter): if isinstance(exc, RetryAfter):
raise raise
if self._on_error is not None: if self._on_error is not None:
@@ -566,7 +566,7 @@ class TelegramOutbox:
op.set_result(result) op.set_result(result)
except cancel_exc: except cancel_exc:
return return
except Exception as exc: except Exception as exc: # noqa: BLE001
async with self._cond: async with self._cond:
self._closed = True self._closed = True
self.fail_pending() self.fail_pending()
@@ -759,7 +759,7 @@ class TelegramClient:
retry_after: float | None = None retry_after: float | None = None
try: try:
response_payload = resp.json() response_payload = resp.json()
except Exception: except Exception: # noqa: BLE001
response_payload = None response_payload = None
if isinstance(response_payload, dict): if isinstance(response_payload, dict):
retry_after = retry_after_from_payload(response_payload) retry_after = retry_after_from_payload(response_payload)
@@ -785,7 +785,7 @@ class TelegramClient:
try: try:
response_payload = resp.json() response_payload = resp.json()
except Exception as exc: except Exception as exc: # noqa: BLE001
body = resp.text body = resp.text
logger.error( logger.error(
"telegram.bad_response", "telegram.bad_response",
@@ -815,7 +815,7 @@ class TelegramClient:
return None return None
try: try:
return msgspec.convert(payload, type=model) return msgspec.convert(payload, type=model)
except Exception as exc: except Exception as exc: # noqa: BLE001
logger.error( logger.error(
"telegram.decode_error", "telegram.decode_error",
method=method, method=method,
@@ -862,7 +862,7 @@ class TelegramClient:
return None return None
try: try:
return msgspec.convert(raw, type=list[Update]) return msgspec.convert(raw, type=list[Update])
except Exception as exc: except Exception as exc: # noqa: BLE001
logger.error( logger.error(
"telegram.decode_error", "telegram.decode_error",
method="getUpdates", method="getUpdates",
@@ -881,7 +881,7 @@ class TelegramClient:
return None return None
try: try:
return msgspec.convert(result, type=list[Update]) return msgspec.convert(result, type=list[Update])
except Exception as exc: except Exception as exc: # noqa: BLE001
logger.error( logger.error(
"telegram.decode_error", "telegram.decode_error",
method="getUpdates", method="getUpdates",
@@ -928,7 +928,7 @@ class TelegramClient:
retry_after: float | None = None retry_after: float | None = None
try: try:
response_payload = resp.json() response_payload = resp.json()
except Exception: except Exception: # noqa: BLE001
response_payload = None response_payload = None
if isinstance(response_payload, dict): if isinstance(response_payload, dict):
retry_after = retry_after_from_payload(response_payload) retry_after = retry_after_from_payload(response_payload)
+50 -120
View File
@@ -37,7 +37,7 @@ from ..scheduler import ThreadScheduler
from ..transport import MessageRef, RenderedMessage, SendOptions from ..transport import MessageRef, RenderedMessage, SendOptions
from ..transport_runtime import ResolvedMessage, TransportRuntime from ..transport_runtime import ResolvedMessage, TransportRuntime
from ..utils.paths import reset_run_base_dir, set_run_base_dir from ..utils.paths import reset_run_base_dir, set_run_base_dir
from .bridge import send_plain from .bridge import TelegramBridgeConfig, send_plain
from .chat_sessions import ChatSessionStore from .chat_sessions import ChatSessionStore
from .context import ( from .context import (
_format_context, _format_context,
@@ -203,13 +203,25 @@ def _reserved_commands(runtime: TransportRuntime) -> set[str]:
} }
async def _set_command_menu(cfg) -> None: def _reply_sender(
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
) -> Callable[..., Awaitable[None]]:
return partial(
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
async def _set_command_menu(cfg: TelegramBridgeConfig) -> None:
commands = build_bot_commands(cfg.runtime, include_file=cfg.files.enabled) commands = build_bot_commands(cfg.runtime, include_file=cfg.files.enabled)
if not commands: if not commands:
return return
try: try:
ok = await cfg.bot.set_my_commands(commands) ok = await cfg.bot.set_my_commands(commands)
except Exception as exc: except Exception as exc: # noqa: BLE001
logger.info( logger.info(
"startup.command_menu.failed", "startup.command_menu.failed",
error=str(exc), error=str(exc),
@@ -297,7 +309,7 @@ def _should_show_resume_line(
def resolve_file_put_paths( def resolve_file_put_paths(
plan: _FilePutPlan, plan: _FilePutPlan,
*, *,
cfg, cfg: TelegramBridgeConfig,
require_dir: bool, require_dir: bool,
) -> tuple[Path | None, Path | None, str | None]: ) -> tuple[Path | None, Path | None, str | None]:
path_value = plan.path_value path_value = plan.path_value
@@ -322,14 +334,10 @@ def resolve_file_put_paths(
return None, rel_path, None return None, rel_path, None
async def _check_file_permissions(cfg, msg: TelegramIncomingMessage) -> bool: async def _check_file_permissions(
reply = partial( cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
send_plain, ) -> bool:
cfg.exec_cfg.transport, reply = _reply_sender(cfg, msg)
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
sender_id = msg.sender_id sender_id = msg.sender_id
if sender_id is None: if sender_id is None:
await reply(text="cannot verify sender for file transfer.") await reply(text="cannot verify sender for file transfer.")
@@ -355,19 +363,13 @@ async def _check_file_permissions(cfg, msg: TelegramIncomingMessage) -> bool:
async def _prepare_file_put_plan( async def _prepare_file_put_plan(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
args_text: str, args_text: str,
ambient_context: RunContext | None, ambient_context: RunContext | None,
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
) -> _FilePutPlan | None: ) -> _FilePutPlan | None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
if not await _check_file_permissions(cfg, msg): if not await _check_file_permissions(cfg, msg):
return None return None
try: try:
@@ -423,7 +425,7 @@ def _format_file_put_failures(failed: Sequence[_FilePutResult]) -> str | None:
async def _save_document_payload( async def _save_document_payload(
cfg, cfg: TelegramBridgeConfig,
*, *,
document: TelegramDocument, document: TelegramDocument,
run_root: Path, run_root: Path,
@@ -524,19 +526,13 @@ async def _save_document_payload(
async def _handle_file_command( async def _handle_file_command(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
args_text: str, args_text: str,
ambient_context: RunContext | None, ambient_context: RunContext | None,
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
) -> None: ) -> None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
command, rest, error = parse_file_command(args_text) command, rest, error = parse_file_command(args_text)
if error is not None: if error is not None:
await reply(text=error) await reply(text=error)
@@ -548,7 +544,7 @@ async def _handle_file_command(
async def _handle_file_put_default( async def _handle_file_put_default(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
ambient_context: RunContext | None, ambient_context: RunContext | None,
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
@@ -557,19 +553,13 @@ async def _handle_file_put_default(
async def _save_file_put( async def _save_file_put(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
args_text: str, args_text: str,
ambient_context: RunContext | None, ambient_context: RunContext | None,
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
) -> _SavedFilePut | None: ) -> _SavedFilePut | None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
document = msg.document document = msg.document
if document is None: if document is None:
await reply(text=FILE_PUT_USAGE) await reply(text=FILE_PUT_USAGE)
@@ -613,19 +603,13 @@ async def _save_file_put(
async def _handle_file_put( async def _handle_file_put(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
args_text: str, args_text: str,
ambient_context: RunContext | None, ambient_context: RunContext | None,
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
) -> None: ) -> None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
saved = await _save_file_put( saved = await _save_file_put(
cfg, cfg,
msg, msg,
@@ -645,20 +629,14 @@ async def _handle_file_put(
async def _handle_file_put_group( async def _handle_file_put_group(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
args_text: str, args_text: str,
messages: Sequence[TelegramIncomingMessage], messages: Sequence[TelegramIncomingMessage],
ambient_context: RunContext | None, ambient_context: RunContext | None,
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
) -> None: ) -> None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
saved_group = await _save_file_put_group( saved_group = await _save_file_put_group(
cfg, cfg,
msg, msg,
@@ -700,20 +678,14 @@ async def _handle_file_put_group(
async def _save_file_put_group( async def _save_file_put_group(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
args_text: str, args_text: str,
messages: Sequence[TelegramIncomingMessage], messages: Sequence[TelegramIncomingMessage],
ambient_context: RunContext | None, ambient_context: RunContext | None,
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
) -> _SavedFilePutGroup | None: ) -> _SavedFilePutGroup | None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
documents = [item.document for item in messages if item.document is not None] documents = [item.document for item in messages if item.document is not None]
if not documents: if not documents:
await reply(text=FILE_PUT_USAGE) await reply(text=FILE_PUT_USAGE)
@@ -759,7 +731,7 @@ async def _save_file_put_group(
async def _handle_media_group( async def _handle_media_group(
cfg, cfg: TelegramBridgeConfig,
messages: Sequence[TelegramIncomingMessage], messages: Sequence[TelegramIncomingMessage],
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
run_prompt: Callable[ run_prompt: Callable[
@@ -779,13 +751,7 @@ async def _handle_media_group(
(item for item in ordered if item.text.strip()), (item for item in ordered if item.text.strip()),
ordered[0], ordered[0],
) )
reply = partial( reply = _reply_sender(cfg, command_msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=command_msg.chat_id,
user_msg_id=command_msg.message_id,
thread_id=command_msg.thread_id,
)
topic_key = _topic_key(command_msg, cfg) if topic_store is not None else None topic_key = _topic_key(command_msg, cfg) if topic_store is not None else None
chat_project = _topics_chat_project(cfg, command_msg.chat_id) chat_project = _topics_chat_project(cfg, command_msg.chat_id)
bound_context = ( bound_context = (
@@ -884,19 +850,13 @@ async def _handle_media_group(
async def _handle_file_get( async def _handle_file_get(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
args_text: str, args_text: str,
ambient_context: RunContext | None, ambient_context: RunContext | None,
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
) -> None: ) -> None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
if not await _check_file_permissions(cfg, msg): if not await _check_file_permissions(cfg, msg):
return return
try: try:
@@ -989,7 +949,7 @@ async def _handle_file_get(
async def _handle_ctx_command( async def _handle_ctx_command(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
args_text: str, args_text: str,
store: TopicStateStore, store: TopicStateStore,
@@ -997,13 +957,7 @@ async def _handle_ctx_command(
resolved_scope: str | None = None, resolved_scope: str | None = None,
scope_chat_ids: frozenset[int] | None = None, scope_chat_ids: frozenset[int] | None = None,
) -> None: ) -> None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
error = _topics_command_error( error = _topics_command_error(
cfg, cfg,
msg.chat_id, msg.chat_id,
@@ -1081,20 +1035,14 @@ async def _handle_ctx_command(
async def _handle_new_command( async def _handle_new_command(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
store: TopicStateStore, store: TopicStateStore,
*, *,
resolved_scope: str | None = None, resolved_scope: str | None = None,
scope_chat_ids: frozenset[int] | None = None, scope_chat_ids: frozenset[int] | None = None,
) -> None: ) -> None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
error = _topics_command_error( error = _topics_command_error(
cfg, cfg,
msg.chat_id, msg.chat_id,
@@ -1113,18 +1061,12 @@ async def _handle_new_command(
async def _handle_chat_new_command( async def _handle_chat_new_command(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
store: ChatSessionStore, store: ChatSessionStore,
session_key: tuple[int, int | None] | None, session_key: tuple[int, int | None] | None,
) -> None: ) -> None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
if session_key is None: if session_key is None:
await reply(text="no stored sessions to clear for this chat.") await reply(text="no stored sessions to clear for this chat.")
return return
@@ -1137,7 +1079,7 @@ async def _handle_chat_new_command(
async def _handle_topic_command( async def _handle_topic_command(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
args_text: str, args_text: str,
store: TopicStateStore, store: TopicStateStore,
@@ -1145,13 +1087,7 @@ async def _handle_topic_command(
resolved_scope: str | None = None, resolved_scope: str | None = None,
scope_chat_ids: frozenset[int] | None = None, scope_chat_ids: frozenset[int] | None = None,
) -> None: ) -> None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
error = _topics_command_error( error = _topics_command_error(
cfg, cfg,
msg.chat_id, msg.chat_id,
@@ -1203,17 +1139,11 @@ async def _handle_topic_command(
async def handle_cancel( async def handle_cancel(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
running_tasks: RunningTasks, running_tasks: RunningTasks,
) -> None: ) -> None:
reply = partial( reply = _reply_sender(cfg, msg)
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
chat_id = msg.chat_id chat_id = msg.chat_id
reply_id = msg.reply_to_message_id reply_id = msg.reply_to_message_id
@@ -1239,7 +1169,7 @@ async def handle_cancel(
async def handle_callback_cancel( async def handle_callback_cancel(
cfg, cfg: TelegramBridgeConfig,
query: TelegramCallbackQuery, query: TelegramCallbackQuery,
running_tasks: RunningTasks, running_tasks: RunningTasks,
) -> None: ) -> None:
@@ -1572,7 +1502,7 @@ class _TelegramCommandExecutor(CommandExecutor):
async def _dispatch_command( async def _dispatch_command(
cfg, cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
text: str, text: str,
command_id: str, command_id: str,
+94
View File
@@ -0,0 +1,94 @@
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any, Callable, Generic, Protocol, TypeVar
import anyio
import msgspec
T = TypeVar("T", bound="_VersionedState")
class _Logger(Protocol):
def warning(self, event: str, **fields: Any) -> None: ...
class _VersionedState(Protocol):
version: int
class JsonStateStore(Generic[T]):
def __init__(
self,
path: Path,
*,
version: int,
state_type: type[T],
state_factory: Callable[[], T],
log_prefix: str,
logger: _Logger,
) -> None:
self._path = path
self._lock = anyio.Lock()
self._loaded = False
self._mtime_ns: int | None = None
self._state_type = state_type
self._state_factory = state_factory
self._version = version
self._log_prefix = log_prefix
self._logger = logger
self._state = state_factory()
def _stat_mtime_ns(self) -> int | None:
try:
return self._path.stat().st_mtime_ns
except FileNotFoundError:
return None
def _reload_locked_if_needed(self) -> None:
current = self._stat_mtime_ns()
if self._loaded and current == self._mtime_ns:
return
self._load_locked()
def _load_locked(self) -> None:
self._loaded = True
self._mtime_ns = self._stat_mtime_ns()
if self._mtime_ns is None:
self._state = self._state_factory()
return
try:
payload = msgspec.json.decode(
self._path.read_bytes(), type=self._state_type
)
except Exception as exc: # noqa: BLE001
self._logger.warning(
f"{self._log_prefix}.load_failed",
path=str(self._path),
error=str(exc),
error_type=exc.__class__.__name__,
)
self._state = self._state_factory()
return
if payload.version != self._version:
self._logger.warning(
f"{self._log_prefix}.version_mismatch",
path=str(self._path),
version=payload.version,
expected=self._version,
)
self._state = self._state_factory()
return
self._state = payload
def _save_locked(self) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True)
payload = msgspec.to_builtins(self._state)
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2, sort_keys=True)
handle.write("\n")
os.replace(tmp_path, self._path)
self._mtime_ns = self._stat_mtime_ns()
+14 -59
View File
@@ -1,16 +1,14 @@
from __future__ import annotations from __future__ import annotations
import json
import os
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import anyio
import msgspec import msgspec
from ..context import RunContext from ..context import RunContext
from ..logging import get_logger from ..logging import get_logger
from ..model import ResumeToken from ..model import ResumeToken
from .state_store import JsonStateStore
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -82,13 +80,20 @@ def _context_to_state(context: RunContext | None) -> _ContextState | None:
return _ContextState(project=project, branch=branch) return _ContextState(project=project, branch=branch)
class TopicStateStore: def _new_state() -> _TopicState:
return _TopicState(version=STATE_VERSION, threads={})
class TopicStateStore(JsonStateStore[_TopicState]):
def __init__(self, path: Path) -> None: def __init__(self, path: Path) -> None:
self._path = path super().__init__(
self._lock = anyio.Lock() path,
self._loaded = False version=STATE_VERSION,
self._mtime_ns: int | None = None state_type=_TopicState,
self._state = _TopicState(version=STATE_VERSION, threads={}) state_factory=_new_state,
log_prefix="telegram.topic_state",
logger=logger,
)
async def get_thread( async def get_thread(
self, chat_id: int, thread_id: int self, chat_id: int, thread_id: int
@@ -202,56 +207,6 @@ class TopicStateStore:
topic_title=thread.topic_title, topic_title=thread.topic_title,
) )
def _stat_mtime_ns(self) -> int | None:
try:
return self._path.stat().st_mtime_ns
except FileNotFoundError:
return None
def _reload_locked_if_needed(self) -> None:
current = self._stat_mtime_ns()
if self._loaded and current == self._mtime_ns:
return
self._load_locked()
def _load_locked(self) -> None:
self._loaded = True
self._mtime_ns = self._stat_mtime_ns()
if self._mtime_ns is None:
self._state = _TopicState(version=STATE_VERSION, threads={})
return
try:
payload = msgspec.json.decode(self._path.read_bytes(), type=_TopicState)
except Exception as exc:
logger.warning(
"telegram.topic_state.load_failed",
path=str(self._path),
error=str(exc),
error_type=exc.__class__.__name__,
)
self._state = _TopicState(version=STATE_VERSION, threads={})
return
if payload.version != STATE_VERSION:
logger.warning(
"telegram.topic_state.version_mismatch",
path=str(self._path),
version=payload.version,
expected=STATE_VERSION,
)
self._state = _TopicState(version=STATE_VERSION, threads={})
return
self._state = payload
def _save_locked(self) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True)
payload = msgspec.to_builtins(self._state)
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2, sort_keys=True)
handle.write("\n")
os.replace(tmp_path, self._path)
self._mtime_ns = self._stat_mtime_ns()
def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None: def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None:
return self._state.threads.get(_thread_key(chat_id, thread_id)) return self._state.threads.get(_thread_key(chat_id, thread_id))
+112 -60
View File
@@ -3,17 +3,30 @@ from __future__ import annotations
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal, TypeAlias
from .config import ConfigError, ProjectsConfig from .config import ConfigError, ProjectsConfig
from .context import RunContext from .context import RunContext
from .directives import format_context_line, parse_context_line, parse_directives from .directives import (
ParsedDirectives,
format_context_line,
parse_context_line,
parse_directives,
)
from .model import EngineId, ResumeToken from .model import EngineId, ResumeToken
from .plugins import normalize_allowlist from .plugins import normalize_allowlist
from .router import AutoRouter, EngineStatus from .router import AutoRouter, EngineStatus
from .runner import Runner from .runner import Runner
from .worktrees import WorktreeError, resolve_run_cwd from .worktrees import WorktreeError, resolve_run_cwd
ContextSource: TypeAlias = Literal[
"reply_ctx",
"directives",
"ambient",
"default_project",
"none",
]
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class ResolvedMessage: class ResolvedMessage:
@@ -21,13 +34,7 @@ class ResolvedMessage:
resume_token: ResumeToken | None resume_token: ResumeToken | None
engine_override: EngineId | None engine_override: EngineId | None
context: RunContext | None context: RunContext | None
context_source: Literal[ context_source: ContextSource = "none"
"reply_ctx",
"directives",
"ambient",
"default_project",
"none",
] = "none"
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
@@ -58,12 +65,14 @@ class TransportRuntime:
plugin_configs: Mapping[str, Any] | None = None, plugin_configs: Mapping[str, Any] | None = None,
watch_config: bool = False, watch_config: bool = False,
) -> None: ) -> None:
self._router = router self._apply(
self._projects = projects router=router,
self._allowlist = normalize_allowlist(allowlist) projects=projects,
self._config_path = config_path allowlist=allowlist,
self._plugin_configs = dict(plugin_configs or {}) config_path=config_path,
self._watch_config = watch_config plugin_configs=plugin_configs,
watch_config=watch_config,
)
def update( def update(
self, self,
@@ -74,6 +83,25 @@ class TransportRuntime:
config_path: Path | None = None, config_path: Path | None = None,
plugin_configs: Mapping[str, Any] | None = None, plugin_configs: Mapping[str, Any] | None = None,
watch_config: bool = False, watch_config: bool = False,
) -> None:
self._apply(
router=router,
projects=projects,
allowlist=allowlist,
config_path=config_path,
plugin_configs=plugin_configs,
watch_config=watch_config,
)
def _apply(
self,
*,
router: AutoRouter,
projects: ProjectsConfig,
allowlist: Iterable[str] | None,
config_path: Path | None,
plugin_configs: Mapping[str, Any] | None,
watch_config: bool,
) -> None: ) -> None:
self._router = router self._router = router
self._projects = projects self._projects = projects
@@ -162,51 +190,16 @@ class TransportRuntime:
chat_project = self._projects.project_for_chat(chat_id) chat_project = self._projects.project_for_chat(chat_id)
default_project = chat_project or self._projects.default_project default_project = chat_project or self._projects.default_project
context_source: Literal[ context, context_source = self._resolve_context(
"reply_ctx", directives=directives,
"directives", reply_ctx=reply_ctx,
"ambient", ambient_context=ambient_context,
"default_project", default_project=default_project,
"none", )
] = "none" engine_override = self._resolve_engine_override(
context: RunContext | None = None directives_engine=directives.engine,
context=context,
if reply_ctx is not None: )
context = reply_ctx
context_source = "reply_ctx"
else:
project_key = directives.project
branch = directives.branch
if project_key is None:
if ambient_context is not None and ambient_context.project is not None:
project_key = ambient_context.project
else:
project_key = default_project
if branch is None:
if (
ambient_context is not None
and ambient_context.branch is not None
and project_key == ambient_context.project
):
branch = ambient_context.branch
if project_key is not None or branch is not None:
context = RunContext(project=project_key, branch=branch)
if directives.project is not None or directives.branch is not None:
context_source = "directives"
elif ambient_context is not None and ambient_context.project is not None:
context_source = "ambient"
elif default_project is not None:
context_source = "default_project"
engine_override = directives.engine
if engine_override is None and context is not None:
project = (
self._projects.projects.get(context.project)
if context.project is not None
else None
)
if project is not None and project.default_engine is not None:
engine_override = project.default_engine
return ResolvedMessage( return ResolvedMessage(
prompt=directives.prompt, prompt=directives.prompt,
@@ -216,6 +209,65 @@ class TransportRuntime:
context_source=context_source, context_source=context_source,
) )
def _resolve_context(
self,
*,
directives: ParsedDirectives,
reply_ctx: RunContext | None,
ambient_context: RunContext | None,
default_project: str | None,
) -> tuple[RunContext | None, ContextSource]:
if reply_ctx is not None:
return reply_ctx, "reply_ctx"
project_key = directives.project
branch = directives.branch
if project_key is None:
if ambient_context is not None and ambient_context.project is not None:
project_key = ambient_context.project
else:
project_key = default_project
if branch is None:
if (
ambient_context is not None
and ambient_context.branch is not None
and project_key == ambient_context.project
):
branch = ambient_context.branch
context: RunContext | None = None
if project_key is not None or branch is not None:
context = RunContext(project=project_key, branch=branch)
if directives.project is not None or directives.branch is not None:
context_source: ContextSource = "directives"
elif ambient_context is not None and ambient_context.project is not None:
context_source = "ambient"
elif default_project is not None:
context_source = "default_project"
else:
context_source = "none"
return context, context_source
def _resolve_engine_override(
self,
*,
directives_engine: EngineId | None,
context: RunContext | None,
) -> EngineId | None:
if directives_engine is not None:
return directives_engine
if context is None:
return None
project = (
self._projects.projects.get(context.project)
if context.project is not None
else None
)
if project is not None and project.default_engine is not None:
return project.default_engine
return None
@property @property
def default_project(self) -> str | None: def default_project(self) -> str | None:
return self._projects.default_project return self._projects.default_project
+9 -24
View File
@@ -5,14 +5,7 @@ from pathlib import Path
from typing import Iterable, Protocol, runtime_checkable from typing import Iterable, Protocol, runtime_checkable
from .backends import EngineBackend, SetupIssue from .backends import EngineBackend, SetupIssue
from .config import ConfigError from .plugins import TRANSPORT_GROUP, list_ids, load_plugin_backend
from .plugins import (
PluginLoadFailed,
PluginNotFound,
TRANSPORT_GROUP,
load_entrypoint,
list_ids,
)
from .transport_runtime import TransportRuntime from .transport_runtime import TransportRuntime
@@ -67,22 +60,14 @@ def _validate_transport_backend(backend: object, ep) -> None:
def get_transport( def get_transport(
transport_id: str, *, allowlist: Iterable[str] | None = None transport_id: str, *, allowlist: Iterable[str] | None = None
) -> TransportBackend: ) -> TransportBackend:
try: backend = load_plugin_backend(
backend = load_entrypoint( TRANSPORT_GROUP,
TRANSPORT_GROUP, transport_id,
transport_id, allowlist=allowlist,
allowlist=allowlist, validator=_validate_transport_backend,
validator=_validate_transport_backend, kind_label="transport",
) )
except PluginNotFound as exc: assert backend is not None
if exc.available:
available = ", ".join(exc.available)
message = f"Unknown transport {transport_id!r}. Available: {available}."
else:
message = f"Unknown transport {transport_id!r}."
raise ConfigError(message) from exc
except PluginLoadFailed as exc:
raise ConfigError(f"Failed to load transport {transport_id!r}: {exc}") from exc
return backend return backend
+1 -1
View File
@@ -35,7 +35,7 @@ async def drain_stderr(
tag=tag, tag=tag,
line=text, line=text,
) )
except Exception as exc: except Exception as exc: # noqa: BLE001
log_pipeline( log_pipeline(
logger, logger,
"subprocess.stderr.error", "subprocess.stderr.error",
+1 -1
View File
@@ -53,7 +53,7 @@ def _signal_process(
return return
except ProcessLookupError: except ProcessLookupError:
return return
except Exception as exc: except OSError as exc:
logger.debug( logger.debug(
log_event, log_event,
error=str(exc), error=str(exc),
+1 -1
View File
@@ -20,7 +20,7 @@ def _decode_fixture(name: str) -> list[str]:
continue continue
try: try:
decoded = claude_schema.decode_stream_json_line(line) decoded = claude_schema.decode_stream_json_line(line)
except Exception as exc: except Exception as exc: # noqa: BLE001
errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}") errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}")
continue continue
+2 -2
View File
@@ -21,12 +21,12 @@ def _decode_fixture(name: str) -> list[str]:
continue continue
try: try:
json.loads(line) json.loads(line)
except Exception as exc: except Exception as exc: # noqa: BLE001
errors.append(f"line {lineno}: invalid JSON ({exc})") errors.append(f"line {lineno}: invalid JSON ({exc})")
continue continue
try: try:
codex_schema.decode_event(line) codex_schema.decode_event(line)
except Exception as exc: except Exception as exc: # noqa: BLE001
errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}") errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}")
return errors return errors
+1 -1
View File
@@ -20,7 +20,7 @@ def _decode_fixture(name: str) -> list[str]:
continue continue
try: try:
opencode_schema.decode_event(line) opencode_schema.decode_event(line)
except Exception as exc: except Exception as exc: # noqa: BLE001
errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}") errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}")
return errors return errors
+1 -1
View File
@@ -20,7 +20,7 @@ def _decode_fixture(name: str) -> list[str]:
continue continue
try: try:
pi_schema.decode_event(line) pi_schema.decode_event(line)
except Exception as exc: except Exception as exc: # noqa: BLE001
errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}") errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}")
return errors return errors
+1 -1
View File
@@ -90,7 +90,7 @@ def test_projects_default_engine_unknown() -> None:
"projects": {"z80": {"path": "/tmp/repo", "default_engine": "nope"}}, "projects": {"z80": {"path": "/tmp/repo", "default_engine": "nope"}},
} }
settings = TakopiSettings.model_validate(config) settings = TakopiSettings.model_validate(config)
with pytest.raises(ConfigError, match="projects.z80.default_engine"): with pytest.raises(ConfigError, match=r"projects\.z80\.default_engine"):
settings.to_projects_config( settings.to_projects_config(
config_path=Path("takopi.toml"), config_path=Path("takopi.toml"),
engine_ids=["codex"], engine_ids=["codex"],
+1 -1
View File
@@ -172,7 +172,7 @@ def test_transport_config_telegram_and_extra(tmp_path: Path) -> None:
}, },
} }
) )
with pytest.raises(ConfigError, match="transports.discord"): with pytest.raises(ConfigError, match=r"transports\.discord"):
settings.transport_config("discord", config_path=config_path) settings.transport_config("discord", config_path=config_path)