refactor: cleanup, linting, and tooling updates (#108)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
after you finish work, commit with a conventional message. only commit the files you edited.
|
||||
always run `just check` before code commits.
|
||||
if you fix anything from `just check`, rerun it and confirm it passes before committing.
|
||||
when using gh to edit or create PR descriptions, prefer `--body-file` to preserve newlines.
|
||||
|
||||
+8
-1
@@ -61,5 +61,12 @@ dev = [
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = ["--cov=takopi", "--cov-report=term-missing", "--cov-fail-under=70"]
|
||||
addopts = ["--cov=takopi", "--cov-report=term-missing", "--cov-fail-under=75"]
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
extend-select = ["B904", "BLE001", "S110", "RUF043"]
|
||||
|
||||
[tool.ty.src]
|
||||
include = ["src", "tests"]
|
||||
exclude = ["scripts"]
|
||||
|
||||
+53
-37
@@ -12,6 +12,7 @@ from . import __version__
|
||||
from .config import ConfigError, load_or_init_config, write_config
|
||||
from .config_migrations import migrate_config
|
||||
from .commands import get_command
|
||||
from .backends import EngineBackend
|
||||
from .engines import get_backend, list_backend_ids
|
||||
from .ids import RESERVED_COMMAND_IDS, RESERVED_ENGINE_IDS
|
||||
from .lockfile import LockError, LockHandle, acquire_lock, token_fingerprint
|
||||
@@ -108,6 +109,26 @@ def _default_engine_for_setup(
|
||||
return value
|
||||
|
||||
|
||||
def _resolve_setup_engine(
|
||||
default_engine_override: str | None,
|
||||
) -> tuple[
|
||||
TakopiSettings | None,
|
||||
Path | None,
|
||||
list[str] | None,
|
||||
str,
|
||||
EngineBackend,
|
||||
]:
|
||||
settings_hint, config_hint = _load_settings_optional()
|
||||
allowlist = resolve_plugins_allowlist(settings_hint)
|
||||
default_engine = _default_engine_for_setup(
|
||||
default_engine_override,
|
||||
settings=settings_hint,
|
||||
config_path=config_hint,
|
||||
)
|
||||
engine_backend = get_backend(default_engine, allowlist=allowlist)
|
||||
return settings_hint, config_hint, allowlist, default_engine, engine_backend
|
||||
|
||||
|
||||
def _config_path_display(path: Path) -> str:
|
||||
home = Path.home()
|
||||
try:
|
||||
@@ -148,33 +169,31 @@ def _run_auto_router(
|
||||
setup_logging(debug=debug)
|
||||
lock_handle: LockHandle | None = None
|
||||
try:
|
||||
settings_hint, config_hint = _load_settings_optional()
|
||||
allowlist = resolve_plugins_allowlist(settings_hint)
|
||||
default_engine = _default_engine_for_setup(
|
||||
default_engine_override,
|
||||
settings=settings_hint,
|
||||
config_path=config_hint,
|
||||
)
|
||||
engine_backend = get_backend(default_engine, allowlist=allowlist)
|
||||
(
|
||||
settings_hint,
|
||||
config_hint,
|
||||
allowlist,
|
||||
default_engine,
|
||||
engine_backend,
|
||||
) = _resolve_setup_engine(default_engine_override)
|
||||
transport_id = _resolve_transport_id(transport_override)
|
||||
transport_backend = get_transport(transport_id, allowlist=allowlist)
|
||||
except ConfigError as e:
|
||||
typer.echo(f"error: {e}", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
raise typer.Exit(code=1) from e
|
||||
if onboard:
|
||||
if not _should_run_interactive():
|
||||
typer.echo("error: --onboard requires a TTY", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
if not transport_backend.interactive_setup(force=True):
|
||||
raise typer.Exit(code=1)
|
||||
settings_hint, config_hint = _load_settings_optional()
|
||||
allowlist = resolve_plugins_allowlist(settings_hint)
|
||||
default_engine = _default_engine_for_setup(
|
||||
default_engine_override,
|
||||
settings=settings_hint,
|
||||
config_path=config_hint,
|
||||
)
|
||||
engine_backend = get_backend(default_engine, allowlist=allowlist)
|
||||
(
|
||||
settings_hint,
|
||||
config_hint,
|
||||
allowlist,
|
||||
default_engine,
|
||||
engine_backend,
|
||||
) = _resolve_setup_engine(default_engine_override)
|
||||
setup = transport_backend.check_setup(
|
||||
engine_backend,
|
||||
transport_override=transport_override,
|
||||
@@ -189,27 +208,25 @@ def _run_auto_router(
|
||||
default=False,
|
||||
)
|
||||
if run_onboard and transport_backend.interactive_setup(force=True):
|
||||
settings_hint, config_hint = _load_settings_optional()
|
||||
allowlist = resolve_plugins_allowlist(settings_hint)
|
||||
default_engine = _default_engine_for_setup(
|
||||
default_engine_override,
|
||||
settings=settings_hint,
|
||||
config_path=config_hint,
|
||||
)
|
||||
engine_backend = get_backend(default_engine, allowlist=allowlist)
|
||||
(
|
||||
settings_hint,
|
||||
config_hint,
|
||||
allowlist,
|
||||
default_engine,
|
||||
engine_backend,
|
||||
) = _resolve_setup_engine(default_engine_override)
|
||||
setup = transport_backend.check_setup(
|
||||
engine_backend,
|
||||
transport_override=transport_override,
|
||||
)
|
||||
elif transport_backend.interactive_setup(force=False):
|
||||
settings_hint, config_hint = _load_settings_optional()
|
||||
allowlist = resolve_plugins_allowlist(settings_hint)
|
||||
default_engine = _default_engine_for_setup(
|
||||
default_engine_override,
|
||||
settings=settings_hint,
|
||||
config_path=config_hint,
|
||||
)
|
||||
engine_backend = get_backend(default_engine, allowlist=allowlist)
|
||||
(
|
||||
settings_hint,
|
||||
config_hint,
|
||||
allowlist,
|
||||
default_engine,
|
||||
engine_backend,
|
||||
) = _resolve_setup_engine(default_engine_override)
|
||||
setup = transport_backend.check_setup(
|
||||
engine_backend,
|
||||
transport_override=transport_override,
|
||||
@@ -252,10 +269,10 @@ def _run_auto_router(
|
||||
)
|
||||
except ConfigError as e:
|
||||
typer.echo(f"error: {e}", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
raise typer.Exit(code=1) from e
|
||||
except KeyboardInterrupt:
|
||||
logger.info("shutdown.interrupted")
|
||||
raise typer.Exit(code=130)
|
||||
raise typer.Exit(code=130) from None
|
||||
finally:
|
||||
if lock_handle is not None:
|
||||
lock_handle.release()
|
||||
@@ -279,8 +296,7 @@ def _default_alias_from_path(path: Path) -> str | None:
|
||||
name = path.name
|
||||
if not name:
|
||||
return None
|
||||
if name.endswith(".git"):
|
||||
name = name[: -len(".git")]
|
||||
name = name.removesuffix(".git")
|
||||
return name or None
|
||||
|
||||
|
||||
|
||||
+9
-26
@@ -9,13 +9,7 @@ from .config import ConfigError
|
||||
from .context import RunContext
|
||||
from .ids import RESERVED_COMMAND_IDS
|
||||
from .model import EngineId
|
||||
from .plugins import (
|
||||
COMMAND_GROUP,
|
||||
PluginLoadFailed,
|
||||
PluginNotFound,
|
||||
load_entrypoint,
|
||||
list_ids,
|
||||
)
|
||||
from .plugins import COMMAND_GROUP, list_ids, load_plugin_backend
|
||||
from .transport import MessageRef, RenderedMessage
|
||||
from .transport_runtime import TransportRuntime
|
||||
|
||||
@@ -122,25 +116,14 @@ def get_command(
|
||||
) -> CommandBackend | None:
|
||||
if command_id.lower() in RESERVED_COMMAND_IDS:
|
||||
raise ConfigError(f"Command id {command_id!r} is reserved.")
|
||||
try:
|
||||
backend = load_entrypoint(
|
||||
COMMAND_GROUP,
|
||||
command_id,
|
||||
allowlist=allowlist,
|
||||
validator=_validate_command_backend,
|
||||
)
|
||||
except PluginNotFound as exc:
|
||||
if not required:
|
||||
return None
|
||||
if exc.available:
|
||||
available = ", ".join(exc.available)
|
||||
message = f"Unknown command {command_id!r}. Available: {available}."
|
||||
else:
|
||||
message = f"Unknown command {command_id!r}."
|
||||
raise ConfigError(message) from exc
|
||||
except PluginLoadFailed as exc:
|
||||
raise ConfigError(f"Failed to load command {command_id!r}: {exc}") from exc
|
||||
return backend
|
||||
return load_plugin_backend(
|
||||
COMMAND_GROUP,
|
||||
command_id,
|
||||
allowlist=allowlist,
|
||||
validator=_validate_command_backend,
|
||||
kind_label="command",
|
||||
required=required,
|
||||
)
|
||||
|
||||
|
||||
def list_command_ids(*, allowlist: Iterable[str] | None = None) -> list[str]:
|
||||
|
||||
@@ -40,6 +40,13 @@ def config_status(path: Path) -> tuple[str, tuple[int, int] | None]:
|
||||
return "ok", (stat.st_mtime_ns, stat.st_size)
|
||||
|
||||
|
||||
def _matches_config_path(candidate: str, config_path: Path) -> bool:
|
||||
try:
|
||||
return Path(candidate).resolve(strict=False) == config_path
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _reload_config(
|
||||
config_path: Path,
|
||||
default_engine_override: str | None,
|
||||
@@ -76,7 +83,7 @@ async def watch_config(
|
||||
logger.warning("config.watch.unavailable", path=str(config_path), status=status)
|
||||
|
||||
async for changes in awatch(watch_root):
|
||||
if not any(Path(path) == config_path for _, path in changes):
|
||||
if not any(_matches_config_path(path, config_path) for _, path in changes):
|
||||
continue
|
||||
|
||||
status, current = config_status(config_path)
|
||||
|
||||
+9
-23
@@ -4,13 +4,7 @@ from typing import Iterable
|
||||
|
||||
from .backends import EngineBackend
|
||||
from .config import ConfigError
|
||||
from .plugins import (
|
||||
ENGINE_GROUP,
|
||||
PluginLoadFailed,
|
||||
PluginNotFound,
|
||||
load_entrypoint,
|
||||
list_ids,
|
||||
)
|
||||
from .plugins import ENGINE_GROUP, list_ids, load_plugin_backend
|
||||
from .ids import RESERVED_ENGINE_IDS
|
||||
|
||||
|
||||
@@ -28,22 +22,14 @@ def get_backend(
|
||||
) -> EngineBackend:
|
||||
if engine_id.lower() in RESERVED_ENGINE_IDS:
|
||||
raise ConfigError(f"Engine id {engine_id!r} is reserved.")
|
||||
try:
|
||||
backend = load_entrypoint(
|
||||
ENGINE_GROUP,
|
||||
engine_id,
|
||||
allowlist=allowlist,
|
||||
validator=_validate_engine_backend,
|
||||
)
|
||||
except PluginNotFound as exc:
|
||||
if exc.available:
|
||||
available = ", ".join(exc.available)
|
||||
message = f"Unknown engine {engine_id!r}. Available: {available}."
|
||||
else:
|
||||
message = f"Unknown engine {engine_id!r}."
|
||||
raise ConfigError(message) from exc
|
||||
except PluginLoadFailed as exc:
|
||||
raise ConfigError(f"Failed to load engine {engine_id!r}: {exc}") from exc
|
||||
backend = load_plugin_backend(
|
||||
ENGINE_GROUP,
|
||||
engine_id,
|
||||
allowlist=allowlist,
|
||||
validator=_validate_engine_backend,
|
||||
kind_label="engine",
|
||||
)
|
||||
assert backend is not None
|
||||
return backend
|
||||
|
||||
|
||||
|
||||
@@ -126,8 +126,8 @@ def _file_sink(
|
||||
payload = payload.decode("utf-8", errors="replace")
|
||||
_log_file_handle.write(payload + "\n")
|
||||
_log_file_handle.flush()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception: # noqa: BLE001
|
||||
return event_dict
|
||||
return event_dict
|
||||
|
||||
|
||||
@@ -202,8 +202,8 @@ class SafeWriter(io.TextIOBase):
|
||||
self._closed = True
|
||||
try:
|
||||
self._stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception: # noqa: BLE001
|
||||
return
|
||||
|
||||
|
||||
def setup_logging(
|
||||
@@ -236,13 +236,14 @@ def setup_logging(
|
||||
if _log_file_handle is not None:
|
||||
try:
|
||||
_log_file_handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
_log_file_handle = None
|
||||
except Exception: # noqa: BLE001
|
||||
_log_file_handle = None
|
||||
else:
|
||||
_log_file_handle = None
|
||||
if log_file:
|
||||
try:
|
||||
_log_file_handle = open(log_file, "a", encoding="utf-8")
|
||||
except Exception:
|
||||
except OSError:
|
||||
_log_file_handle = None
|
||||
|
||||
processors = cast(
|
||||
|
||||
+34
-1
@@ -100,7 +100,7 @@ def entrypoint_distribution_name(ep: EntryPoint) -> str | None:
|
||||
return None
|
||||
try:
|
||||
return metadata["Name"]
|
||||
except Exception:
|
||||
except (KeyError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
@@ -281,3 +281,36 @@ def load_entrypoint(
|
||||
_LOADED[key] = loaded
|
||||
clear_load_errors(group=group, name=name)
|
||||
return loaded
|
||||
|
||||
|
||||
def load_plugin_backend(
|
||||
group: str,
|
||||
name: str,
|
||||
*,
|
||||
allowlist: Iterable[str] | None = None,
|
||||
validator: Callable[[Any, EntryPoint], None] | None = None,
|
||||
kind_label: str,
|
||||
required: bool = True,
|
||||
) -> Any | None:
|
||||
try:
|
||||
return load_entrypoint(
|
||||
group,
|
||||
name,
|
||||
allowlist=allowlist,
|
||||
validator=validator,
|
||||
)
|
||||
except PluginNotFound as exc:
|
||||
if not required:
|
||||
return None
|
||||
if exc.available:
|
||||
available = ", ".join(exc.available)
|
||||
message = f"Unknown {kind_label} {name!r}. Available: {available}."
|
||||
else:
|
||||
message = f"Unknown {kind_label} {name!r}."
|
||||
from .config import ConfigError
|
||||
|
||||
raise ConfigError(message) from exc
|
||||
except PluginLoadFailed as exc:
|
||||
from .config import ConfigError
|
||||
|
||||
raise ConfigError(f"Failed to load {kind_label} {name!r}: {exc}") from exc
|
||||
|
||||
@@ -427,7 +427,7 @@ class JsonlSubprocessRunner(BaseRunner):
|
||||
line_text = line.decode("utf-8", errors="replace")
|
||||
try:
|
||||
decoded = self.decode_jsonl(line=line)
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log_pipeline(
|
||||
logger,
|
||||
"jsonl.parse.error",
|
||||
@@ -470,7 +470,7 @@ class JsonlSubprocessRunner(BaseRunner):
|
||||
resume=resume,
|
||||
found_session=found_session,
|
||||
)
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log_pipeline(
|
||||
logger,
|
||||
"runner.translate.error",
|
||||
|
||||
@@ -155,15 +155,12 @@ def _tool_result_event(
|
||||
normalized = _normalize_tool_result(raw_result)
|
||||
preview = normalized
|
||||
|
||||
detail = dict(action.detail)
|
||||
detail.update(
|
||||
{
|
||||
"tool_use_id": content.tool_use_id,
|
||||
"result_preview": preview,
|
||||
"result_len": len(normalized),
|
||||
"is_error": is_error,
|
||||
}
|
||||
)
|
||||
detail = action.detail | {
|
||||
"tool_use_id": content.tool_use_id,
|
||||
"result_preview": preview,
|
||||
"result_len": len(normalized),
|
||||
"is_error": is_error,
|
||||
}
|
||||
return factory.action_completed(
|
||||
action_id=action.id,
|
||||
kind=action.kind,
|
||||
@@ -367,7 +364,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: Any,
|
||||
) -> list[str]:
|
||||
_ = state
|
||||
return self._build_args(prompt, resume)
|
||||
|
||||
def stdin_payload(
|
||||
@@ -377,11 +373,9 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: Any,
|
||||
) -> bytes | None:
|
||||
_ = prompt, resume, state
|
||||
return None
|
||||
|
||||
def env(self, *, state: Any) -> dict[str, str] | None:
|
||||
_ = state
|
||||
if self.use_api_billing is not True:
|
||||
env = dict(os.environ)
|
||||
env.pop("ANTHROPIC_API_KEY", None)
|
||||
@@ -389,7 +383,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
return None
|
||||
|
||||
def new_state(self, prompt: str, resume: ResumeToken | None) -> ClaudeStreamState:
|
||||
_ = prompt, resume
|
||||
return ClaudeStreamState()
|
||||
|
||||
def start_run(
|
||||
@@ -399,7 +392,7 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: ClaudeStreamState,
|
||||
) -> None:
|
||||
_ = state, prompt, resume
|
||||
pass
|
||||
|
||||
def decode_jsonl(
|
||||
self,
|
||||
@@ -416,7 +409,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
error: Exception,
|
||||
state: ClaudeStreamState,
|
||||
) -> list[TakopiEvent]:
|
||||
_ = raw, line, state
|
||||
if isinstance(error, msgspec.DecodeError):
|
||||
self.get_logger().warning(
|
||||
"jsonl.msgspec.invalid",
|
||||
@@ -439,7 +431,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
line: str,
|
||||
state: ClaudeStreamState,
|
||||
) -> list[TakopiEvent]:
|
||||
_ = raw, line, state
|
||||
return []
|
||||
|
||||
def translate(
|
||||
@@ -450,7 +441,6 @@ class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
resume: ResumeToken | None,
|
||||
found_session: ResumeToken | None,
|
||||
) -> list[TakopiEvent]:
|
||||
_ = resume, found_session
|
||||
return translate_claude_event(
|
||||
data,
|
||||
title=self.session_title,
|
||||
|
||||
@@ -426,7 +426,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: Any,
|
||||
) -> list[str]:
|
||||
_ = prompt, state
|
||||
args = [
|
||||
*self.extra_args,
|
||||
"exec",
|
||||
@@ -441,7 +440,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
return args
|
||||
|
||||
def new_state(self, prompt: str, resume: ResumeToken | None) -> CodexRunState:
|
||||
_ = prompt, resume
|
||||
return CodexRunState(factory=EventFactory(ENGINE))
|
||||
|
||||
def start_run(
|
||||
@@ -451,7 +449,7 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: CodexRunState,
|
||||
) -> None:
|
||||
_ = state, prompt, resume
|
||||
pass
|
||||
|
||||
def decode_jsonl(self, *, line: bytes) -> codex_schema.ThreadEvent:
|
||||
return codex_schema.decode_event(line)
|
||||
@@ -464,7 +462,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
error: Exception,
|
||||
state: CodexRunState,
|
||||
) -> list[TakopiEvent]:
|
||||
_ = raw, line
|
||||
if isinstance(error, msgspec.DecodeError):
|
||||
self.get_logger().warning(
|
||||
"jsonl.msgspec.invalid",
|
||||
|
||||
@@ -84,7 +84,6 @@ class MockRunner(SessionLockMixin, ResumeTokenMixin, Runner):
|
||||
async def run(
|
||||
self, prompt: str, resume: ResumeToken | None
|
||||
) -> AsyncIterator[TakopiEvent]:
|
||||
_ = prompt
|
||||
token_value = None
|
||||
if resume is not None:
|
||||
if resume.engine != self.engine:
|
||||
@@ -158,7 +157,6 @@ class ScriptRunner(MockRunner):
|
||||
self, prompt: str, resume: ResumeToken | None
|
||||
) -> AsyncIterator[TakopiEvent]:
|
||||
self.calls.append((prompt, resume))
|
||||
_ = prompt
|
||||
token_value = None
|
||||
if resume is not None:
|
||||
if resume.engine != self.engine:
|
||||
|
||||
@@ -367,7 +367,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: Any,
|
||||
) -> list[str]:
|
||||
_ = state
|
||||
args = ["run", "--format", "json"]
|
||||
if resume is not None:
|
||||
args.extend(["--session", resume.value])
|
||||
@@ -383,11 +382,9 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: Any,
|
||||
) -> bytes | None:
|
||||
_ = prompt, resume, state
|
||||
return None
|
||||
|
||||
def new_state(self, prompt: str, resume: ResumeToken | None) -> OpenCodeStreamState:
|
||||
_ = prompt, resume
|
||||
return OpenCodeStreamState()
|
||||
|
||||
def start_run(
|
||||
@@ -397,7 +394,7 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: OpenCodeStreamState,
|
||||
) -> None:
|
||||
_ = state, prompt, resume
|
||||
pass
|
||||
|
||||
def invalid_json_events(
|
||||
self,
|
||||
@@ -406,7 +403,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
line: str,
|
||||
state: OpenCodeStreamState,
|
||||
) -> list[TakopiEvent]:
|
||||
_ = line
|
||||
message = "invalid JSON from opencode; ignoring line"
|
||||
return [self.note_event(message, state=state, detail={"line": raw})]
|
||||
|
||||
@@ -418,7 +414,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
resume: ResumeToken | None,
|
||||
found_session: ResumeToken | None,
|
||||
) -> list[TakopiEvent]:
|
||||
_ = resume, found_session
|
||||
return translate_opencode_event(
|
||||
data,
|
||||
title=self.session_title,
|
||||
@@ -436,7 +431,6 @@ class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
error: Exception,
|
||||
state: OpenCodeStreamState,
|
||||
) -> list[TakopiEvent]:
|
||||
_ = raw, line, state
|
||||
if isinstance(error, msgspec.DecodeError):
|
||||
self.get_logger().warning(
|
||||
"jsonl.msgspec.invalid",
|
||||
|
||||
@@ -290,7 +290,6 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: PiStreamState,
|
||||
) -> list[str]:
|
||||
_ = resume
|
||||
args: list[str] = [*self.extra_args, "--print", "--mode", "json"]
|
||||
if self.provider:
|
||||
args.extend(["--provider", self.provider])
|
||||
@@ -307,18 +306,15 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
*,
|
||||
state: PiStreamState,
|
||||
) -> bytes | None:
|
||||
_ = prompt, resume, state
|
||||
return None
|
||||
|
||||
def env(self, *, state: PiStreamState) -> dict[str, str] | None:
|
||||
_ = state
|
||||
env = dict(os.environ)
|
||||
env.setdefault("NO_COLOR", "1")
|
||||
env.setdefault("CI", "1")
|
||||
return env
|
||||
|
||||
def new_state(self, prompt: str, resume: ResumeToken | None) -> PiStreamState:
|
||||
_ = prompt
|
||||
if resume is None:
|
||||
session_path = self._new_session_path()
|
||||
token = ResumeToken(engine=ENGINE, value=session_path)
|
||||
@@ -334,7 +330,6 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
resume: ResumeToken | None,
|
||||
found_session: ResumeToken | None,
|
||||
) -> list[TakopiEvent]:
|
||||
_ = resume, found_session
|
||||
meta: dict[str, Any] = {"cwd": os.getcwd()}
|
||||
if self.model:
|
||||
meta["model"] = self.model
|
||||
@@ -362,7 +357,6 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
error: Exception,
|
||||
state: PiStreamState,
|
||||
) -> list[TakopiEvent]:
|
||||
_ = raw, line, state
|
||||
if isinstance(error, msgspec.DecodeError):
|
||||
self.get_logger().warning(
|
||||
"jsonl.msgspec.invalid",
|
||||
|
||||
@@ -103,7 +103,7 @@ def build_router(
|
||||
if engine_cfg:
|
||||
try:
|
||||
runner = backend.build_runner({}, config_path)
|
||||
except Exception as fallback_exc:
|
||||
except Exception as fallback_exc: # noqa: BLE001
|
||||
warnings.append(f"{engine_id}: {issue or str(fallback_exc)}")
|
||||
continue
|
||||
status = "bad_config"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Iterable, Literal
|
||||
from typing import Annotated, Any, ClassVar, Iterable, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -62,6 +62,9 @@ class TelegramTopicsSettings(BaseModel):
|
||||
class TelegramFilesSettings(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
|
||||
|
||||
max_upload_bytes: ClassVar[int] = 20 * 1024 * 1024
|
||||
max_download_bytes: ClassVar[int] = 50 * 1024 * 1024
|
||||
|
||||
enabled: bool = False
|
||||
auto_put: bool = True
|
||||
auto_put_mode: Literal["upload", "prompt"] = "upload"
|
||||
@@ -84,14 +87,6 @@ class TelegramFilesSettings(BaseModel):
|
||||
raise ValueError("files.uploads_dir must be a relative path")
|
||||
return value
|
||||
|
||||
@property
|
||||
def max_upload_bytes(self) -> int:
|
||||
return 20 * 1024 * 1024
|
||||
|
||||
@property
|
||||
def max_download_bytes(self) -> int:
|
||||
return 50 * 1024 * 1024
|
||||
|
||||
|
||||
class TelegramTransportSettings(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
import msgspec
|
||||
|
||||
from ..logging import get_logger
|
||||
from ..model import ResumeToken
|
||||
from .state_store import JsonStateStore
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -38,13 +36,20 @@ def _chat_key(chat_id: int, owner_id: int | None) -> str:
|
||||
return f"{chat_id}:{owner}"
|
||||
|
||||
|
||||
class ChatSessionStore:
|
||||
def _new_state() -> _ChatSessionsState:
|
||||
return _ChatSessionsState(version=STATE_VERSION, chats={})
|
||||
|
||||
|
||||
class ChatSessionStore(JsonStateStore[_ChatSessionsState]):
|
||||
def __init__(self, path: Path) -> None:
|
||||
self._path = path
|
||||
self._lock = anyio.Lock()
|
||||
self._loaded = False
|
||||
self._mtime_ns: int | None = None
|
||||
self._state = _ChatSessionsState(version=STATE_VERSION, chats={})
|
||||
super().__init__(
|
||||
path,
|
||||
version=STATE_VERSION,
|
||||
state_type=_ChatSessionsState,
|
||||
state_factory=_new_state,
|
||||
log_prefix="telegram.chat_sessions",
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
async def get_session_resume(
|
||||
self, chat_id: int, owner_id: int | None, engine: str
|
||||
@@ -77,58 +82,6 @@ class ChatSessionStore:
|
||||
chat.sessions = {}
|
||||
self._save_locked()
|
||||
|
||||
def _stat_mtime_ns(self) -> int | None:
|
||||
try:
|
||||
return self._path.stat().st_mtime_ns
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def _reload_locked_if_needed(self) -> None:
|
||||
current = self._stat_mtime_ns()
|
||||
if self._loaded and current == self._mtime_ns:
|
||||
return
|
||||
self._load_locked()
|
||||
|
||||
def _load_locked(self) -> None:
|
||||
self._loaded = True
|
||||
self._mtime_ns = self._stat_mtime_ns()
|
||||
if self._mtime_ns is None:
|
||||
self._state = _ChatSessionsState(version=STATE_VERSION, chats={})
|
||||
return
|
||||
try:
|
||||
payload = msgspec.json.decode(
|
||||
self._path.read_bytes(), type=_ChatSessionsState
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"telegram.chat_sessions.load_failed",
|
||||
path=str(self._path),
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
self._state = _ChatSessionsState(version=STATE_VERSION, chats={})
|
||||
return
|
||||
if payload.version != STATE_VERSION:
|
||||
logger.warning(
|
||||
"telegram.chat_sessions.version_mismatch",
|
||||
path=str(self._path),
|
||||
version=payload.version,
|
||||
expected=STATE_VERSION,
|
||||
)
|
||||
self._state = _ChatSessionsState(version=STATE_VERSION, chats={})
|
||||
return
|
||||
self._state = payload
|
||||
|
||||
def _save_locked(self) -> None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = msgspec.to_builtins(self._state)
|
||||
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
|
||||
with open(tmp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(payload, handle, indent=2, sort_keys=True)
|
||||
handle.write("\n")
|
||||
os.replace(tmp_path, self._path)
|
||||
self._mtime_ns = self._stat_mtime_ns()
|
||||
|
||||
def _get_chat_locked(self, chat_id: int, owner_id: int | None) -> _ChatState | None:
|
||||
return self._state.chats.get(_chat_key(chat_id, owner_id))
|
||||
|
||||
|
||||
@@ -515,7 +515,7 @@ class TelegramOutbox:
|
||||
async def execute_op(self, op: OutboxOp) -> Any:
|
||||
try:
|
||||
return await op.execute()
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if isinstance(exc, RetryAfter):
|
||||
raise
|
||||
if self._on_error is not None:
|
||||
@@ -566,7 +566,7 @@ class TelegramOutbox:
|
||||
op.set_result(result)
|
||||
except cancel_exc:
|
||||
return
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
async with self._cond:
|
||||
self._closed = True
|
||||
self.fail_pending()
|
||||
@@ -759,7 +759,7 @@ class TelegramClient:
|
||||
retry_after: float | None = None
|
||||
try:
|
||||
response_payload = resp.json()
|
||||
except Exception:
|
||||
except Exception: # noqa: BLE001
|
||||
response_payload = None
|
||||
if isinstance(response_payload, dict):
|
||||
retry_after = retry_after_from_payload(response_payload)
|
||||
@@ -785,7 +785,7 @@ class TelegramClient:
|
||||
|
||||
try:
|
||||
response_payload = resp.json()
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
body = resp.text
|
||||
logger.error(
|
||||
"telegram.bad_response",
|
||||
@@ -815,7 +815,7 @@ class TelegramClient:
|
||||
return None
|
||||
try:
|
||||
return msgspec.convert(payload, type=model)
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(
|
||||
"telegram.decode_error",
|
||||
method=method,
|
||||
@@ -862,7 +862,7 @@ class TelegramClient:
|
||||
return None
|
||||
try:
|
||||
return msgspec.convert(raw, type=list[Update])
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(
|
||||
"telegram.decode_error",
|
||||
method="getUpdates",
|
||||
@@ -881,7 +881,7 @@ class TelegramClient:
|
||||
return None
|
||||
try:
|
||||
return msgspec.convert(result, type=list[Update])
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(
|
||||
"telegram.decode_error",
|
||||
method="getUpdates",
|
||||
@@ -928,7 +928,7 @@ class TelegramClient:
|
||||
retry_after: float | None = None
|
||||
try:
|
||||
response_payload = resp.json()
|
||||
except Exception:
|
||||
except Exception: # noqa: BLE001
|
||||
response_payload = None
|
||||
if isinstance(response_payload, dict):
|
||||
retry_after = retry_after_from_payload(response_payload)
|
||||
|
||||
+50
-120
@@ -37,7 +37,7 @@ from ..scheduler import ThreadScheduler
|
||||
from ..transport import MessageRef, RenderedMessage, SendOptions
|
||||
from ..transport_runtime import ResolvedMessage, TransportRuntime
|
||||
from ..utils.paths import reset_run_base_dir, set_run_base_dir
|
||||
from .bridge import send_plain
|
||||
from .bridge import TelegramBridgeConfig, send_plain
|
||||
from .chat_sessions import ChatSessionStore
|
||||
from .context import (
|
||||
_format_context,
|
||||
@@ -203,13 +203,25 @@ def _reserved_commands(runtime: TransportRuntime) -> set[str]:
|
||||
}
|
||||
|
||||
|
||||
async def _set_command_menu(cfg) -> None:
|
||||
def _reply_sender(
|
||||
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
return partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
|
||||
|
||||
async def _set_command_menu(cfg: TelegramBridgeConfig) -> None:
|
||||
commands = build_bot_commands(cfg.runtime, include_file=cfg.files.enabled)
|
||||
if not commands:
|
||||
return
|
||||
try:
|
||||
ok = await cfg.bot.set_my_commands(commands)
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.info(
|
||||
"startup.command_menu.failed",
|
||||
error=str(exc),
|
||||
@@ -297,7 +309,7 @@ def _should_show_resume_line(
|
||||
def resolve_file_put_paths(
|
||||
plan: _FilePutPlan,
|
||||
*,
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
require_dir: bool,
|
||||
) -> tuple[Path | None, Path | None, str | None]:
|
||||
path_value = plan.path_value
|
||||
@@ -322,14 +334,10 @@ def resolve_file_put_paths(
|
||||
return None, rel_path, None
|
||||
|
||||
|
||||
async def _check_file_permissions(cfg, msg: TelegramIncomingMessage) -> bool:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
async def _check_file_permissions(
|
||||
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
|
||||
) -> bool:
|
||||
reply = _reply_sender(cfg, msg)
|
||||
sender_id = msg.sender_id
|
||||
if sender_id is None:
|
||||
await reply(text="cannot verify sender for file transfer.")
|
||||
@@ -355,19 +363,13 @@ async def _check_file_permissions(cfg, msg: TelegramIncomingMessage) -> bool:
|
||||
|
||||
|
||||
async def _prepare_file_put_plan(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> _FilePutPlan | None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
if not await _check_file_permissions(cfg, msg):
|
||||
return None
|
||||
try:
|
||||
@@ -423,7 +425,7 @@ def _format_file_put_failures(failed: Sequence[_FilePutResult]) -> str | None:
|
||||
|
||||
|
||||
async def _save_document_payload(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
*,
|
||||
document: TelegramDocument,
|
||||
run_root: Path,
|
||||
@@ -524,19 +526,13 @@ async def _save_document_payload(
|
||||
|
||||
|
||||
async def _handle_file_command(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
command, rest, error = parse_file_command(args_text)
|
||||
if error is not None:
|
||||
await reply(text=error)
|
||||
@@ -548,7 +544,7 @@ async def _handle_file_command(
|
||||
|
||||
|
||||
async def _handle_file_put_default(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
@@ -557,19 +553,13 @@ async def _handle_file_put_default(
|
||||
|
||||
|
||||
async def _save_file_put(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> _SavedFilePut | None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
document = msg.document
|
||||
if document is None:
|
||||
await reply(text=FILE_PUT_USAGE)
|
||||
@@ -613,19 +603,13 @@ async def _save_file_put(
|
||||
|
||||
|
||||
async def _handle_file_put(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
saved = await _save_file_put(
|
||||
cfg,
|
||||
msg,
|
||||
@@ -645,20 +629,14 @@ async def _handle_file_put(
|
||||
|
||||
|
||||
async def _handle_file_put_group(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
messages: Sequence[TelegramIncomingMessage],
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
saved_group = await _save_file_put_group(
|
||||
cfg,
|
||||
msg,
|
||||
@@ -700,20 +678,14 @@ async def _handle_file_put_group(
|
||||
|
||||
|
||||
async def _save_file_put_group(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
messages: Sequence[TelegramIncomingMessage],
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> _SavedFilePutGroup | None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
documents = [item.document for item in messages if item.document is not None]
|
||||
if not documents:
|
||||
await reply(text=FILE_PUT_USAGE)
|
||||
@@ -759,7 +731,7 @@ async def _save_file_put_group(
|
||||
|
||||
|
||||
async def _handle_media_group(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
messages: Sequence[TelegramIncomingMessage],
|
||||
topic_store: TopicStateStore | None,
|
||||
run_prompt: Callable[
|
||||
@@ -779,13 +751,7 @@ async def _handle_media_group(
|
||||
(item for item in ordered if item.text.strip()),
|
||||
ordered[0],
|
||||
)
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=command_msg.chat_id,
|
||||
user_msg_id=command_msg.message_id,
|
||||
thread_id=command_msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, command_msg)
|
||||
topic_key = _topic_key(command_msg, cfg) if topic_store is not None else None
|
||||
chat_project = _topics_chat_project(cfg, command_msg.chat_id)
|
||||
bound_context = (
|
||||
@@ -884,19 +850,13 @@ async def _handle_media_group(
|
||||
|
||||
|
||||
async def _handle_file_get(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
if not await _check_file_permissions(cfg, msg):
|
||||
return
|
||||
try:
|
||||
@@ -989,7 +949,7 @@ async def _handle_file_get(
|
||||
|
||||
|
||||
async def _handle_ctx_command(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
store: TopicStateStore,
|
||||
@@ -997,13 +957,7 @@ async def _handle_ctx_command(
|
||||
resolved_scope: str | None = None,
|
||||
scope_chat_ids: frozenset[int] | None = None,
|
||||
) -> None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
error = _topics_command_error(
|
||||
cfg,
|
||||
msg.chat_id,
|
||||
@@ -1081,20 +1035,14 @@ async def _handle_ctx_command(
|
||||
|
||||
|
||||
async def _handle_new_command(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
store: TopicStateStore,
|
||||
*,
|
||||
resolved_scope: str | None = None,
|
||||
scope_chat_ids: frozenset[int] | None = None,
|
||||
) -> None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
error = _topics_command_error(
|
||||
cfg,
|
||||
msg.chat_id,
|
||||
@@ -1113,18 +1061,12 @@ async def _handle_new_command(
|
||||
|
||||
|
||||
async def _handle_chat_new_command(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
store: ChatSessionStore,
|
||||
session_key: tuple[int, int | None] | None,
|
||||
) -> None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
if session_key is None:
|
||||
await reply(text="no stored sessions to clear for this chat.")
|
||||
return
|
||||
@@ -1137,7 +1079,7 @@ async def _handle_chat_new_command(
|
||||
|
||||
|
||||
async def _handle_topic_command(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
store: TopicStateStore,
|
||||
@@ -1145,13 +1087,7 @@ async def _handle_topic_command(
|
||||
resolved_scope: str | None = None,
|
||||
scope_chat_ids: frozenset[int] | None = None,
|
||||
) -> None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
error = _topics_command_error(
|
||||
cfg,
|
||||
msg.chat_id,
|
||||
@@ -1203,17 +1139,11 @@ async def _handle_topic_command(
|
||||
|
||||
|
||||
async def handle_cancel(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
running_tasks: RunningTasks,
|
||||
) -> None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = _reply_sender(cfg, msg)
|
||||
chat_id = msg.chat_id
|
||||
reply_id = msg.reply_to_message_id
|
||||
|
||||
@@ -1239,7 +1169,7 @@ async def handle_cancel(
|
||||
|
||||
|
||||
async def handle_callback_cancel(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
query: TelegramCallbackQuery,
|
||||
running_tasks: RunningTasks,
|
||||
) -> None:
|
||||
@@ -1572,7 +1502,7 @@ class _TelegramCommandExecutor(CommandExecutor):
|
||||
|
||||
|
||||
async def _dispatch_command(
|
||||
cfg,
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
text: str,
|
||||
command_id: str,
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generic, Protocol, TypeVar
|
||||
|
||||
import anyio
|
||||
import msgspec
|
||||
|
||||
T = TypeVar("T", bound="_VersionedState")
|
||||
|
||||
|
||||
class _Logger(Protocol):
|
||||
def warning(self, event: str, **fields: Any) -> None: ...
|
||||
|
||||
|
||||
class _VersionedState(Protocol):
|
||||
version: int
|
||||
|
||||
|
||||
class JsonStateStore(Generic[T]):
|
||||
def __init__(
|
||||
self,
|
||||
path: Path,
|
||||
*,
|
||||
version: int,
|
||||
state_type: type[T],
|
||||
state_factory: Callable[[], T],
|
||||
log_prefix: str,
|
||||
logger: _Logger,
|
||||
) -> None:
|
||||
self._path = path
|
||||
self._lock = anyio.Lock()
|
||||
self._loaded = False
|
||||
self._mtime_ns: int | None = None
|
||||
self._state_type = state_type
|
||||
self._state_factory = state_factory
|
||||
self._version = version
|
||||
self._log_prefix = log_prefix
|
||||
self._logger = logger
|
||||
self._state = state_factory()
|
||||
|
||||
def _stat_mtime_ns(self) -> int | None:
|
||||
try:
|
||||
return self._path.stat().st_mtime_ns
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def _reload_locked_if_needed(self) -> None:
|
||||
current = self._stat_mtime_ns()
|
||||
if self._loaded and current == self._mtime_ns:
|
||||
return
|
||||
self._load_locked()
|
||||
|
||||
def _load_locked(self) -> None:
|
||||
self._loaded = True
|
||||
self._mtime_ns = self._stat_mtime_ns()
|
||||
if self._mtime_ns is None:
|
||||
self._state = self._state_factory()
|
||||
return
|
||||
try:
|
||||
payload = msgspec.json.decode(
|
||||
self._path.read_bytes(), type=self._state_type
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._logger.warning(
|
||||
f"{self._log_prefix}.load_failed",
|
||||
path=str(self._path),
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
self._state = self._state_factory()
|
||||
return
|
||||
if payload.version != self._version:
|
||||
self._logger.warning(
|
||||
f"{self._log_prefix}.version_mismatch",
|
||||
path=str(self._path),
|
||||
version=payload.version,
|
||||
expected=self._version,
|
||||
)
|
||||
self._state = self._state_factory()
|
||||
return
|
||||
self._state = payload
|
||||
|
||||
def _save_locked(self) -> None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = msgspec.to_builtins(self._state)
|
||||
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
|
||||
with open(tmp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(payload, handle, indent=2, sort_keys=True)
|
||||
handle.write("\n")
|
||||
os.replace(tmp_path, self._path)
|
||||
self._mtime_ns = self._stat_mtime_ns()
|
||||
@@ -1,16 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
import msgspec
|
||||
|
||||
from ..context import RunContext
|
||||
from ..logging import get_logger
|
||||
from ..model import ResumeToken
|
||||
from .state_store import JsonStateStore
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -82,13 +80,20 @@ def _context_to_state(context: RunContext | None) -> _ContextState | None:
|
||||
return _ContextState(project=project, branch=branch)
|
||||
|
||||
|
||||
class TopicStateStore:
|
||||
def _new_state() -> _TopicState:
|
||||
return _TopicState(version=STATE_VERSION, threads={})
|
||||
|
||||
|
||||
class TopicStateStore(JsonStateStore[_TopicState]):
|
||||
def __init__(self, path: Path) -> None:
|
||||
self._path = path
|
||||
self._lock = anyio.Lock()
|
||||
self._loaded = False
|
||||
self._mtime_ns: int | None = None
|
||||
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||
super().__init__(
|
||||
path,
|
||||
version=STATE_VERSION,
|
||||
state_type=_TopicState,
|
||||
state_factory=_new_state,
|
||||
log_prefix="telegram.topic_state",
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
async def get_thread(
|
||||
self, chat_id: int, thread_id: int
|
||||
@@ -202,56 +207,6 @@ class TopicStateStore:
|
||||
topic_title=thread.topic_title,
|
||||
)
|
||||
|
||||
def _stat_mtime_ns(self) -> int | None:
|
||||
try:
|
||||
return self._path.stat().st_mtime_ns
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def _reload_locked_if_needed(self) -> None:
|
||||
current = self._stat_mtime_ns()
|
||||
if self._loaded and current == self._mtime_ns:
|
||||
return
|
||||
self._load_locked()
|
||||
|
||||
def _load_locked(self) -> None:
|
||||
self._loaded = True
|
||||
self._mtime_ns = self._stat_mtime_ns()
|
||||
if self._mtime_ns is None:
|
||||
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||
return
|
||||
try:
|
||||
payload = msgspec.json.decode(self._path.read_bytes(), type=_TopicState)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"telegram.topic_state.load_failed",
|
||||
path=str(self._path),
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||
return
|
||||
if payload.version != STATE_VERSION:
|
||||
logger.warning(
|
||||
"telegram.topic_state.version_mismatch",
|
||||
path=str(self._path),
|
||||
version=payload.version,
|
||||
expected=STATE_VERSION,
|
||||
)
|
||||
self._state = _TopicState(version=STATE_VERSION, threads={})
|
||||
return
|
||||
self._state = payload
|
||||
|
||||
def _save_locked(self) -> None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = msgspec.to_builtins(self._state)
|
||||
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
|
||||
with open(tmp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(payload, handle, indent=2, sort_keys=True)
|
||||
handle.write("\n")
|
||||
os.replace(tmp_path, self._path)
|
||||
self._mtime_ns = self._stat_mtime_ns()
|
||||
|
||||
def _get_thread_locked(self, chat_id: int, thread_id: int) -> _ThreadState | None:
|
||||
return self._state.threads.get(_thread_key(chat_id, thread_id))
|
||||
|
||||
|
||||
+112
-60
@@ -3,17 +3,30 @@ from __future__ import annotations
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from .config import ConfigError, ProjectsConfig
|
||||
from .context import RunContext
|
||||
from .directives import format_context_line, parse_context_line, parse_directives
|
||||
from .directives import (
|
||||
ParsedDirectives,
|
||||
format_context_line,
|
||||
parse_context_line,
|
||||
parse_directives,
|
||||
)
|
||||
from .model import EngineId, ResumeToken
|
||||
from .plugins import normalize_allowlist
|
||||
from .router import AutoRouter, EngineStatus
|
||||
from .runner import Runner
|
||||
from .worktrees import WorktreeError, resolve_run_cwd
|
||||
|
||||
ContextSource: TypeAlias = Literal[
|
||||
"reply_ctx",
|
||||
"directives",
|
||||
"ambient",
|
||||
"default_project",
|
||||
"none",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ResolvedMessage:
|
||||
@@ -21,13 +34,7 @@ class ResolvedMessage:
|
||||
resume_token: ResumeToken | None
|
||||
engine_override: EngineId | None
|
||||
context: RunContext | None
|
||||
context_source: Literal[
|
||||
"reply_ctx",
|
||||
"directives",
|
||||
"ambient",
|
||||
"default_project",
|
||||
"none",
|
||||
] = "none"
|
||||
context_source: ContextSource = "none"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@@ -58,12 +65,14 @@ class TransportRuntime:
|
||||
plugin_configs: Mapping[str, Any] | None = None,
|
||||
watch_config: bool = False,
|
||||
) -> None:
|
||||
self._router = router
|
||||
self._projects = projects
|
||||
self._allowlist = normalize_allowlist(allowlist)
|
||||
self._config_path = config_path
|
||||
self._plugin_configs = dict(plugin_configs or {})
|
||||
self._watch_config = watch_config
|
||||
self._apply(
|
||||
router=router,
|
||||
projects=projects,
|
||||
allowlist=allowlist,
|
||||
config_path=config_path,
|
||||
plugin_configs=plugin_configs,
|
||||
watch_config=watch_config,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -74,6 +83,25 @@ class TransportRuntime:
|
||||
config_path: Path | None = None,
|
||||
plugin_configs: Mapping[str, Any] | None = None,
|
||||
watch_config: bool = False,
|
||||
) -> None:
|
||||
self._apply(
|
||||
router=router,
|
||||
projects=projects,
|
||||
allowlist=allowlist,
|
||||
config_path=config_path,
|
||||
plugin_configs=plugin_configs,
|
||||
watch_config=watch_config,
|
||||
)
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
*,
|
||||
router: AutoRouter,
|
||||
projects: ProjectsConfig,
|
||||
allowlist: Iterable[str] | None,
|
||||
config_path: Path | None,
|
||||
plugin_configs: Mapping[str, Any] | None,
|
||||
watch_config: bool,
|
||||
) -> None:
|
||||
self._router = router
|
||||
self._projects = projects
|
||||
@@ -162,51 +190,16 @@ class TransportRuntime:
|
||||
chat_project = self._projects.project_for_chat(chat_id)
|
||||
default_project = chat_project or self._projects.default_project
|
||||
|
||||
context_source: Literal[
|
||||
"reply_ctx",
|
||||
"directives",
|
||||
"ambient",
|
||||
"default_project",
|
||||
"none",
|
||||
] = "none"
|
||||
context: RunContext | None = None
|
||||
|
||||
if reply_ctx is not None:
|
||||
context = reply_ctx
|
||||
context_source = "reply_ctx"
|
||||
else:
|
||||
project_key = directives.project
|
||||
branch = directives.branch
|
||||
if project_key is None:
|
||||
if ambient_context is not None and ambient_context.project is not None:
|
||||
project_key = ambient_context.project
|
||||
else:
|
||||
project_key = default_project
|
||||
if branch is None:
|
||||
if (
|
||||
ambient_context is not None
|
||||
and ambient_context.branch is not None
|
||||
and project_key == ambient_context.project
|
||||
):
|
||||
branch = ambient_context.branch
|
||||
if project_key is not None or branch is not None:
|
||||
context = RunContext(project=project_key, branch=branch)
|
||||
if directives.project is not None or directives.branch is not None:
|
||||
context_source = "directives"
|
||||
elif ambient_context is not None and ambient_context.project is not None:
|
||||
context_source = "ambient"
|
||||
elif default_project is not None:
|
||||
context_source = "default_project"
|
||||
|
||||
engine_override = directives.engine
|
||||
if engine_override is None and context is not None:
|
||||
project = (
|
||||
self._projects.projects.get(context.project)
|
||||
if context.project is not None
|
||||
else None
|
||||
)
|
||||
if project is not None and project.default_engine is not None:
|
||||
engine_override = project.default_engine
|
||||
context, context_source = self._resolve_context(
|
||||
directives=directives,
|
||||
reply_ctx=reply_ctx,
|
||||
ambient_context=ambient_context,
|
||||
default_project=default_project,
|
||||
)
|
||||
engine_override = self._resolve_engine_override(
|
||||
directives_engine=directives.engine,
|
||||
context=context,
|
||||
)
|
||||
|
||||
return ResolvedMessage(
|
||||
prompt=directives.prompt,
|
||||
@@ -216,6 +209,65 @@ class TransportRuntime:
|
||||
context_source=context_source,
|
||||
)
|
||||
|
||||
def _resolve_context(
|
||||
self,
|
||||
*,
|
||||
directives: ParsedDirectives,
|
||||
reply_ctx: RunContext | None,
|
||||
ambient_context: RunContext | None,
|
||||
default_project: str | None,
|
||||
) -> tuple[RunContext | None, ContextSource]:
|
||||
if reply_ctx is not None:
|
||||
return reply_ctx, "reply_ctx"
|
||||
|
||||
project_key = directives.project
|
||||
branch = directives.branch
|
||||
if project_key is None:
|
||||
if ambient_context is not None and ambient_context.project is not None:
|
||||
project_key = ambient_context.project
|
||||
else:
|
||||
project_key = default_project
|
||||
if branch is None:
|
||||
if (
|
||||
ambient_context is not None
|
||||
and ambient_context.branch is not None
|
||||
and project_key == ambient_context.project
|
||||
):
|
||||
branch = ambient_context.branch
|
||||
context: RunContext | None = None
|
||||
if project_key is not None or branch is not None:
|
||||
context = RunContext(project=project_key, branch=branch)
|
||||
|
||||
if directives.project is not None or directives.branch is not None:
|
||||
context_source: ContextSource = "directives"
|
||||
elif ambient_context is not None and ambient_context.project is not None:
|
||||
context_source = "ambient"
|
||||
elif default_project is not None:
|
||||
context_source = "default_project"
|
||||
else:
|
||||
context_source = "none"
|
||||
|
||||
return context, context_source
|
||||
|
||||
def _resolve_engine_override(
|
||||
self,
|
||||
*,
|
||||
directives_engine: EngineId | None,
|
||||
context: RunContext | None,
|
||||
) -> EngineId | None:
|
||||
if directives_engine is not None:
|
||||
return directives_engine
|
||||
if context is None:
|
||||
return None
|
||||
project = (
|
||||
self._projects.projects.get(context.project)
|
||||
if context.project is not None
|
||||
else None
|
||||
)
|
||||
if project is not None and project.default_engine is not None:
|
||||
return project.default_engine
|
||||
return None
|
||||
|
||||
@property
|
||||
def default_project(self) -> str | None:
|
||||
return self._projects.default_project
|
||||
|
||||
@@ -5,14 +5,7 @@ from pathlib import Path
|
||||
from typing import Iterable, Protocol, runtime_checkable
|
||||
|
||||
from .backends import EngineBackend, SetupIssue
|
||||
from .config import ConfigError
|
||||
from .plugins import (
|
||||
PluginLoadFailed,
|
||||
PluginNotFound,
|
||||
TRANSPORT_GROUP,
|
||||
load_entrypoint,
|
||||
list_ids,
|
||||
)
|
||||
from .plugins import TRANSPORT_GROUP, list_ids, load_plugin_backend
|
||||
from .transport_runtime import TransportRuntime
|
||||
|
||||
|
||||
@@ -67,22 +60,14 @@ def _validate_transport_backend(backend: object, ep) -> None:
|
||||
def get_transport(
|
||||
transport_id: str, *, allowlist: Iterable[str] | None = None
|
||||
) -> TransportBackend:
|
||||
try:
|
||||
backend = load_entrypoint(
|
||||
TRANSPORT_GROUP,
|
||||
transport_id,
|
||||
allowlist=allowlist,
|
||||
validator=_validate_transport_backend,
|
||||
)
|
||||
except PluginNotFound as exc:
|
||||
if exc.available:
|
||||
available = ", ".join(exc.available)
|
||||
message = f"Unknown transport {transport_id!r}. Available: {available}."
|
||||
else:
|
||||
message = f"Unknown transport {transport_id!r}."
|
||||
raise ConfigError(message) from exc
|
||||
except PluginLoadFailed as exc:
|
||||
raise ConfigError(f"Failed to load transport {transport_id!r}: {exc}") from exc
|
||||
backend = load_plugin_backend(
|
||||
TRANSPORT_GROUP,
|
||||
transport_id,
|
||||
allowlist=allowlist,
|
||||
validator=_validate_transport_backend,
|
||||
kind_label="transport",
|
||||
)
|
||||
assert backend is not None
|
||||
return backend
|
||||
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ async def drain_stderr(
|
||||
tag=tag,
|
||||
line=text,
|
||||
)
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log_pipeline(
|
||||
logger,
|
||||
"subprocess.stderr.error",
|
||||
|
||||
@@ -53,7 +53,7 @@ def _signal_process(
|
||||
return
|
||||
except ProcessLookupError:
|
||||
return
|
||||
except Exception as exc:
|
||||
except OSError as exc:
|
||||
logger.debug(
|
||||
log_event,
|
||||
error=str(exc),
|
||||
|
||||
@@ -20,7 +20,7 @@ def _decode_fixture(name: str) -> list[str]:
|
||||
continue
|
||||
try:
|
||||
decoded = claude_schema.decode_stream_json_line(line)
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}")
|
||||
continue
|
||||
|
||||
|
||||
@@ -21,12 +21,12 @@ def _decode_fixture(name: str) -> list[str]:
|
||||
continue
|
||||
try:
|
||||
json.loads(line)
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(f"line {lineno}: invalid JSON ({exc})")
|
||||
continue
|
||||
try:
|
||||
codex_schema.decode_event(line)
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}")
|
||||
|
||||
return errors
|
||||
|
||||
@@ -20,7 +20,7 @@ def _decode_fixture(name: str) -> list[str]:
|
||||
continue
|
||||
try:
|
||||
opencode_schema.decode_event(line)
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}")
|
||||
|
||||
return errors
|
||||
|
||||
@@ -20,7 +20,7 @@ def _decode_fixture(name: str) -> list[str]:
|
||||
continue
|
||||
try:
|
||||
pi_schema.decode_event(line)
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(f"line {lineno}: {exc.__class__.__name__}: {exc}")
|
||||
|
||||
return errors
|
||||
|
||||
@@ -90,7 +90,7 @@ def test_projects_default_engine_unknown() -> None:
|
||||
"projects": {"z80": {"path": "/tmp/repo", "default_engine": "nope"}},
|
||||
}
|
||||
settings = TakopiSettings.model_validate(config)
|
||||
with pytest.raises(ConfigError, match="projects.z80.default_engine"):
|
||||
with pytest.raises(ConfigError, match=r"projects\.z80\.default_engine"):
|
||||
settings.to_projects_config(
|
||||
config_path=Path("takopi.toml"),
|
||||
engine_ids=["codex"],
|
||||
|
||||
@@ -172,7 +172,7 @@ def test_transport_config_telegram_and_extra(tmp_path: Path) -> None:
|
||||
},
|
||||
}
|
||||
)
|
||||
with pytest.raises(ConfigError, match="transports.discord"):
|
||||
with pytest.raises(ConfigError, match=r"transports\.discord"):
|
||||
settings.transport_config("discord", config_path=config_path)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user