refactor: telegram modules and tighten linting (#111)
This commit is contained in:
@@ -408,7 +408,7 @@ from ..schemas import acme as acme_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENGINE: EngineId = EngineId("acme")
|
||||
ENGINE: EngineId = "acme"
|
||||
_RESUME_RE = re.compile(
|
||||
r"(?im)^\s*`?acme\s+--resume\s+(?P<token>[^`\s]+)`?\s*$"
|
||||
)
|
||||
|
||||
+1
-1
@@ -65,7 +65,7 @@ addopts = ["--cov=takopi", "--cov-report=term-missing", "--cov-fail-under=75"]
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
extend-select = ["B904", "BLE001", "S110", "RUF043"]
|
||||
extend-select = ["B", "BLE001", "C4", "PERF", "RUF043", "S110", "SIM", "UP"]
|
||||
|
||||
[tool.ty.src]
|
||||
include = ["src", "tests"]
|
||||
|
||||
+2
-5
@@ -256,7 +256,7 @@ def _run_auto_router(
|
||||
)
|
||||
lock_token = transport_backend.lock_token(
|
||||
transport_config=transport_config,
|
||||
config_path=config_path,
|
||||
_config_path=config_path,
|
||||
)
|
||||
lock_handle = acquire_config_lock(config_path, lock_token)
|
||||
runtime = spec.to_runtime(config_path=config_path)
|
||||
@@ -301,10 +301,7 @@ def _default_alias_from_path(path: Path) -> str | None:
|
||||
|
||||
|
||||
def _ensure_projects_table(config: dict, config_path: Path) -> dict:
|
||||
projects = config.get("projects")
|
||||
if projects is None:
|
||||
projects = {}
|
||||
config["projects"] = projects
|
||||
projects = config.setdefault("projects", {})
|
||||
if not isinstance(projects, dict):
|
||||
raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.")
|
||||
return projects
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable, Iterable
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
|
||||
from watchfiles import awatch
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
from collections.abc import Iterable
|
||||
|
||||
from .backends import EngineBackend
|
||||
from .config import ConfigError
|
||||
|
||||
@@ -11,7 +11,7 @@ from .logging import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LockInfo:
|
||||
pid: int | None
|
||||
token_fingerprint: str | None
|
||||
@@ -29,7 +29,7 @@ class LockError(RuntimeError):
|
||||
super().__init__(_format_lock_message(path, state))
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class LockHandle:
|
||||
path: Path
|
||||
|
||||
@@ -44,7 +44,7 @@ class LockHandle:
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
|
||||
def __enter__(self) -> "LockHandle":
|
||||
def __enter__(self) -> LockHandle:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
|
||||
@@ -107,9 +107,8 @@ def _redact_value(value: Any, memo: dict[int, Any]) -> Any:
|
||||
|
||||
|
||||
def _redact_event_dict(
|
||||
logger: Any, method_name: str, event_dict: dict[str, Any]
|
||||
_logger: Any, _method_name: str, event_dict: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
_ = logger, method_name
|
||||
return _redact_value(event_dict, memo={})
|
||||
|
||||
|
||||
@@ -222,10 +221,7 @@ def setup_logging(
|
||||
|
||||
format_value = os.environ.get("TAKOPI_LOG_FORMAT", "console").strip().lower()
|
||||
color_override = os.environ.get("TAKOPI_LOG_COLOR")
|
||||
if color_override is None:
|
||||
is_tty = sys.stdout.isatty()
|
||||
else:
|
||||
is_tty = _truthy(color_override)
|
||||
is_tty = sys.stdout.isatty() if color_override is None else _truthy(color_override)
|
||||
if format_value == "json":
|
||||
renderer: Any = structlog.processors.JSONRenderer(default=str)
|
||||
else:
|
||||
@@ -242,7 +238,9 @@ def setup_logging(
|
||||
_log_file_handle = None
|
||||
if log_file:
|
||||
try:
|
||||
_log_file_handle = open(log_file, "a", encoding="utf-8")
|
||||
_log_file_handle = open( # noqa: SIM115
|
||||
log_file, "a", encoding="utf-8"
|
||||
)
|
||||
except OSError:
|
||||
_log_file_handle = None
|
||||
|
||||
|
||||
@@ -120,9 +120,12 @@ def format_file_change_title(action: Action, *, command_width: int | None) -> st
|
||||
was_relativized = relativized != fallback
|
||||
if was_relativized:
|
||||
fallback = relativized
|
||||
if fallback and not (fallback.startswith("`") and fallback.endswith("`")):
|
||||
if was_relativized or os.sep in fallback or "/" in fallback:
|
||||
fallback = f"`{fallback}`"
|
||||
if (
|
||||
fallback
|
||||
and not (fallback.startswith("`") and fallback.endswith("`"))
|
||||
and (was_relativized or os.sep in fallback or "/" in fallback)
|
||||
):
|
||||
fallback = f"`{fallback}`"
|
||||
return f"files: {shorten(fallback, command_width)}"
|
||||
|
||||
|
||||
@@ -247,10 +250,7 @@ class MarkdownFormatter:
|
||||
|
||||
def _format_actions(self, state: ProgressState) -> list[str]:
|
||||
actions = list(state.actions)
|
||||
if self.max_actions == 0:
|
||||
actions = []
|
||||
else:
|
||||
actions = actions[-self.max_actions :]
|
||||
actions = [] if self.max_actions == 0 else actions[-self.max_actions :]
|
||||
return [
|
||||
format_action_line(
|
||||
action_state.action,
|
||||
|
||||
+7
-7
@@ -3,11 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, TypeAlias
|
||||
from typing import Any, Literal
|
||||
|
||||
EngineId: TypeAlias = str
|
||||
type EngineId = str
|
||||
|
||||
ActionKind: TypeAlias = Literal[
|
||||
type ActionKind = Literal[
|
||||
"command",
|
||||
"tool",
|
||||
"file_change",
|
||||
@@ -19,14 +19,14 @@ ActionKind: TypeAlias = Literal[
|
||||
"telemetry",
|
||||
]
|
||||
|
||||
TakopiEventType: TypeAlias = Literal[
|
||||
type TakopiEventType = Literal[
|
||||
"started",
|
||||
"action",
|
||||
"completed",
|
||||
]
|
||||
|
||||
ActionPhase: TypeAlias = Literal["started", "updated", "completed"]
|
||||
ActionLevel: TypeAlias = Literal["debug", "info", "warning", "error"]
|
||||
type ActionPhase = Literal["started", "updated", "completed"]
|
||||
type ActionLevel = Literal["debug", "info", "warning", "error"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@@ -74,4 +74,4 @@ class CompletedEvent:
|
||||
usage: dict[str, Any] | None = None
|
||||
|
||||
|
||||
TakopiEvent: TypeAlias = StartedEvent | ActionEvent | CompletedEvent
|
||||
type TakopiEvent = StartedEvent | ActionEvent | CompletedEvent
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from importlib.metadata import EntryPoint, entry_points
|
||||
import re
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
from .ids import ID_PATTERN, is_valid_id
|
||||
|
||||
@@ -80,12 +81,7 @@ def reset_plugin_state() -> None:
|
||||
|
||||
|
||||
def _select_entrypoints(group: str) -> list[EntryPoint]:
|
||||
eps = entry_points()
|
||||
if hasattr(eps, "select"):
|
||||
return list(eps.select(group=group))
|
||||
if isinstance(eps, Mapping):
|
||||
return list(eps.get(group, []))
|
||||
return []
|
||||
return list(entry_points().select(group=group))
|
||||
|
||||
|
||||
def entrypoint_distribution_name(ep: EntryPoint) -> str | None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from .model import Action, ActionEvent, ResumeToken, StartedEvent, TakopiEvent
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Literal, TypeAlias
|
||||
from typing import Literal
|
||||
from collections.abc import Iterable
|
||||
|
||||
from .model import EngineId, ResumeToken
|
||||
from .runner import Runner
|
||||
@@ -17,7 +18,7 @@ class RunnerUnavailableError(RuntimeError):
|
||||
self.issue = issue
|
||||
|
||||
|
||||
EngineStatus: TypeAlias = Literal["ok", "missing_cli", "bad_config", "load_error"]
|
||||
type EngineStatus = Literal["ok", "missing_cli", "bad_config", "load_error"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
|
||||
@@ -37,12 +37,9 @@ def _log_runner_event(evt: TakopiEvent) -> None:
|
||||
|
||||
|
||||
def _strip_resume_lines(text: str, *, is_resume_line: Callable[[str], bool]) -> str:
|
||||
stripped_lines: list[str] = []
|
||||
for line in text.splitlines():
|
||||
if is_resume_line(line):
|
||||
continue
|
||||
stripped_lines.append(line)
|
||||
prompt = "\n".join(stripped_lines).strip()
|
||||
prompt = "\n".join(
|
||||
line for line in text.splitlines() if not is_resume_line(line)
|
||||
).strip()
|
||||
return prompt or "continue"
|
||||
|
||||
|
||||
@@ -83,14 +80,14 @@ class IncomingMessage:
|
||||
thread_id: ThreadId | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ExecBridgeConfig:
|
||||
transport: Transport
|
||||
presenter: Presenter
|
||||
final_notify: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class RunningTask:
|
||||
resume: ResumeToken | None = None
|
||||
resume_ready: anyio.Event = field(default_factory=anyio.Event)
|
||||
|
||||
@@ -14,11 +14,11 @@ from ..logging import get_logger
|
||||
from ..model import Action, ActionKind, EngineId, ResumeToken, TakopiEvent
|
||||
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
||||
from ..schemas import claude as claude_schema
|
||||
from ..utils.paths import relativize_command, relativize_path
|
||||
from .tool_actions import tool_input_path, tool_kind_and_title
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
ENGINE: EngineId = EngineId("claude")
|
||||
ENGINE: EngineId = "claude"
|
||||
DEFAULT_ALLOWED_TOOLS = ["Bash", "Read", "Edit", "Write"]
|
||||
|
||||
_RESUME_RE = re.compile(
|
||||
@@ -67,55 +67,10 @@ def _coerce_comma_list(value: Any) -> str | None:
|
||||
return text or None
|
||||
|
||||
|
||||
def _tool_input_path(tool_input: dict[str, Any]) -> str | None:
|
||||
for key in ("file_path", "path"):
|
||||
value = tool_input.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _tool_kind_and_title(
|
||||
name: str, tool_input: dict[str, Any]
|
||||
) -> tuple[ActionKind, str]:
|
||||
if name in {"Bash", "Shell", "KillShell"}:
|
||||
command = tool_input.get("command")
|
||||
display = relativize_command(str(command or name))
|
||||
return "command", display
|
||||
if name in {"Edit", "Write", "NotebookEdit", "MultiEdit"}:
|
||||
path = _tool_input_path(tool_input)
|
||||
if path:
|
||||
return "file_change", relativize_path(str(path))
|
||||
return "file_change", str(name)
|
||||
if name == "Read":
|
||||
path = _tool_input_path(tool_input)
|
||||
if path:
|
||||
return "tool", f"read: `{relativize_path(str(path))}`"
|
||||
return "tool", "read"
|
||||
if name == "Glob":
|
||||
pattern = tool_input.get("pattern")
|
||||
if pattern:
|
||||
return "tool", f"glob: `{pattern}`"
|
||||
return "tool", "glob"
|
||||
if name == "Grep":
|
||||
pattern = tool_input.get("pattern")
|
||||
if pattern:
|
||||
return "tool", f"grep: {pattern}"
|
||||
return "tool", "grep"
|
||||
if name == "WebSearch":
|
||||
query = tool_input.get("query")
|
||||
return "web_search", str(query or "search")
|
||||
if name == "WebFetch":
|
||||
url = tool_input.get("url")
|
||||
return "web_search", str(url or "fetch")
|
||||
if name in {"TodoWrite", "TodoRead"}:
|
||||
return "note", "update todos" if name == "TodoWrite" else "read todos"
|
||||
if name == "AskUserQuestion":
|
||||
return "note", "ask user"
|
||||
if name in {"Task", "Agent"}:
|
||||
desc = tool_input.get("description") or tool_input.get("prompt")
|
||||
return "subagent", str(desc or name)
|
||||
return "tool", name
|
||||
return tool_kind_and_title(name, tool_input, path_keys=("file_path", "path"))
|
||||
|
||||
|
||||
def _tool_action(
|
||||
@@ -137,7 +92,7 @@ def _tool_action(
|
||||
detail["parent_tool_use_id"] = parent_tool_use_id
|
||||
|
||||
if kind == "file_change":
|
||||
path = _tool_input_path(tool_input)
|
||||
path = tool_input_path(tool_input, path_keys=("file_path", "path"))
|
||||
if path:
|
||||
detail["changes"] = [{"path": path, "kind": "update"}]
|
||||
|
||||
@@ -321,7 +276,7 @@ def translate_claude_event(
|
||||
return []
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
engine: EngineId = ENGINE
|
||||
resume_re: re.Pattern[str] = _RESUME_RE
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..utils.paths import relativize_command
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
ENGINE: EngineId = EngineId("codex")
|
||||
ENGINE: EngineId = "codex"
|
||||
|
||||
__all__ = [
|
||||
"ENGINE",
|
||||
|
||||
@@ -4,7 +4,6 @@ import re
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import TypeAlias
|
||||
|
||||
import anyio
|
||||
|
||||
@@ -18,7 +17,7 @@ from ..model import (
|
||||
)
|
||||
from ..runner import ResumeTokenMixin, Runner, SessionLockMixin
|
||||
|
||||
ENGINE: EngineId = EngineId("mock")
|
||||
ENGINE: EngineId = "mock"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@@ -52,7 +51,7 @@ class Raise:
|
||||
error: Exception
|
||||
|
||||
|
||||
ScriptStep: TypeAlias = Emit | Advance | Sleep | Wait | Return | Raise
|
||||
type ScriptStep = Emit | Advance | Sleep | Wait | Return | Raise
|
||||
|
||||
|
||||
def _resume_token(engine: EngineId, value: str | None) -> ResumeToken:
|
||||
@@ -108,9 +107,9 @@ class MockRunner(SessionLockMixin, ResumeTokenMixin, Runner):
|
||||
if (
|
||||
isinstance(event_out, ActionEvent)
|
||||
and event_out.phase == "completed"
|
||||
and event_out.ok is None
|
||||
):
|
||||
if event_out.ok is None:
|
||||
event_out = replace(event_out, ok=True)
|
||||
event_out = replace(event_out, ok=True)
|
||||
yield event_out
|
||||
await anyio.sleep(0)
|
||||
|
||||
@@ -187,9 +186,9 @@ class ScriptRunner(MockRunner):
|
||||
if (
|
||||
isinstance(event_out, ActionEvent)
|
||||
and event_out.phase == "completed"
|
||||
and event_out.ok is None
|
||||
):
|
||||
if event_out.ok is None:
|
||||
event_out = replace(event_out, ok=True)
|
||||
event_out = replace(event_out, ok=True)
|
||||
yield event_out
|
||||
await anyio.sleep(0)
|
||||
continue
|
||||
|
||||
@@ -35,11 +35,12 @@ from ..model import (
|
||||
)
|
||||
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
||||
from ..schemas import opencode as opencode_schema
|
||||
from ..utils.paths import relativize_command, relativize_path
|
||||
from ..utils.paths import relativize_path
|
||||
from .tool_actions import tool_input_path, tool_kind_and_title
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
ENGINE: EngineId = EngineId("opencode")
|
||||
ENGINE: EngineId = "opencode"
|
||||
|
||||
_RESUME_RE = re.compile(
|
||||
r"(?im)^\s*`?opencode(?:\s+run)?\s+(?:--session|-s)\s+(?P<token>ses_[A-Za-z0-9]+)`?\s*$"
|
||||
@@ -79,54 +80,12 @@ def _action_event(
|
||||
def _tool_kind_and_title(
|
||||
tool_name: str, tool_input: dict[str, Any]
|
||||
) -> tuple[ActionKind, str]:
|
||||
"""Map OpenCode tool names to Takopi action kinds and titles."""
|
||||
name_lower = tool_name.lower()
|
||||
|
||||
if name_lower in {"bash", "shell"}:
|
||||
command = tool_input.get("command")
|
||||
display = relativize_command(str(command or tool_name))
|
||||
return "command", display
|
||||
|
||||
if name_lower in {"edit", "write", "multiedit"}:
|
||||
path = tool_input.get("file_path") or tool_input.get("filePath")
|
||||
if path:
|
||||
return "file_change", relativize_path(str(path))
|
||||
return "file_change", str(tool_name)
|
||||
|
||||
if name_lower == "read":
|
||||
path = tool_input.get("file_path") or tool_input.get("filePath")
|
||||
if path:
|
||||
return "tool", f"read: `{relativize_path(str(path))}`"
|
||||
return "tool", "read"
|
||||
|
||||
if name_lower == "glob":
|
||||
pattern = tool_input.get("pattern")
|
||||
if pattern:
|
||||
return "tool", f"glob: `{pattern}`"
|
||||
return "tool", "glob"
|
||||
|
||||
if name_lower == "grep":
|
||||
pattern = tool_input.get("pattern")
|
||||
if pattern:
|
||||
return "tool", f"grep: {pattern}"
|
||||
return "tool", "grep"
|
||||
|
||||
if name_lower in {"websearch", "web_search"}:
|
||||
query = tool_input.get("query")
|
||||
return "web_search", str(query or "search")
|
||||
|
||||
if name_lower in {"webfetch", "web_fetch"}:
|
||||
url = tool_input.get("url")
|
||||
return "web_search", str(url or "fetch")
|
||||
|
||||
if name_lower in {"todowrite", "todoread"}:
|
||||
return "note", "update todos" if "write" in name_lower else "read todos"
|
||||
|
||||
if name_lower == "task":
|
||||
desc = tool_input.get("description") or tool_input.get("prompt")
|
||||
return "tool", str(desc or tool_name)
|
||||
|
||||
return "tool", tool_name
|
||||
return tool_kind_and_title(
|
||||
tool_name,
|
||||
tool_input,
|
||||
path_keys=("file_path", "filePath"),
|
||||
task_kind="tool",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_tool_title(
|
||||
@@ -137,10 +96,10 @@ def _normalize_tool_title(
|
||||
if "`" in title:
|
||||
return title
|
||||
|
||||
path = tool_input.get("file_path") or tool_input.get("filePath")
|
||||
path = tool_input_path(tool_input, path_keys=("file_path", "filePath"))
|
||||
if isinstance(path, str) and path:
|
||||
rel_path = relativize_path(path)
|
||||
if title == path or title == rel_path:
|
||||
if title in (path, rel_path):
|
||||
return f"`{rel_path}`"
|
||||
|
||||
return title
|
||||
@@ -190,9 +149,8 @@ def translate_opencode_event(
|
||||
"""Translate an OpenCode JSON event into Takopi events."""
|
||||
session_id = event.sessionID
|
||||
|
||||
if isinstance(session_id, str) and session_id:
|
||||
if state.session_id is None:
|
||||
state.session_id = session_id
|
||||
if isinstance(session_id, str) and session_id and state.session_id is None:
|
||||
state.session_id = session_id
|
||||
|
||||
match event:
|
||||
case opencode_schema.StepStart():
|
||||
@@ -340,7 +298,7 @@ def translate_opencode_event(
|
||||
return []
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
"""Runner for OpenCode CLI."""
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, UTC
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
@@ -27,11 +27,12 @@ from ..model import (
|
||||
)
|
||||
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
||||
from ..schemas import pi as pi_schema
|
||||
from ..utils.paths import get_run_base_dir, relativize_command, relativize_path
|
||||
from ..utils.paths import get_run_base_dir
|
||||
from .tool_actions import tool_kind_and_title
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
ENGINE: EngineId = EngineId("pi")
|
||||
ENGINE: EngineId = "pi"
|
||||
|
||||
_RESUME_RE = re.compile(r"(?im)^\s*`?pi\s+--session\s+(?P<token>.+?)`?\s*$")
|
||||
|
||||
@@ -96,32 +97,7 @@ def _tool_kind_and_title(
|
||||
name: str,
|
||||
args: dict[str, Any],
|
||||
) -> tuple[ActionKind, str]:
|
||||
tool = name.lower()
|
||||
if tool == "bash":
|
||||
command = args.get("command")
|
||||
return "command", relativize_command(str(command or "bash"))
|
||||
if tool in {"edit", "write"}:
|
||||
path = args.get("path")
|
||||
if path:
|
||||
return "file_change", relativize_path(str(path))
|
||||
return "file_change", tool
|
||||
if tool == "read":
|
||||
path = args.get("path")
|
||||
if path:
|
||||
return "tool", f"read: `{relativize_path(str(path))}`"
|
||||
return "tool", "read"
|
||||
if tool == "grep":
|
||||
pattern = args.get("pattern")
|
||||
return "tool", f"grep: {pattern}" if pattern else "grep"
|
||||
if tool == "find":
|
||||
pattern = args.get("pattern")
|
||||
return "tool", f"find: {pattern}" if pattern else "find"
|
||||
if tool == "ls":
|
||||
path = args.get("path")
|
||||
if path:
|
||||
return "tool", f"ls: `{relativize_path(str(path))}`"
|
||||
return "tool", "ls"
|
||||
return "tool", name
|
||||
return tool_kind_and_title(name, args, path_keys=("path",))
|
||||
|
||||
|
||||
def _last_assistant_message(messages: Any) -> dict[str, Any] | None:
|
||||
@@ -418,7 +394,7 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
cwd = get_run_base_dir() or Path.cwd()
|
||||
session_dir = _default_session_dir(cwd)
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
timestamp = datetime.now(UTC).isoformat()
|
||||
safe_timestamp = timestamp.replace(":", "-").replace(".", "-")
|
||||
token = uuid4().hex
|
||||
filename = f"{safe_timestamp}_{token}.jsonl"
|
||||
@@ -442,7 +418,9 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||
def _default_session_dir(cwd: PurePath) -> Path:
|
||||
agent_dir = os.environ.get("PI_CODING_AGENT_DIR")
|
||||
base = Path(agent_dir).expanduser() if agent_dir else Path.home() / ".pi" / "agent"
|
||||
safe_path = f"--{str(cwd).lstrip('/\\\\').replace('/', '-').replace('\\', '-').replace(':', '-')}--"
|
||||
cwd_str = str(cwd).lstrip("/\\")
|
||||
safe_path_part = cwd_str.translate(str.maketrans({"/": "-", "\\": "-", ":": "-"}))
|
||||
safe_path = f"--{safe_path_part}--"
|
||||
return base / "sessions" / safe_path
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from ..model import ActionKind
|
||||
from ..utils.paths import relativize_command, relativize_path
|
||||
|
||||
|
||||
def tool_input_path(
|
||||
tool_input: Mapping[str, Any],
|
||||
*,
|
||||
path_keys: Sequence[str],
|
||||
) -> str | None:
|
||||
for key in path_keys:
|
||||
value = tool_input.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def tool_kind_and_title(
|
||||
tool_name: str,
|
||||
tool_input: Mapping[str, Any],
|
||||
*,
|
||||
path_keys: Sequence[str],
|
||||
task_kind: ActionKind = "subagent",
|
||||
) -> tuple[ActionKind, str]:
|
||||
name_lower = tool_name.lower()
|
||||
|
||||
if name_lower in {"bash", "shell", "killshell"}:
|
||||
command = tool_input.get("command")
|
||||
display = relativize_command(str(command or tool_name))
|
||||
return "command", display
|
||||
|
||||
if name_lower in {"edit", "write", "notebookedit", "multiedit"}:
|
||||
path = tool_input_path(tool_input, path_keys=path_keys)
|
||||
if path:
|
||||
return "file_change", relativize_path(str(path))
|
||||
return "file_change", str(tool_name)
|
||||
|
||||
if name_lower == "read":
|
||||
path = tool_input_path(tool_input, path_keys=path_keys)
|
||||
if path:
|
||||
return "tool", f"read: `{relativize_path(str(path))}`"
|
||||
return "tool", "read"
|
||||
|
||||
if name_lower == "glob":
|
||||
pattern = tool_input.get("pattern")
|
||||
if pattern:
|
||||
return "tool", f"glob: `{pattern}`"
|
||||
return "tool", "glob"
|
||||
|
||||
if name_lower == "grep":
|
||||
pattern = tool_input.get("pattern")
|
||||
if pattern:
|
||||
return "tool", f"grep: {pattern}"
|
||||
return "tool", "grep"
|
||||
|
||||
if name_lower == "find":
|
||||
pattern = tool_input.get("pattern")
|
||||
if pattern:
|
||||
return "tool", f"find: {pattern}"
|
||||
return "tool", "find"
|
||||
|
||||
if name_lower == "ls":
|
||||
path = tool_input_path(tool_input, path_keys=path_keys)
|
||||
if path:
|
||||
return "tool", f"ls: `{relativize_path(str(path))}`"
|
||||
return "tool", "ls"
|
||||
|
||||
if name_lower in {"websearch", "web_search"}:
|
||||
query = tool_input.get("query")
|
||||
return "web_search", str(query or "search")
|
||||
|
||||
if name_lower in {"webfetch", "web_fetch"}:
|
||||
url = tool_input.get("url")
|
||||
return "web_search", str(url or "fetch")
|
||||
|
||||
if name_lower in {"todowrite", "todoread"}:
|
||||
return "note", "update todos" if "write" in name_lower else "read todos"
|
||||
|
||||
if name_lower == "askuserquestion":
|
||||
return "note", "ask user"
|
||||
|
||||
if name_lower in {"task", "agent"}:
|
||||
desc = tool_input.get("description") or tool_input.get("prompt")
|
||||
return task_kind, str(desc or tool_name)
|
||||
|
||||
return "tool", tool_name
|
||||
@@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Mapping
|
||||
from typing import Any
|
||||
from collections.abc import Iterable, Mapping
|
||||
|
||||
from .backends import EngineBackend
|
||||
from .config import ConfigError, ProjectsConfig
|
||||
|
||||
+17
-2
@@ -2,14 +2,18 @@ from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Protocol
|
||||
from typing import Any, Protocol
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import anyio
|
||||
|
||||
from .context import RunContext
|
||||
from .logging import get_logger
|
||||
from .model import ResumeToken
|
||||
from .transport import ChannelId, MessageId, ThreadId
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ThreadJob:
|
||||
@@ -108,7 +112,18 @@ class ThreadScheduler:
|
||||
if done is not None and not done.is_set():
|
||||
await done.wait()
|
||||
|
||||
await self._run_job(job)
|
||||
try:
|
||||
await self._run_job(job)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception(
|
||||
"scheduler.job_failed",
|
||||
key=key,
|
||||
tag=job.resume_token.engine,
|
||||
chat_id=job.chat_id,
|
||||
user_msg_id=job.user_msg_id,
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
finally:
|
||||
async with self._lock:
|
||||
self._active_threads.discard(key)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, TypeAlias
|
||||
from typing import Any, Literal
|
||||
|
||||
import msgspec
|
||||
|
||||
@@ -36,7 +36,7 @@ class StreamToolResultBlock(
|
||||
is_error: bool | None = None
|
||||
|
||||
|
||||
StreamContentBlock: TypeAlias = (
|
||||
type StreamContentBlock = (
|
||||
StreamTextBlock | StreamThinkingBlock | StreamToolUseBlock | StreamToolResultBlock
|
||||
)
|
||||
|
||||
@@ -164,7 +164,7 @@ class ControlRewindFilesRequest(
|
||||
user_message_id: str
|
||||
|
||||
|
||||
ControlRequest: TypeAlias = (
|
||||
type ControlRequest = (
|
||||
ControlInterruptRequest
|
||||
| ControlCanUseToolRequest
|
||||
| ControlInitializeRequest
|
||||
@@ -196,7 +196,7 @@ class ControlErrorResponse(
|
||||
error: str
|
||||
|
||||
|
||||
ControlResponse: TypeAlias = ControlSuccessResponse | ControlErrorResponse
|
||||
type ControlResponse = ControlSuccessResponse | ControlErrorResponse
|
||||
|
||||
|
||||
class StreamControlResponse(
|
||||
@@ -217,7 +217,7 @@ class StreamControlCancelRequest(
|
||||
request_id: str | None = None
|
||||
|
||||
|
||||
StreamJsonMessage: TypeAlias = (
|
||||
type StreamJsonMessage = (
|
||||
StreamUserMessage
|
||||
| StreamAssistantMessage
|
||||
| StreamSystemMessage
|
||||
|
||||
@@ -2,27 +2,27 @@ from __future__ import annotations
|
||||
|
||||
# Headless JSONL schema derived from tag rust-v0.77.0 (git 112f40e91c12af0f7146d7e03f20283516a8af0b).
|
||||
|
||||
from typing import Any, Literal, TypeAlias
|
||||
from typing import Any, Literal
|
||||
|
||||
import msgspec
|
||||
|
||||
CommandExecutionStatus: TypeAlias = Literal[
|
||||
type CommandExecutionStatus = Literal[
|
||||
"in_progress",
|
||||
"completed",
|
||||
"failed",
|
||||
"declined",
|
||||
]
|
||||
PatchApplyStatus: TypeAlias = Literal[
|
||||
type PatchApplyStatus = Literal[
|
||||
"in_progress",
|
||||
"completed",
|
||||
"failed",
|
||||
]
|
||||
PatchChangeKind: TypeAlias = Literal[
|
||||
type PatchChangeKind = Literal[
|
||||
"add",
|
||||
"delete",
|
||||
"update",
|
||||
]
|
||||
McpToolCallStatus: TypeAlias = Literal[
|
||||
type McpToolCallStatus = Literal[
|
||||
"in_progress",
|
||||
"completed",
|
||||
"failed",
|
||||
@@ -127,7 +127,7 @@ class TodoListItem(msgspec.Struct, tag="todo_list", kw_only=True):
|
||||
items: list[TodoItem]
|
||||
|
||||
|
||||
ThreadItem: TypeAlias = (
|
||||
type ThreadItem = (
|
||||
AgentMessageItem
|
||||
| ReasoningItem
|
||||
| CommandExecutionItem
|
||||
@@ -151,7 +151,7 @@ class ItemCompleted(msgspec.Struct, tag="item.completed", kw_only=True):
|
||||
item: ThreadItem
|
||||
|
||||
|
||||
ThreadEvent: TypeAlias = (
|
||||
type ThreadEvent = (
|
||||
ThreadStarted
|
||||
| TurnStarted
|
||||
| TurnCompleted
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TypeAlias
|
||||
from typing import Any
|
||||
|
||||
import msgspec
|
||||
|
||||
@@ -42,7 +42,7 @@ class Error(_Event, tag="error"):
|
||||
message: Any = None
|
||||
|
||||
|
||||
OpenCodeEvent: TypeAlias = StepStart | StepFinish | ToolUse | Text | Error
|
||||
type OpenCodeEvent = StepStart | StepFinish | ToolUse | Text | Error
|
||||
|
||||
_DECODER = msgspec.json.Decoder(OpenCodeEvent)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TypeAlias
|
||||
from typing import Any
|
||||
|
||||
import msgspec
|
||||
|
||||
@@ -84,7 +84,7 @@ class AutoRetryEnd(_Event, tag="auto_retry_end"):
|
||||
finalError: str | None = None
|
||||
|
||||
|
||||
PiEvent: TypeAlias = (
|
||||
type PiEvent = (
|
||||
AgentStart
|
||||
| AgentEnd
|
||||
| MessageStart
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, ClassVar, Iterable, Literal
|
||||
from typing import Annotated, Any, ClassVar, Literal
|
||||
from collections.abc import Iterable
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
|
||||
@@ -50,9 +50,7 @@ def _build_startup_message(
|
||||
notes.append(f"failed to load: {', '.join(failed_engines)}")
|
||||
if notes:
|
||||
engine_list = f"{engine_list} ({'; '.join(notes)})"
|
||||
project_aliases = sorted(
|
||||
{alias for alias in runtime.project_aliases()}, key=str.lower
|
||||
)
|
||||
project_aliases = sorted(set(runtime.project_aliases()), key=str.lower)
|
||||
project_list = ", ".join(project_aliases) if project_aliases else "none"
|
||||
return (
|
||||
f"\N{OCTOPUS} **takopi is ready**\n\n"
|
||||
@@ -78,8 +76,7 @@ class TelegramBackend(TransportBackend):
|
||||
def interactive_setup(self, *, force: bool) -> bool:
|
||||
return interactive_setup(force=force)
|
||||
|
||||
def lock_token(self, *, transport_config: object, config_path: Path) -> str | None:
|
||||
_ = config_path
|
||||
def lock_token(self, *, transport_config: object, _config_path: Path) -> str | None:
|
||||
settings = _expect_transport_settings(transport_config)
|
||||
return settings.bot_token
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ def _is_cancelled_label(label: str) -> bool:
|
||||
return stripped.lower() == "cancelled"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TelegramBridgeConfig:
|
||||
bot: BotClient
|
||||
runtime: TransportRuntime
|
||||
|
||||
+86
-985
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,537 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, TypeVar
|
||||
|
||||
import httpx
|
||||
import msgspec
|
||||
|
||||
from ..logging import get_logger
|
||||
from .api_models import Chat, ChatMember, File, ForumTopic, Message, Update, User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RetryAfter(Exception):
|
||||
def __init__(self, retry_after: float, description: str | None = None) -> None:
|
||||
super().__init__(description or f"retry after {retry_after}")
|
||||
self.retry_after = float(retry_after)
|
||||
self.description = description
|
||||
|
||||
|
||||
class TelegramRetryAfter(RetryAfter):
|
||||
pass
|
||||
|
||||
|
||||
def retry_after_from_payload(payload: dict[str, Any]) -> float | None:
|
||||
params = payload.get("parameters")
|
||||
if isinstance(params, dict):
|
||||
retry_after = params.get("retry_after")
|
||||
if isinstance(retry_after, (int, float)):
|
||||
return float(retry_after)
|
||||
return None
|
||||
|
||||
|
||||
class BotClient(Protocol):
|
||||
async def close(self) -> None: ...
|
||||
|
||||
async def get_updates(
|
||||
self,
|
||||
offset: int | None,
|
||||
timeout_s: int = 50,
|
||||
allowed_updates: list[str] | None = None,
|
||||
) -> list[Update] | None: ...
|
||||
|
||||
async def get_file(self, file_id: str) -> File | None: ...
|
||||
|
||||
async def download_file(self, file_path: str) -> bytes | None: ...
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_to_message_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
message_thread_id: int | None = None,
|
||||
entities: list[dict] | None = None,
|
||||
parse_mode: str | None = None,
|
||||
reply_markup: dict[str, Any] | None = None,
|
||||
*,
|
||||
replace_message_id: int | None = None,
|
||||
) -> Message | None: ...
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: int,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
reply_to_message_id: int | None = None,
|
||||
message_thread_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
caption: str | None = None,
|
||||
) -> Message | None: ...
|
||||
|
||||
async def edit_message_text(
|
||||
self,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
text: str,
|
||||
entities: list[dict] | None = None,
|
||||
parse_mode: str | None = None,
|
||||
reply_markup: dict[str, Any] | None = None,
|
||||
*,
|
||||
wait: bool = True,
|
||||
) -> Message | None: ...
|
||||
|
||||
async def delete_message(
|
||||
self,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
) -> bool: ...
|
||||
|
||||
async def set_my_commands(
|
||||
self,
|
||||
commands: list[dict[str, Any]],
|
||||
*,
|
||||
scope: dict[str, Any] | None = None,
|
||||
language_code: str | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
async def get_me(self) -> User | None: ...
|
||||
|
||||
async def answer_callback_query(
|
||||
self,
|
||||
callback_query_id: str,
|
||||
text: str | None = None,
|
||||
show_alert: bool | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
async def get_chat(self, chat_id: int) -> Chat | None: ...
|
||||
|
||||
async def get_chat_member(
|
||||
self, chat_id: int, user_id: int
|
||||
) -> ChatMember | None: ...
|
||||
|
||||
async def create_forum_topic(
|
||||
self,
|
||||
chat_id: int,
|
||||
name: str,
|
||||
) -> ForumTopic | None: ...
|
||||
|
||||
async def edit_forum_topic(
|
||||
self,
|
||||
chat_id: int,
|
||||
message_thread_id: int,
|
||||
name: str,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
class HttpBotClient:
|
||||
def __init__(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
timeout_s: float = 120,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
if not token:
|
||||
raise ValueError("Telegram token is empty")
|
||||
self._base = f"https://api.telegram.org/bot{token}"
|
||||
self._file_base = f"https://api.telegram.org/file/bot{token}"
|
||||
self._http_client = http_client or httpx.AsyncClient(timeout=timeout_s)
|
||||
self._owns_http_client = http_client is None
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._owns_http_client:
|
||||
await self._http_client.aclose()
|
||||
|
||||
def _parse_telegram_envelope(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
resp: httpx.Response,
|
||||
payload: Any,
|
||||
) -> Any | None:
|
||||
if not isinstance(payload, dict):
|
||||
logger.error(
|
||||
"telegram.invalid_payload",
|
||||
method=method,
|
||||
url=str(resp.request.url),
|
||||
payload=payload,
|
||||
)
|
||||
return None
|
||||
|
||||
if not payload.get("ok"):
|
||||
if payload.get("error_code") == 429:
|
||||
retry_after = retry_after_from_payload(payload)
|
||||
retry_after = 5.0 if retry_after is None else retry_after
|
||||
logger.warning(
|
||||
"telegram.rate_limited",
|
||||
method=method,
|
||||
url=str(resp.request.url),
|
||||
retry_after=retry_after,
|
||||
)
|
||||
raise TelegramRetryAfter(retry_after)
|
||||
logger.error(
|
||||
"telegram.api_error",
|
||||
method=method,
|
||||
url=str(resp.request.url),
|
||||
payload=payload,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.debug("telegram.response", method=method, payload=payload)
|
||||
return payload.get("result")
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
*,
|
||||
json: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
) -> Any | None:
|
||||
request_payload = json if json is not None else data
|
||||
logger.debug("telegram.request", method=method, payload=request_payload)
|
||||
try:
|
||||
if json is not None:
|
||||
resp = await self._http_client.post(f"{self._base}/{method}", json=json)
|
||||
else:
|
||||
resp = await self._http_client.post(
|
||||
f"{self._base}/{method}", data=data, files=files
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
url = getattr(exc.request, "url", None)
|
||||
logger.error(
|
||||
"telegram.network_error",
|
||||
method=method,
|
||||
url=str(url) if url is not None else None,
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if resp.status_code == 429:
|
||||
retry_after: float | None = None
|
||||
try:
|
||||
response_payload = resp.json()
|
||||
except Exception: # noqa: BLE001
|
||||
response_payload = None
|
||||
if isinstance(response_payload, dict):
|
||||
retry_after = retry_after_from_payload(response_payload)
|
||||
retry_after = 5.0 if retry_after is None else retry_after
|
||||
logger.warning(
|
||||
"telegram.rate_limited",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
retry_after=retry_after,
|
||||
)
|
||||
raise TelegramRetryAfter(retry_after) from exc
|
||||
body = resp.text
|
||||
logger.error(
|
||||
"telegram.http_error",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
error=str(exc),
|
||||
body=body,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
response_payload = resp.json()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
body = resp.text
|
||||
logger.error(
|
||||
"telegram.bad_response",
|
||||
method=method,
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
body=body,
|
||||
)
|
||||
return None
|
||||
|
||||
return self._parse_telegram_envelope(
|
||||
method=method,
|
||||
resp=resp,
|
||||
payload=response_payload,
|
||||
)
|
||||
|
||||
def _decode_result(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
payload: Any,
|
||||
model: type[T],
|
||||
) -> T | None:
|
||||
if payload is None:
|
||||
return None
|
||||
try:
|
||||
return msgspec.convert(payload, type=model)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(
|
||||
"telegram.decode_error",
|
||||
method=method,
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
|
||||
return await self._request(method, json=json_data)
|
||||
|
||||
async def _post_form(
|
||||
self,
|
||||
method: str,
|
||||
data: dict[str, Any],
|
||||
files: dict[str, Any],
|
||||
) -> Any | None:
|
||||
return await self._request(method, data=data, files=files)
|
||||
|
||||
async def get_updates(
|
||||
self,
|
||||
offset: int | None,
|
||||
timeout_s: int = 50,
|
||||
allowed_updates: list[str] | None = None,
|
||||
) -> list[Update] | None:
|
||||
params: dict[str, Any] = {"timeout": timeout_s}
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
if allowed_updates is not None:
|
||||
params["allowed_updates"] = allowed_updates
|
||||
result = await self._post("getUpdates", params)
|
||||
if result is None or not isinstance(result, list):
|
||||
return None
|
||||
try:
|
||||
return msgspec.convert(result, type=list[Update])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(
|
||||
"telegram.decode_error",
|
||||
method="getUpdates",
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_file(self, file_id: str) -> File | None:
|
||||
result = await self._post("getFile", {"file_id": file_id})
|
||||
return self._decode_result(method="getFile", payload=result, model=File)
|
||||
|
||||
async def download_file(self, file_path: str) -> bytes | None:
|
||||
url = f"{self._file_base}/{file_path}"
|
||||
try:
|
||||
resp = await self._http_client.get(url)
|
||||
except httpx.HTTPError as exc:
|
||||
request_url = getattr(exc.request, "url", None)
|
||||
logger.error(
|
||||
"telegram.file_network_error",
|
||||
url=str(request_url) if request_url is not None else None,
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
return None
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if resp.status_code == 429:
|
||||
retry_after: float | None = None
|
||||
try:
|
||||
response_payload = resp.json()
|
||||
except Exception: # noqa: BLE001
|
||||
response_payload = None
|
||||
if isinstance(response_payload, dict):
|
||||
retry_after = retry_after_from_payload(response_payload)
|
||||
retry_after = 5.0 if retry_after is None else retry_after
|
||||
logger.warning(
|
||||
"telegram.rate_limited",
|
||||
method="download_file",
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
retry_after=retry_after,
|
||||
)
|
||||
raise TelegramRetryAfter(retry_after) from exc
|
||||
|
||||
logger.error(
|
||||
"telegram.file_http_error",
|
||||
status=resp.status_code,
|
||||
url=str(resp.request.url),
|
||||
error=str(exc),
|
||||
body=resp.text,
|
||||
)
|
||||
return None
|
||||
return resp.content
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_to_message_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
message_thread_id: int | None = None,
|
||||
entities: list[dict] | None = None,
|
||||
parse_mode: str | None = None,
|
||||
reply_markup: dict[str, Any] | None = None,
|
||||
*,
|
||||
replace_message_id: int | None = None,
|
||||
) -> Message | None:
|
||||
params: dict[str, Any] = {"chat_id": chat_id, "text": text}
|
||||
if disable_notification is not None:
|
||||
params["disable_notification"] = disable_notification
|
||||
if reply_to_message_id is not None:
|
||||
params["reply_to_message_id"] = reply_to_message_id
|
||||
if message_thread_id is not None:
|
||||
params["message_thread_id"] = message_thread_id
|
||||
if entities is not None:
|
||||
params["entities"] = entities
|
||||
if parse_mode is not None:
|
||||
params["parse_mode"] = parse_mode
|
||||
if reply_markup is not None:
|
||||
params["reply_markup"] = reply_markup
|
||||
result = await self._post("sendMessage", params)
|
||||
return self._decode_result(method="sendMessage", payload=result, model=Message)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: int,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
reply_to_message_id: int | None = None,
|
||||
message_thread_id: int | None = None,
|
||||
disable_notification: bool | None = False,
|
||||
caption: str | None = None,
|
||||
) -> Message | None:
|
||||
params: dict[str, Any] = {"chat_id": chat_id}
|
||||
if disable_notification is not None:
|
||||
params["disable_notification"] = disable_notification
|
||||
if reply_to_message_id is not None:
|
||||
params["reply_to_message_id"] = reply_to_message_id
|
||||
if message_thread_id is not None:
|
||||
params["message_thread_id"] = message_thread_id
|
||||
if caption is not None:
|
||||
params["caption"] = caption
|
||||
result = await self._post_form(
|
||||
"sendDocument",
|
||||
params,
|
||||
files={"document": (filename, content)},
|
||||
)
|
||||
return self._decode_result(method="sendDocument", payload=result, model=Message)
|
||||
|
||||
async def edit_message_text(
|
||||
self,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
text: str,
|
||||
entities: list[dict] | None = None,
|
||||
parse_mode: str | None = None,
|
||||
reply_markup: dict[str, Any] | None = None,
|
||||
*,
|
||||
wait: bool = True,
|
||||
) -> Message | None:
|
||||
params: dict[str, Any] = {
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"text": text,
|
||||
}
|
||||
if entities is not None:
|
||||
params["entities"] = entities
|
||||
if parse_mode is not None:
|
||||
params["parse_mode"] = parse_mode
|
||||
if reply_markup is not None:
|
||||
params["reply_markup"] = reply_markup
|
||||
result = await self._post("editMessageText", params)
|
||||
return self._decode_result(
|
||||
method="editMessageText",
|
||||
payload=result,
|
||||
model=Message,
|
||||
)
|
||||
|
||||
async def delete_message(
|
||||
self,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
) -> bool:
|
||||
result = await self._post(
|
||||
"deleteMessage",
|
||||
{"chat_id": chat_id, "message_id": message_id},
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
async def set_my_commands(
|
||||
self,
|
||||
commands: list[dict[str, Any]],
|
||||
*,
|
||||
scope: dict[str, Any] | None = None,
|
||||
language_code: str | None = None,
|
||||
) -> bool:
|
||||
params: dict[str, Any] = {"commands": commands}
|
||||
if scope is not None:
|
||||
params["scope"] = scope
|
||||
if language_code is not None:
|
||||
params["language_code"] = language_code
|
||||
result = await self._post("setMyCommands", params)
|
||||
return bool(result)
|
||||
|
||||
async def get_me(self) -> User | None:
|
||||
result = await self._post("getMe", {})
|
||||
return self._decode_result(method="getMe", payload=result, model=User)
|
||||
|
||||
async def answer_callback_query(
|
||||
self,
|
||||
callback_query_id: str,
|
||||
text: str | None = None,
|
||||
show_alert: bool | None = None,
|
||||
) -> bool:
|
||||
params: dict[str, Any] = {"callback_query_id": callback_query_id}
|
||||
if text is not None:
|
||||
params["text"] = text
|
||||
if show_alert is not None:
|
||||
params["show_alert"] = show_alert
|
||||
result = await self._post("answerCallbackQuery", params)
|
||||
return bool(result)
|
||||
|
||||
async def get_chat(self, chat_id: int) -> Chat | None:
|
||||
result = await self._post("getChat", {"chat_id": chat_id})
|
||||
return self._decode_result(method="getChat", payload=result, model=Chat)
|
||||
|
||||
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
|
||||
result = await self._post(
|
||||
"getChatMember", {"chat_id": chat_id, "user_id": user_id}
|
||||
)
|
||||
return self._decode_result(
|
||||
method="getChatMember",
|
||||
payload=result,
|
||||
model=ChatMember,
|
||||
)
|
||||
|
||||
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
|
||||
result = await self._post(
|
||||
"createForumTopic", {"chat_id": chat_id, "name": name}
|
||||
)
|
||||
return self._decode_result(
|
||||
method="createForumTopic",
|
||||
payload=result,
|
||||
model=ForumTopic,
|
||||
)
|
||||
|
||||
async def edit_forum_topic(
|
||||
self,
|
||||
chat_id: int,
|
||||
message_thread_id: int,
|
||||
name: str,
|
||||
) -> bool:
|
||||
result = await self._post(
|
||||
"editForumTopic",
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
"message_thread_id": message_thread_id,
|
||||
"name": name,
|
||||
},
|
||||
)
|
||||
return bool(result)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .cancel import handle_callback_cancel, handle_cancel
|
||||
from .menu import build_bot_commands
|
||||
from .parse import is_cancel_command
|
||||
|
||||
__all__ = [
|
||||
"build_bot_commands",
|
||||
"handle_callback_cancel",
|
||||
"handle_cancel",
|
||||
"is_cancel_command",
|
||||
]
|
||||
@@ -0,0 +1,160 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...context import RunContext
|
||||
from ...directives import DirectiveError
|
||||
from ..chat_prefs import ChatPrefsStore
|
||||
from ..engine_defaults import resolve_engine_for_message
|
||||
from ..files import split_command_args
|
||||
from ..topic_state import TopicStateStore
|
||||
from ..topics import _topic_key
|
||||
from ..types import TelegramIncomingMessage
|
||||
from .reply import make_reply
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bridge import TelegramBridgeConfig
|
||||
|
||||
AGENT_USAGE = "usage: `/agent`, `/agent set <engine>`, or `/agent clear`"
|
||||
|
||||
|
||||
async def _check_agent_permissions(
|
||||
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
|
||||
) -> bool:
|
||||
reply = make_reply(cfg, msg)
|
||||
sender_id = msg.sender_id
|
||||
if sender_id is None:
|
||||
await reply(text="cannot verify sender for agent defaults.")
|
||||
return False
|
||||
is_private = msg.chat_type == "private"
|
||||
if msg.chat_type is None:
|
||||
is_private = msg.chat_id > 0
|
||||
if is_private:
|
||||
return True
|
||||
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
|
||||
if member is None:
|
||||
await reply(text="failed to verify agent permissions.")
|
||||
return False
|
||||
if member.status in {"creator", "administrator"}:
|
||||
return True
|
||||
await reply(text="changing default agents is restricted to group admins.")
|
||||
return False
|
||||
|
||||
|
||||
async def _handle_agent_command(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
chat_prefs: ChatPrefsStore | None,
|
||||
*,
|
||||
resolved_scope: str | None = None,
|
||||
scope_chat_ids: frozenset[int] | None = None,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
tkey = (
|
||||
_topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
|
||||
if topic_store is not None
|
||||
else None
|
||||
)
|
||||
tokens = split_command_args(args_text)
|
||||
action = tokens[0].lower() if tokens else "show"
|
||||
|
||||
if action in {"show", ""}:
|
||||
try:
|
||||
resolved = cfg.runtime.resolve_message(
|
||||
text="",
|
||||
reply_text=msg.reply_to_text,
|
||||
ambient_context=ambient_context,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
except DirectiveError as exc:
|
||||
await reply(text=f"error:\n{exc}")
|
||||
return
|
||||
selection = await resolve_engine_for_message(
|
||||
runtime=cfg.runtime,
|
||||
context=resolved.context,
|
||||
explicit_engine=None,
|
||||
chat_id=msg.chat_id,
|
||||
topic_key=tkey,
|
||||
topic_store=topic_store,
|
||||
chat_prefs=chat_prefs,
|
||||
)
|
||||
source_labels = {
|
||||
"directive": "directive",
|
||||
"topic_default": "topic default",
|
||||
"chat_default": "chat default",
|
||||
"project_default": "project default",
|
||||
"global_default": "global default",
|
||||
}
|
||||
agent_line = f"agent: {selection.engine} ({source_labels[selection.source]})"
|
||||
topic_default = selection.topic_default or "none"
|
||||
if tkey is None:
|
||||
topic_default = "none"
|
||||
if chat_prefs is None:
|
||||
chat_default = "unavailable"
|
||||
else:
|
||||
chat_default = selection.chat_default or "none"
|
||||
project_default = (
|
||||
selection.project_default
|
||||
if selection.project_default is not None
|
||||
else "none"
|
||||
)
|
||||
defaults_line = (
|
||||
"defaults: "
|
||||
f"topic: {topic_default}, "
|
||||
f"chat: {chat_default}, "
|
||||
f"project: {project_default}, "
|
||||
f"global: {cfg.runtime.default_engine}"
|
||||
)
|
||||
available = ", ".join(cfg.runtime.engine_ids)
|
||||
available_line = f"available: {available}"
|
||||
await reply(text="\n\n".join([agent_line, defaults_line, available_line]))
|
||||
return
|
||||
|
||||
if action == "set":
|
||||
if len(tokens) < 2:
|
||||
await reply(text=AGENT_USAGE)
|
||||
return
|
||||
if not await _check_agent_permissions(cfg, msg):
|
||||
return
|
||||
engine = tokens[1].strip().lower()
|
||||
if engine not in cfg.runtime.engine_ids:
|
||||
available = ", ".join(cfg.runtime.engine_ids)
|
||||
await reply(
|
||||
text=f"unknown engine `{engine}`.\navailable agents: `{available}`",
|
||||
)
|
||||
return
|
||||
if tkey is not None:
|
||||
if topic_store is None:
|
||||
await reply(text="topic defaults are unavailable.")
|
||||
return
|
||||
await topic_store.set_default_engine(tkey[0], tkey[1], engine)
|
||||
await reply(text=f"topic default agent set to `{engine}`")
|
||||
return
|
||||
if chat_prefs is None:
|
||||
await reply(text="chat defaults are unavailable (no config path).")
|
||||
return
|
||||
await chat_prefs.set_default_engine(msg.chat_id, engine)
|
||||
await reply(text=f"chat default agent set to `{engine}`")
|
||||
return
|
||||
|
||||
if action == "clear":
|
||||
if not await _check_agent_permissions(cfg, msg):
|
||||
return
|
||||
if tkey is not None:
|
||||
if topic_store is None:
|
||||
await reply(text="topic defaults are unavailable.")
|
||||
return
|
||||
await topic_store.clear_default_engine(tkey[0], tkey[1])
|
||||
await reply(text="topic default agent cleared.")
|
||||
return
|
||||
if chat_prefs is None:
|
||||
await reply(text="chat defaults are unavailable (no config path).")
|
||||
return
|
||||
await chat_prefs.clear_default_engine(msg.chat_id)
|
||||
await reply(text="chat default agent cleared.")
|
||||
return
|
||||
|
||||
await reply(text=AGENT_USAGE)
|
||||
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...logging import get_logger
|
||||
from ...runner_bridge import RunningTasks
|
||||
from ...transport import MessageRef
|
||||
from ..types import TelegramCallbackQuery, TelegramIncomingMessage
|
||||
from .reply import make_reply
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bridge import TelegramBridgeConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def handle_cancel(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
running_tasks: RunningTasks,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
chat_id = msg.chat_id
|
||||
reply_id = msg.reply_to_message_id
|
||||
|
||||
if reply_id is None:
|
||||
if msg.reply_to_text:
|
||||
await reply(text="nothing is currently running for that message.")
|
||||
return
|
||||
await reply(text="reply to the progress message to cancel.")
|
||||
return
|
||||
|
||||
progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id)
|
||||
running_task = running_tasks.get(progress_ref)
|
||||
if running_task is None:
|
||||
await reply(text="nothing is currently running for that message.")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"cancel.requested",
|
||||
chat_id=chat_id,
|
||||
progress_message_id=reply_id,
|
||||
)
|
||||
running_task.cancel_requested.set()
|
||||
|
||||
|
||||
async def handle_callback_cancel(
|
||||
cfg: TelegramBridgeConfig,
|
||||
query: TelegramCallbackQuery,
|
||||
running_tasks: RunningTasks,
|
||||
) -> None:
|
||||
progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id)
|
||||
running_task = running_tasks.get(progress_ref)
|
||||
if running_task is None:
|
||||
await cfg.bot.answer_callback_query(
|
||||
callback_query_id=query.callback_query_id,
|
||||
text="nothing is currently running for that message.",
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"cancel.requested",
|
||||
chat_id=query.chat_id,
|
||||
progress_message_id=query.message_id,
|
||||
)
|
||||
running_task.cancel_requested.set()
|
||||
await cfg.bot.answer_callback_query(
|
||||
callback_query_id=query.callback_query_id,
|
||||
text="cancelling...",
|
||||
)
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import anyio
|
||||
|
||||
from ...commands import CommandContext, get_command
|
||||
from ...config import ConfigError
|
||||
from ...logging import get_logger
|
||||
from ...model import EngineId, ResumeToken
|
||||
from ...runner_bridge import RunningTasks
|
||||
from ...scheduler import ThreadScheduler
|
||||
from ...transport import MessageRef
|
||||
from ..files import split_command_args
|
||||
from ..types import TelegramIncomingMessage
|
||||
from .executor import _TelegramCommandExecutor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bridge import TelegramBridgeConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def _dispatch_command(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
text: str,
|
||||
command_id: str,
|
||||
args_text: str,
|
||||
running_tasks: RunningTasks,
|
||||
scheduler: ThreadScheduler,
|
||||
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None,
|
||||
stateful_mode: bool,
|
||||
default_engine_override: EngineId | None,
|
||||
) -> None:
|
||||
allowlist = cfg.runtime.allowlist
|
||||
chat_id = msg.chat_id
|
||||
user_msg_id = msg.message_id
|
||||
reply_ref = (
|
||||
MessageRef(
|
||||
channel_id=chat_id,
|
||||
message_id=msg.reply_to_message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
if msg.reply_to_message_id is not None
|
||||
else None
|
||||
)
|
||||
executor = _TelegramCommandExecutor(
|
||||
exec_cfg=cfg.exec_cfg,
|
||||
runtime=cfg.runtime,
|
||||
running_tasks=running_tasks,
|
||||
scheduler=scheduler,
|
||||
on_thread_known=on_thread_known,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
thread_id=msg.thread_id,
|
||||
show_resume_line=cfg.show_resume_line,
|
||||
stateful_mode=stateful_mode,
|
||||
default_engine_override=default_engine_override,
|
||||
)
|
||||
message_ref = MessageRef(
|
||||
channel_id=chat_id,
|
||||
message_id=user_msg_id,
|
||||
thread_id=msg.thread_id,
|
||||
sender_id=msg.sender_id,
|
||||
raw=msg.raw,
|
||||
)
|
||||
try:
|
||||
backend = get_command(command_id, allowlist=allowlist, required=False)
|
||||
except ConfigError as exc:
|
||||
await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True)
|
||||
return
|
||||
if backend is None:
|
||||
return
|
||||
try:
|
||||
plugin_config = cfg.runtime.plugin_config(command_id)
|
||||
except ConfigError as exc:
|
||||
await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True)
|
||||
return
|
||||
ctx = CommandContext(
|
||||
command=command_id,
|
||||
text=text,
|
||||
args_text=args_text,
|
||||
args=split_command_args(args_text),
|
||||
message=message_ref,
|
||||
reply_to=reply_ref,
|
||||
reply_text=msg.reply_to_text,
|
||||
config_path=cfg.runtime.config_path,
|
||||
plugin_config=plugin_config,
|
||||
runtime=cfg.runtime,
|
||||
executor=executor,
|
||||
)
|
||||
try:
|
||||
result = await backend.handle(ctx)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"command.failed",
|
||||
command=command_id,
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True)
|
||||
return
|
||||
if result is not None:
|
||||
reply_to = message_ref if result.reply_to is None else result.reply_to
|
||||
await executor.send(result.text, reply_to=reply_to, notify=result.notify)
|
||||
@@ -0,0 +1,382 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
|
||||
from ...commands import CommandExecutor, RunMode, RunRequest, RunResult
|
||||
from ...config import ConfigError
|
||||
from ...context import RunContext
|
||||
from ...logging import bind_run_context, clear_context, get_logger
|
||||
from ...model import EngineId, ResumeToken, TakopiEvent
|
||||
from ...progress import ProgressTracker
|
||||
from ...router import RunnerUnavailableError
|
||||
from ...runner import Runner
|
||||
from ...runner_bridge import (
|
||||
ExecBridgeConfig,
|
||||
IncomingMessage as RunnerIncomingMessage,
|
||||
RunningTasks,
|
||||
handle_message,
|
||||
)
|
||||
from ...scheduler import ThreadScheduler
|
||||
from ...transport import MessageRef, RenderedMessage, SendOptions
|
||||
from ...transport_runtime import TransportRuntime
|
||||
from ...utils.paths import reset_run_base_dir, set_run_base_dir
|
||||
from ..bridge import send_plain
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _ResumeLineProxy:
|
||||
runner: Runner
|
||||
|
||||
@property
|
||||
def engine(self) -> str:
|
||||
return self.runner.engine
|
||||
|
||||
def is_resume_line(self, line: str) -> bool:
|
||||
return self.runner.is_resume_line(line)
|
||||
|
||||
def format_resume(self, _: ResumeToken) -> str:
|
||||
return ""
|
||||
|
||||
def extract_resume(self, text: str | None) -> ResumeToken | None:
|
||||
return self.runner.extract_resume(text)
|
||||
|
||||
def run(
|
||||
self, prompt: str, resume: ResumeToken | None
|
||||
) -> AsyncIterator[TakopiEvent]:
|
||||
return self.runner.run(prompt, resume)
|
||||
|
||||
|
||||
def _should_show_resume_line(
|
||||
*,
|
||||
show_resume_line: bool,
|
||||
stateful_mode: bool,
|
||||
context: RunContext | None,
|
||||
) -> bool:
|
||||
if show_resume_line or not stateful_mode:
|
||||
return True
|
||||
return context is None or context.project is None
|
||||
|
||||
|
||||
async def _send_runner_unavailable(
|
||||
exec_cfg: ExecBridgeConfig,
|
||||
*,
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
resume_token: ResumeToken | None,
|
||||
runner: Runner,
|
||||
reason: str,
|
||||
thread_id: int | None = None,
|
||||
) -> None:
|
||||
tracker = ProgressTracker(engine=runner.engine)
|
||||
tracker.set_resume(resume_token)
|
||||
state = tracker.snapshot(resume_formatter=runner.format_resume)
|
||||
message = exec_cfg.presenter.render_final(
|
||||
state,
|
||||
elapsed_s=0.0,
|
||||
status="error",
|
||||
answer=f"error:\n{reason}",
|
||||
)
|
||||
reply_to = MessageRef(channel_id=chat_id, message_id=user_msg_id)
|
||||
await exec_cfg.transport.send(
|
||||
channel_id=chat_id,
|
||||
message=message,
|
||||
options=SendOptions(reply_to=reply_to, notify=True, thread_id=thread_id),
|
||||
)
|
||||
|
||||
|
||||
async def _run_engine(
|
||||
*,
|
||||
exec_cfg: ExecBridgeConfig,
|
||||
runtime: TransportRuntime,
|
||||
running_tasks: RunningTasks | None,
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
resume_token: ResumeToken | None,
|
||||
context: RunContext | None,
|
||||
reply_ref: MessageRef | None = None,
|
||||
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
|
||||
| None = None,
|
||||
engine_override: EngineId | None = None,
|
||||
thread_id: int | None = None,
|
||||
show_resume_line: bool = True,
|
||||
) -> None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
exec_cfg.transport,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
try:
|
||||
try:
|
||||
entry = runtime.resolve_runner(
|
||||
resume_token=resume_token,
|
||||
engine_override=engine_override,
|
||||
)
|
||||
except RunnerUnavailableError as exc:
|
||||
await reply(text=f"error:\n{exc}")
|
||||
return
|
||||
runner: Runner = entry.runner
|
||||
if not show_resume_line:
|
||||
runner = cast(Runner, _ResumeLineProxy(runner))
|
||||
if not entry.available:
|
||||
reason = entry.issue or "engine unavailable"
|
||||
await _send_runner_unavailable(
|
||||
exec_cfg,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
resume_token=resume_token,
|
||||
runner=runner,
|
||||
reason=reason,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
return
|
||||
try:
|
||||
cwd = runtime.resolve_run_cwd(context)
|
||||
except ConfigError as exc:
|
||||
await reply(text=f"error:\n{exc}")
|
||||
return
|
||||
run_base_token = set_run_base_dir(cwd)
|
||||
try:
|
||||
run_fields = {
|
||||
"chat_id": chat_id,
|
||||
"user_msg_id": user_msg_id,
|
||||
"engine": runner.engine,
|
||||
"resume": resume_token.value if resume_token else None,
|
||||
}
|
||||
if context is not None:
|
||||
run_fields["project"] = context.project
|
||||
run_fields["branch"] = context.branch
|
||||
if cwd is not None:
|
||||
run_fields["cwd"] = str(cwd)
|
||||
bind_run_context(**run_fields)
|
||||
context_line = runtime.format_context_line(context)
|
||||
incoming = RunnerIncomingMessage(
|
||||
channel_id=chat_id,
|
||||
message_id=user_msg_id,
|
||||
text=text,
|
||||
reply_to=reply_ref,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
await handle_message(
|
||||
exec_cfg,
|
||||
runner=runner,
|
||||
incoming=incoming,
|
||||
resume_token=resume_token,
|
||||
context=context,
|
||||
context_line=context_line,
|
||||
strip_resume_line=runtime.is_resume_line,
|
||||
running_tasks=running_tasks,
|
||||
on_thread_known=on_thread_known,
|
||||
)
|
||||
finally:
|
||||
reset_run_base_dir(run_base_token)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"handle.worker_failed",
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
finally:
|
||||
clear_context()
|
||||
|
||||
|
||||
class _CaptureTransport:
|
||||
def __init__(self) -> None:
|
||||
self._next_id = 1
|
||||
self.last_message: RenderedMessage | None = None
|
||||
|
||||
async def send(
|
||||
self,
|
||||
*,
|
||||
channel_id: int | str,
|
||||
message: RenderedMessage,
|
||||
options: SendOptions | None = None,
|
||||
) -> MessageRef:
|
||||
thread_id = options.thread_id if options is not None else None
|
||||
ref = MessageRef(channel_id=channel_id, message_id=self._next_id)
|
||||
self._next_id += 1
|
||||
self.last_message = message
|
||||
return MessageRef(
|
||||
channel_id=ref.channel_id,
|
||||
message_id=ref.message_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
async def edit(
|
||||
self, *, ref: MessageRef, message: RenderedMessage, wait: bool = True
|
||||
) -> MessageRef:
|
||||
self.last_message = message
|
||||
return ref
|
||||
|
||||
async def delete(self, *, ref: MessageRef) -> bool:
|
||||
return True
|
||||
|
||||
async def close(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _TelegramCommandExecutor(CommandExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
exec_cfg: ExecBridgeConfig,
|
||||
runtime: TransportRuntime,
|
||||
running_tasks: RunningTasks,
|
||||
scheduler: ThreadScheduler,
|
||||
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None,
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
thread_id: int | None,
|
||||
show_resume_line: bool,
|
||||
stateful_mode: bool,
|
||||
default_engine_override: EngineId | None,
|
||||
) -> None:
|
||||
self._exec_cfg = exec_cfg
|
||||
self._runtime = runtime
|
||||
self._running_tasks = running_tasks
|
||||
self._scheduler = scheduler
|
||||
self._on_thread_known = on_thread_known
|
||||
self._chat_id = chat_id
|
||||
self._user_msg_id = user_msg_id
|
||||
self._thread_id = thread_id
|
||||
self._show_resume_line = show_resume_line
|
||||
self._stateful_mode = stateful_mode
|
||||
self._default_engine_override = default_engine_override
|
||||
self._reply_ref = MessageRef(
|
||||
channel_id=chat_id,
|
||||
message_id=user_msg_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
def _apply_default_context(self, request: RunRequest) -> RunRequest:
|
||||
if request.context is not None:
|
||||
return request
|
||||
context = self._runtime.default_context_for_chat(self._chat_id)
|
||||
if context is None:
|
||||
return request
|
||||
return RunRequest(
|
||||
prompt=request.prompt,
|
||||
engine=request.engine,
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _apply_default_engine(self, request: RunRequest) -> RunRequest:
|
||||
if request.engine is not None or self._default_engine_override is None:
|
||||
return request
|
||||
return RunRequest(
|
||||
prompt=request.prompt,
|
||||
engine=self._default_engine_override,
|
||||
context=request.context,
|
||||
)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
message: RenderedMessage | str,
|
||||
*,
|
||||
reply_to: MessageRef | None = None,
|
||||
notify: bool = True,
|
||||
) -> MessageRef | None:
|
||||
rendered = (
|
||||
message
|
||||
if isinstance(message, RenderedMessage)
|
||||
else RenderedMessage(text=message)
|
||||
)
|
||||
reply_ref = self._reply_ref if reply_to is None else reply_to
|
||||
return await self._exec_cfg.transport.send(
|
||||
channel_id=self._chat_id,
|
||||
message=rendered,
|
||||
options=SendOptions(
|
||||
reply_to=reply_ref,
|
||||
notify=notify,
|
||||
thread_id=self._thread_id,
|
||||
),
|
||||
)
|
||||
|
||||
async def run_one(
|
||||
self, request: RunRequest, *, mode: RunMode = "emit"
|
||||
) -> RunResult:
|
||||
request = self._apply_default_context(request)
|
||||
request = self._apply_default_engine(request)
|
||||
effective_show_resume_line = _should_show_resume_line(
|
||||
show_resume_line=self._show_resume_line,
|
||||
stateful_mode=self._stateful_mode,
|
||||
context=request.context,
|
||||
)
|
||||
engine = self._runtime.resolve_engine(
|
||||
engine_override=request.engine,
|
||||
context=request.context,
|
||||
)
|
||||
on_thread_known = (
|
||||
self._scheduler.note_thread_known
|
||||
if self._on_thread_known is None
|
||||
else self._on_thread_known
|
||||
)
|
||||
if mode == "capture":
|
||||
capture = _CaptureTransport()
|
||||
exec_cfg = ExecBridgeConfig(
|
||||
transport=capture,
|
||||
presenter=self._exec_cfg.presenter,
|
||||
final_notify=False,
|
||||
)
|
||||
await _run_engine(
|
||||
exec_cfg=exec_cfg,
|
||||
runtime=self._runtime,
|
||||
running_tasks={},
|
||||
chat_id=self._chat_id,
|
||||
user_msg_id=self._user_msg_id,
|
||||
text=request.prompt,
|
||||
resume_token=None,
|
||||
context=request.context,
|
||||
reply_ref=self._reply_ref,
|
||||
on_thread_known=on_thread_known,
|
||||
engine_override=engine,
|
||||
thread_id=self._thread_id,
|
||||
show_resume_line=effective_show_resume_line,
|
||||
)
|
||||
return RunResult(engine=engine, message=capture.last_message)
|
||||
await _run_engine(
|
||||
exec_cfg=self._exec_cfg,
|
||||
runtime=self._runtime,
|
||||
running_tasks=self._running_tasks,
|
||||
chat_id=self._chat_id,
|
||||
user_msg_id=self._user_msg_id,
|
||||
text=request.prompt,
|
||||
resume_token=None,
|
||||
context=request.context,
|
||||
reply_ref=self._reply_ref,
|
||||
on_thread_known=on_thread_known,
|
||||
engine_override=engine,
|
||||
thread_id=self._thread_id,
|
||||
show_resume_line=effective_show_resume_line,
|
||||
)
|
||||
return RunResult(engine=engine, message=None)
|
||||
|
||||
async def run_many(
|
||||
self,
|
||||
requests: Sequence[RunRequest],
|
||||
*,
|
||||
mode: RunMode = "emit",
|
||||
parallel: bool = False,
|
||||
) -> list[RunResult]:
|
||||
if not parallel:
|
||||
return [await self.run_one(request, mode=mode) for request in requests]
|
||||
results: list[RunResult | None] = [None] * len(requests)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
async def run_idx(idx: int, request: RunRequest) -> None:
|
||||
results[idx] = await self.run_one(request, mode=mode)
|
||||
|
||||
for idx, request in enumerate(requests):
|
||||
tg.start_soon(run_idx, idx, request)
|
||||
|
||||
return [result for result in results if result is not None]
|
||||
@@ -0,0 +1,589 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...config import ConfigError
|
||||
from ...context import RunContext
|
||||
from ...directives import DirectiveError
|
||||
from ...transport_runtime import ResolvedMessage
|
||||
from ..context import _format_context
|
||||
from ..files import (
|
||||
default_upload_name,
|
||||
default_upload_path,
|
||||
deny_reason,
|
||||
format_bytes,
|
||||
normalize_relative_path,
|
||||
parse_file_command,
|
||||
parse_file_prompt,
|
||||
resolve_path_within_root,
|
||||
write_bytes_atomic,
|
||||
ZipTooLargeError,
|
||||
zip_directory,
|
||||
)
|
||||
from ..topic_state import TopicStateStore
|
||||
from ..topics import _maybe_update_topic_context, _topic_key
|
||||
from ..types import TelegramDocument, TelegramIncomingMessage
|
||||
from .reply import make_reply
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bridge import TelegramBridgeConfig
|
||||
|
||||
FILE_PUT_USAGE = "usage: `/file put <path>`"
|
||||
FILE_GET_USAGE = "usage: `/file get <path>`"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _FilePutPlan:
|
||||
resolved: ResolvedMessage
|
||||
run_root: Path
|
||||
path_value: str | None
|
||||
force: bool
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _FilePutResult:
|
||||
name: str
|
||||
rel_path: Path | None
|
||||
size: int | None
|
||||
error: str | None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _SavedFilePut:
|
||||
context: RunContext | None
|
||||
rel_path: Path
|
||||
size: int
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _SavedFilePutGroup:
|
||||
context: RunContext | None
|
||||
base_dir: Path | None
|
||||
saved: list[_FilePutResult]
|
||||
failed: list[_FilePutResult]
|
||||
|
||||
|
||||
def resolve_file_put_paths(
|
||||
plan: _FilePutPlan,
|
||||
*,
|
||||
cfg: TelegramBridgeConfig,
|
||||
require_dir: bool,
|
||||
) -> tuple[Path | None, Path | None, str | None]:
|
||||
path_value = plan.path_value
|
||||
if not path_value:
|
||||
return None, None, None
|
||||
if require_dir or path_value.endswith("/"):
|
||||
base_dir = normalize_relative_path(path_value)
|
||||
if base_dir is None:
|
||||
return None, None, "invalid upload path."
|
||||
deny_rule = deny_reason(base_dir, cfg.files.deny_globs)
|
||||
if deny_rule is not None:
|
||||
return None, None, f"path denied by rule: {deny_rule}"
|
||||
base_target = resolve_path_within_root(plan.run_root, base_dir)
|
||||
if base_target is None:
|
||||
return None, None, "upload path escapes the repo root."
|
||||
if base_target.exists() and not base_target.is_dir():
|
||||
return None, None, "upload path is a file."
|
||||
return base_dir, None, None
|
||||
rel_path = normalize_relative_path(path_value)
|
||||
if rel_path is None:
|
||||
return None, None, "invalid upload path."
|
||||
return None, rel_path, None
|
||||
|
||||
|
||||
async def _check_file_permissions(
|
||||
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
|
||||
) -> bool:
|
||||
reply = make_reply(cfg, msg)
|
||||
sender_id = msg.sender_id
|
||||
if sender_id is None:
|
||||
await reply(text="cannot verify sender for file transfer.")
|
||||
return False
|
||||
if cfg.files.allowed_user_ids:
|
||||
if sender_id not in cfg.files.allowed_user_ids:
|
||||
await reply(text="file transfer is not allowed for this user.")
|
||||
return False
|
||||
return True
|
||||
is_private = msg.chat_type == "private"
|
||||
if msg.chat_type is None:
|
||||
is_private = msg.chat_id > 0
|
||||
if is_private:
|
||||
return True
|
||||
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
|
||||
if member is None:
|
||||
await reply(text="failed to verify file transfer permissions.")
|
||||
return False
|
||||
if member.status in {"creator", "administrator"}:
|
||||
return True
|
||||
await reply(text="file transfer is restricted to group admins.")
|
||||
return False
|
||||
|
||||
|
||||
async def _prepare_file_put_plan(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> _FilePutPlan | None:
|
||||
reply = make_reply(cfg, msg)
|
||||
if not await _check_file_permissions(cfg, msg):
|
||||
return None
|
||||
try:
|
||||
resolved = cfg.runtime.resolve_message(
|
||||
text=args_text,
|
||||
reply_text=msg.reply_to_text,
|
||||
ambient_context=ambient_context,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
except DirectiveError as exc:
|
||||
await reply(text=f"error:\n{exc}")
|
||||
return None
|
||||
topic_key = _topic_key(msg, cfg) if topic_store is not None else None
|
||||
await _maybe_update_topic_context(
|
||||
cfg=cfg,
|
||||
topic_store=topic_store,
|
||||
topic_key=topic_key,
|
||||
context=resolved.context,
|
||||
context_source=resolved.context_source,
|
||||
)
|
||||
if resolved.context is None or resolved.context.project is None:
|
||||
await reply(text="no project context available for file upload.")
|
||||
return None
|
||||
try:
|
||||
run_root = cfg.runtime.resolve_run_cwd(resolved.context)
|
||||
except ConfigError as exc:
|
||||
await reply(text=f"error:\n{exc}")
|
||||
return None
|
||||
if run_root is None:
|
||||
await reply(text="no project context available for file upload.")
|
||||
return None
|
||||
path_value, force, error = parse_file_prompt(resolved.prompt, allow_empty=True)
|
||||
if error is not None:
|
||||
await reply(text=error)
|
||||
return None
|
||||
return _FilePutPlan(
|
||||
resolved=resolved,
|
||||
run_root=run_root,
|
||||
path_value=path_value,
|
||||
force=force,
|
||||
)
|
||||
|
||||
|
||||
def _format_file_put_failures(failed: Sequence[_FilePutResult]) -> str | None:
|
||||
if not failed:
|
||||
return None
|
||||
errors = ", ".join(
|
||||
f"`{item.name}` ({item.error})" for item in failed if item.error is not None
|
||||
)
|
||||
if not errors:
|
||||
return None
|
||||
return f"failed: {errors}"
|
||||
|
||||
|
||||
async def _save_document_payload(
|
||||
cfg: TelegramBridgeConfig,
|
||||
*,
|
||||
document: TelegramDocument,
|
||||
run_root: Path,
|
||||
rel_path: Path | None,
|
||||
base_dir: Path | None,
|
||||
force: bool,
|
||||
) -> _FilePutResult:
|
||||
name = default_upload_name(document.file_name, None)
|
||||
if (
|
||||
document.file_size is not None
|
||||
and document.file_size > cfg.files.max_upload_bytes
|
||||
):
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=None,
|
||||
size=None,
|
||||
error="file is too large to upload.",
|
||||
)
|
||||
file_info = await cfg.bot.get_file(document.file_id)
|
||||
if file_info is None:
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=None,
|
||||
size=None,
|
||||
error="failed to fetch file metadata.",
|
||||
)
|
||||
file_path = file_info.file_path
|
||||
name = default_upload_name(document.file_name, file_path)
|
||||
resolved_path = rel_path
|
||||
if resolved_path is None:
|
||||
if base_dir is None:
|
||||
resolved_path = default_upload_path(
|
||||
cfg.files.uploads_dir, document.file_name, file_path
|
||||
)
|
||||
else:
|
||||
resolved_path = base_dir / name
|
||||
deny_rule = deny_reason(resolved_path, cfg.files.deny_globs)
|
||||
if deny_rule is not None:
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=None,
|
||||
size=None,
|
||||
error=f"path denied by rule: {deny_rule}",
|
||||
)
|
||||
target = resolve_path_within_root(run_root, resolved_path)
|
||||
if target is None:
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=None,
|
||||
size=None,
|
||||
error="upload path escapes the repo root.",
|
||||
)
|
||||
if target.exists():
|
||||
if target.is_dir():
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=None,
|
||||
size=None,
|
||||
error="upload target is a directory.",
|
||||
)
|
||||
if not force:
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=None,
|
||||
size=None,
|
||||
error="file already exists; use --force to overwrite.",
|
||||
)
|
||||
payload = await cfg.bot.download_file(file_path)
|
||||
if payload is None:
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=None,
|
||||
size=None,
|
||||
error="failed to download file.",
|
||||
)
|
||||
if len(payload) > cfg.files.max_upload_bytes:
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=None,
|
||||
size=None,
|
||||
error="file is too large to upload.",
|
||||
)
|
||||
try:
|
||||
write_bytes_atomic(target, payload)
|
||||
except OSError as exc:
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=None,
|
||||
size=None,
|
||||
error=f"failed to write file: {exc}",
|
||||
)
|
||||
return _FilePutResult(
|
||||
name=name,
|
||||
rel_path=resolved_path,
|
||||
size=len(payload),
|
||||
error=None,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_file_command(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
command, rest, error = parse_file_command(args_text)
|
||||
if error is not None:
|
||||
await reply(text=error)
|
||||
return
|
||||
if command == "put":
|
||||
await _handle_file_put(cfg, msg, rest, ambient_context, topic_store)
|
||||
else:
|
||||
await _handle_file_get(cfg, msg, rest, ambient_context, topic_store)
|
||||
|
||||
|
||||
async def _handle_file_put_default(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> None:
|
||||
await _handle_file_put(cfg, msg, "", ambient_context, topic_store)
|
||||
|
||||
|
||||
async def _save_file_put(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> _SavedFilePut | None:
|
||||
reply = make_reply(cfg, msg)
|
||||
document = msg.document
|
||||
if document is None:
|
||||
await reply(text=FILE_PUT_USAGE)
|
||||
return None
|
||||
plan = await _prepare_file_put_plan(
|
||||
cfg,
|
||||
msg,
|
||||
args_text,
|
||||
ambient_context,
|
||||
topic_store,
|
||||
)
|
||||
if plan is None:
|
||||
return None
|
||||
base_dir, rel_path, error = resolve_file_put_paths(
|
||||
plan,
|
||||
cfg=cfg,
|
||||
require_dir=False,
|
||||
)
|
||||
if error is not None:
|
||||
await reply(text=error)
|
||||
return None
|
||||
result = await _save_document_payload(
|
||||
cfg,
|
||||
document=document,
|
||||
run_root=plan.run_root,
|
||||
rel_path=rel_path,
|
||||
base_dir=base_dir,
|
||||
force=plan.force,
|
||||
)
|
||||
if result.error is not None:
|
||||
await reply(text=result.error)
|
||||
return None
|
||||
if result.rel_path is None or result.size is None:
|
||||
await reply(text="failed to save file.")
|
||||
return None
|
||||
return _SavedFilePut(
|
||||
context=plan.resolved.context,
|
||||
rel_path=result.rel_path,
|
||||
size=result.size,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_file_put(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
saved = await _save_file_put(
|
||||
cfg,
|
||||
msg,
|
||||
args_text,
|
||||
ambient_context,
|
||||
topic_store,
|
||||
)
|
||||
if saved is None:
|
||||
return
|
||||
context_label = _format_context(cfg.runtime, saved.context)
|
||||
await reply(
|
||||
text=(
|
||||
f"saved `{saved.rel_path.as_posix()}` "
|
||||
f"in `{context_label}` ({format_bytes(saved.size)})"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_file_put_group(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
messages: Sequence[TelegramIncomingMessage],
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
saved_group = await _save_file_put_group(
|
||||
cfg,
|
||||
msg,
|
||||
args_text,
|
||||
messages,
|
||||
ambient_context,
|
||||
topic_store,
|
||||
)
|
||||
if saved_group is None:
|
||||
return
|
||||
context_label = _format_context(cfg.runtime, saved_group.context)
|
||||
total_bytes = sum(item.size or 0 for item in saved_group.saved)
|
||||
dir_label: Path | None = saved_group.base_dir
|
||||
if dir_label is None and saved_group.saved:
|
||||
first_path = saved_group.saved[0].rel_path
|
||||
if first_path is not None:
|
||||
dir_label = first_path.parent
|
||||
if saved_group.saved:
|
||||
saved_names = ", ".join(f"`{item.name}`" for item in saved_group.saved)
|
||||
if dir_label is not None:
|
||||
dir_text = dir_label.as_posix()
|
||||
if not dir_text.endswith("/"):
|
||||
dir_text = f"{dir_text}/"
|
||||
text = (
|
||||
f"saved {saved_names} to `{dir_text}` "
|
||||
f"in `{context_label}` ({format_bytes(total_bytes)})"
|
||||
)
|
||||
else:
|
||||
text = (
|
||||
f"saved {saved_names} in `{context_label}` "
|
||||
f"({format_bytes(total_bytes)})"
|
||||
)
|
||||
else:
|
||||
text = "failed to upload files."
|
||||
failure_text = _format_file_put_failures(saved_group.failed)
|
||||
if failure_text is not None:
|
||||
text = f"{text}\n\n{failure_text}"
|
||||
await reply(text=text)
|
||||
|
||||
|
||||
async def _save_file_put_group(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
messages: Sequence[TelegramIncomingMessage],
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> _SavedFilePutGroup | None:
|
||||
reply = make_reply(cfg, msg)
|
||||
documents = [item.document for item in messages if item.document is not None]
|
||||
if not documents:
|
||||
await reply(text=FILE_PUT_USAGE)
|
||||
return None
|
||||
plan = await _prepare_file_put_plan(
|
||||
cfg,
|
||||
msg,
|
||||
args_text,
|
||||
ambient_context,
|
||||
topic_store,
|
||||
)
|
||||
if plan is None:
|
||||
return None
|
||||
base_dir, _, error = resolve_file_put_paths(
|
||||
plan,
|
||||
cfg=cfg,
|
||||
require_dir=True,
|
||||
)
|
||||
if error is not None:
|
||||
await reply(text=error)
|
||||
return None
|
||||
saved: list[_FilePutResult] = []
|
||||
failed: list[_FilePutResult] = []
|
||||
for document in documents:
|
||||
result = await _save_document_payload(
|
||||
cfg,
|
||||
document=document,
|
||||
run_root=plan.run_root,
|
||||
rel_path=None,
|
||||
base_dir=base_dir,
|
||||
force=plan.force,
|
||||
)
|
||||
if result.error is None:
|
||||
saved.append(result)
|
||||
else:
|
||||
failed.append(result)
|
||||
return _SavedFilePutGroup(
|
||||
context=plan.resolved.context,
|
||||
base_dir=base_dir,
|
||||
saved=saved,
|
||||
failed=failed,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_file_get(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
ambient_context: RunContext | None,
|
||||
topic_store: TopicStateStore | None,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
if not await _check_file_permissions(cfg, msg):
|
||||
return
|
||||
try:
|
||||
resolved = cfg.runtime.resolve_message(
|
||||
text=args_text,
|
||||
reply_text=msg.reply_to_text,
|
||||
ambient_context=ambient_context,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
except DirectiveError as exc:
|
||||
await reply(text=f"error:\n{exc}")
|
||||
return
|
||||
topic_key = _topic_key(msg, cfg) if topic_store is not None else None
|
||||
await _maybe_update_topic_context(
|
||||
cfg=cfg,
|
||||
topic_store=topic_store,
|
||||
topic_key=topic_key,
|
||||
context=resolved.context,
|
||||
context_source=resolved.context_source,
|
||||
)
|
||||
if resolved.context is None or resolved.context.project is None:
|
||||
await reply(text="no project context available for file download.")
|
||||
return
|
||||
try:
|
||||
run_root = cfg.runtime.resolve_run_cwd(resolved.context)
|
||||
except ConfigError as exc:
|
||||
await reply(text=f"error:\n{exc}")
|
||||
return
|
||||
if run_root is None:
|
||||
await reply(text="no project context available for file download.")
|
||||
return
|
||||
path_value = resolved.prompt
|
||||
if not path_value.strip():
|
||||
await reply(text=FILE_GET_USAGE)
|
||||
return
|
||||
rel_path = normalize_relative_path(path_value)
|
||||
if rel_path is None:
|
||||
await reply(text="invalid download path.")
|
||||
return
|
||||
deny_rule = deny_reason(rel_path, cfg.files.deny_globs)
|
||||
if deny_rule is not None:
|
||||
await reply(text=f"path denied by rule: {deny_rule}")
|
||||
return
|
||||
target = resolve_path_within_root(run_root, rel_path)
|
||||
if target is None:
|
||||
await reply(text="download path escapes the repo root.")
|
||||
return
|
||||
if not target.exists():
|
||||
await reply(text="file does not exist.")
|
||||
return
|
||||
if target.is_dir():
|
||||
try:
|
||||
payload = zip_directory(
|
||||
run_root,
|
||||
rel_path,
|
||||
cfg.files.deny_globs,
|
||||
max_bytes=cfg.files.max_download_bytes,
|
||||
)
|
||||
except ZipTooLargeError:
|
||||
await reply(text="file is too large to send.")
|
||||
return
|
||||
except OSError as exc:
|
||||
await reply(text=f"failed to read directory: {exc}")
|
||||
return
|
||||
filename = f"{rel_path.name or 'archive'}.zip"
|
||||
else:
|
||||
try:
|
||||
size = target.stat().st_size
|
||||
if size > cfg.files.max_download_bytes:
|
||||
await reply(text="file is too large to send.")
|
||||
return
|
||||
payload = target.read_bytes()
|
||||
except OSError as exc:
|
||||
await reply(text=f"failed to read file: {exc}")
|
||||
return
|
||||
filename = target.name
|
||||
if len(payload) > cfg.files.max_download_bytes:
|
||||
await reply(text="file is too large to send.")
|
||||
return
|
||||
sent = await cfg.bot.send_document(
|
||||
chat_id=msg.chat_id,
|
||||
filename=filename,
|
||||
content=payload,
|
||||
reply_to_message_id=msg.message_id,
|
||||
message_thread_id=msg.thread_id,
|
||||
)
|
||||
if sent is None:
|
||||
await reply(text="failed to send file.")
|
||||
return
|
||||
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...context import RunContext
|
||||
from ...directives import DirectiveError
|
||||
from ...transport_runtime import ResolvedMessage
|
||||
from ..context import _merge_topic_context
|
||||
from ..files import parse_file_command
|
||||
from ..topic_state import TopicStateStore
|
||||
from ..topics import _topic_key, _topics_chat_project
|
||||
from ..types import TelegramIncomingMessage
|
||||
from .file_transfer import (
|
||||
FILE_PUT_USAGE,
|
||||
_format_file_put_failures,
|
||||
_handle_file_put_group,
|
||||
_save_file_put_group,
|
||||
)
|
||||
from .parse import _parse_slash_command
|
||||
from .reply import make_reply
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bridge import TelegramBridgeConfig
|
||||
|
||||
|
||||
async def _handle_media_group(
|
||||
cfg: TelegramBridgeConfig,
|
||||
messages: Sequence[TelegramIncomingMessage],
|
||||
topic_store: TopicStateStore | None,
|
||||
run_prompt: Callable[
|
||||
[TelegramIncomingMessage, str, ResolvedMessage], Awaitable[None]
|
||||
]
|
||||
| None = None,
|
||||
resolve_prompt: Callable[
|
||||
[TelegramIncomingMessage, str, RunContext | None],
|
||||
Awaitable[ResolvedMessage | None],
|
||||
]
|
||||
| None = None,
|
||||
) -> None:
|
||||
if not messages:
|
||||
return
|
||||
ordered = sorted(messages, key=lambda item: item.message_id)
|
||||
command_msg = next(
|
||||
(item for item in ordered if item.text.strip()),
|
||||
ordered[0],
|
||||
)
|
||||
reply = make_reply(cfg, command_msg)
|
||||
topic_key = _topic_key(command_msg, cfg) if topic_store is not None else None
|
||||
chat_project = _topics_chat_project(cfg, command_msg.chat_id)
|
||||
bound_context = (
|
||||
await topic_store.get_context(*topic_key)
|
||||
if topic_store is not None and topic_key is not None
|
||||
else None
|
||||
)
|
||||
ambient_context = _merge_topic_context(
|
||||
chat_project=chat_project, bound=bound_context
|
||||
)
|
||||
command_id, args_text = _parse_slash_command(command_msg.text)
|
||||
if command_id == "file":
|
||||
command, rest, error = parse_file_command(args_text)
|
||||
if error is not None:
|
||||
await reply(text=error)
|
||||
return
|
||||
if command == "put":
|
||||
await _handle_file_put_group(
|
||||
cfg,
|
||||
command_msg,
|
||||
rest,
|
||||
ordered,
|
||||
ambient_context,
|
||||
topic_store,
|
||||
)
|
||||
return
|
||||
if cfg.files.enabled and cfg.files.auto_put:
|
||||
caption_text = command_msg.text.strip()
|
||||
if cfg.files.auto_put_mode == "prompt" and caption_text:
|
||||
if resolve_prompt is None:
|
||||
try:
|
||||
resolved = cfg.runtime.resolve_message(
|
||||
text=caption_text,
|
||||
reply_text=command_msg.reply_to_text,
|
||||
ambient_context=ambient_context,
|
||||
chat_id=command_msg.chat_id,
|
||||
)
|
||||
except DirectiveError as exc:
|
||||
await reply(text=f"error:\n{exc}")
|
||||
return
|
||||
else:
|
||||
resolved = await resolve_prompt(
|
||||
command_msg, caption_text, ambient_context
|
||||
)
|
||||
if resolved is None:
|
||||
return
|
||||
saved_group = await _save_file_put_group(
|
||||
cfg,
|
||||
command_msg,
|
||||
"",
|
||||
ordered,
|
||||
resolved.context,
|
||||
topic_store,
|
||||
)
|
||||
if saved_group is None:
|
||||
return
|
||||
if not saved_group.saved:
|
||||
failure_text = _format_file_put_failures(saved_group.failed)
|
||||
text = "failed to upload files."
|
||||
if failure_text is not None:
|
||||
text = f"{text}\n\n{failure_text}"
|
||||
await reply(text=text)
|
||||
return
|
||||
if saved_group.failed:
|
||||
failure_text = _format_file_put_failures(saved_group.failed)
|
||||
if failure_text is not None:
|
||||
await reply(text=f"some files failed to upload.\n\n{failure_text}")
|
||||
if run_prompt is None:
|
||||
await reply(text=FILE_PUT_USAGE)
|
||||
return
|
||||
paths = [
|
||||
item.rel_path.as_posix()
|
||||
for item in saved_group.saved
|
||||
if item.rel_path is not None
|
||||
]
|
||||
files_text = "\n".join(f"- {path}" for path in paths)
|
||||
prompt_base = resolved.prompt
|
||||
annotation = f"[uploaded files]\n{files_text}"
|
||||
if prompt_base and prompt_base.strip():
|
||||
prompt = f"{prompt_base}\n\n{annotation}"
|
||||
else:
|
||||
prompt = annotation
|
||||
await run_prompt(command_msg, prompt, resolved)
|
||||
return
|
||||
if not caption_text:
|
||||
await _handle_file_put_group(
|
||||
cfg,
|
||||
command_msg,
|
||||
"",
|
||||
ordered,
|
||||
ambient_context,
|
||||
topic_store,
|
||||
)
|
||||
return
|
||||
await reply(text=FILE_PUT_USAGE)
|
||||
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...commands import get_command
|
||||
from ...config import ConfigError
|
||||
from ...ids import RESERVED_COMMAND_IDS, is_valid_id
|
||||
from ...logging import get_logger
|
||||
from ...plugins import COMMAND_GROUP, list_entrypoints
|
||||
from ...transport_runtime import TransportRuntime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bridge import TelegramBridgeConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_MAX_BOT_COMMANDS = 100
|
||||
|
||||
|
||||
def build_bot_commands(
|
||||
runtime: TransportRuntime, *, include_file: bool = True
|
||||
) -> list[dict[str, str]]:
|
||||
commands: list[dict[str, str]] = []
|
||||
seen: set[str] = set()
|
||||
for engine_id in runtime.available_engine_ids():
|
||||
cmd = engine_id.lower()
|
||||
if cmd in seen:
|
||||
continue
|
||||
commands.append({"command": cmd, "description": f"use agent: {cmd}"})
|
||||
seen.add(cmd)
|
||||
for alias in runtime.project_aliases():
|
||||
cmd = alias.lower()
|
||||
if cmd in seen:
|
||||
continue
|
||||
if not is_valid_id(cmd):
|
||||
logger.debug(
|
||||
"startup.command_menu.skip_project",
|
||||
alias=alias,
|
||||
)
|
||||
continue
|
||||
commands.append({"command": cmd, "description": f"work on: {cmd}"})
|
||||
seen.add(cmd)
|
||||
allowlist = runtime.allowlist
|
||||
for ep in list_entrypoints(
|
||||
COMMAND_GROUP,
|
||||
allowlist=allowlist,
|
||||
reserved_ids=RESERVED_COMMAND_IDS,
|
||||
):
|
||||
try:
|
||||
backend = get_command(ep.name, allowlist=allowlist)
|
||||
except ConfigError as exc:
|
||||
logger.info(
|
||||
"startup.command_menu.skip_command",
|
||||
command=ep.name,
|
||||
error=str(exc),
|
||||
)
|
||||
continue
|
||||
cmd = backend.id.lower()
|
||||
if cmd in seen:
|
||||
continue
|
||||
if not is_valid_id(cmd):
|
||||
logger.debug(
|
||||
"startup.command_menu.skip_command_id",
|
||||
command=cmd,
|
||||
)
|
||||
continue
|
||||
description = backend.description or f"command: {cmd}"
|
||||
commands.append({"command": cmd, "description": description})
|
||||
seen.add(cmd)
|
||||
if include_file and "file" not in seen:
|
||||
commands.append({"command": "file", "description": "upload or fetch files"})
|
||||
seen.add("file")
|
||||
if "cancel" not in seen:
|
||||
commands.append({"command": "cancel", "description": "cancel run"})
|
||||
if len(commands) > _MAX_BOT_COMMANDS:
|
||||
logger.warning(
|
||||
"startup.command_menu.too_many",
|
||||
count=len(commands),
|
||||
limit=_MAX_BOT_COMMANDS,
|
||||
)
|
||||
commands = commands[:_MAX_BOT_COMMANDS]
|
||||
if not any(cmd["command"] == "cancel" for cmd in commands):
|
||||
commands[-1] = {"command": "cancel", "description": "cancel run"}
|
||||
return commands
|
||||
|
||||
|
||||
def _reserved_commands(runtime: TransportRuntime) -> set[str]:
|
||||
return {
|
||||
*{engine.lower() for engine in runtime.engine_ids},
|
||||
*{alias.lower() for alias in runtime.project_aliases()},
|
||||
*RESERVED_COMMAND_IDS,
|
||||
}
|
||||
|
||||
|
||||
async def _set_command_menu(cfg: TelegramBridgeConfig) -> None:
|
||||
commands = build_bot_commands(cfg.runtime, include_file=cfg.files.enabled)
|
||||
if not commands:
|
||||
return
|
||||
try:
|
||||
ok = await cfg.bot.set_my_commands(commands)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.info(
|
||||
"startup.command_menu.failed",
|
||||
error=str(exc),
|
||||
error_type=exc.__class__.__name__,
|
||||
)
|
||||
return
|
||||
if not ok:
|
||||
logger.info("startup.command_menu.rejected")
|
||||
return
|
||||
logger.info(
|
||||
"startup.command_menu.updated",
|
||||
commands=[cmd["command"] for cmd in commands],
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def is_cancel_command(text: str) -> bool:
|
||||
stripped = text.strip()
|
||||
if not stripped:
|
||||
return False
|
||||
command = stripped.split(maxsplit=1)[0]
|
||||
return command == "/cancel" or command.startswith("/cancel@")
|
||||
|
||||
|
||||
def _parse_slash_command(text: str) -> tuple[str | None, str]:
|
||||
stripped = text.lstrip()
|
||||
if not stripped.startswith("/"):
|
||||
return None, text
|
||||
lines = stripped.splitlines()
|
||||
if not lines:
|
||||
return None, text
|
||||
first_line = lines[0]
|
||||
token, _, rest = first_line.partition(" ")
|
||||
command = token[1:]
|
||||
if not command:
|
||||
return None, text
|
||||
if "@" in command:
|
||||
command = command.split("@", 1)[0]
|
||||
args_text = rest
|
||||
if len(lines) > 1:
|
||||
tail = "\n".join(lines[1:])
|
||||
args_text = f"{args_text}\n{tail}" if args_text else tail
|
||||
return command.lower(), args_text
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..bridge import send_plain
|
||||
from ..types import TelegramIncomingMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bridge import TelegramBridgeConfig
|
||||
|
||||
|
||||
def make_reply(
|
||||
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
return partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
@@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...markdown import MarkdownParts
|
||||
from ...transport import RenderedMessage, SendOptions
|
||||
from ..chat_sessions import ChatSessionStore
|
||||
from ..context import (
|
||||
_format_context,
|
||||
_format_ctx_status,
|
||||
_merge_topic_context,
|
||||
_parse_project_branch_args,
|
||||
_usage_ctx_set,
|
||||
_usage_topic,
|
||||
)
|
||||
from ..files import split_command_args
|
||||
from ..render import prepare_telegram
|
||||
from ..topic_state import TopicStateStore
|
||||
from ..topics import (
|
||||
_maybe_rename_topic,
|
||||
_topic_key,
|
||||
_topic_title,
|
||||
_topics_chat_project,
|
||||
_topics_command_error,
|
||||
)
|
||||
from ..types import TelegramIncomingMessage
|
||||
from .reply import make_reply
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..bridge import TelegramBridgeConfig
|
||||
|
||||
|
||||
async def _handle_ctx_command(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
store: TopicStateStore,
|
||||
*,
|
||||
resolved_scope: str | None = None,
|
||||
scope_chat_ids: frozenset[int] | None = None,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
error = _topics_command_error(
|
||||
cfg,
|
||||
msg.chat_id,
|
||||
resolved_scope=resolved_scope,
|
||||
scope_chat_ids=scope_chat_ids,
|
||||
)
|
||||
if error is not None:
|
||||
await reply(text=error)
|
||||
return
|
||||
chat_project = _topics_chat_project(cfg, msg.chat_id)
|
||||
tkey = _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
|
||||
if tkey is None:
|
||||
await reply(text="this command only works inside a topic.")
|
||||
return
|
||||
tokens = split_command_args(args_text)
|
||||
action = tokens[0].lower() if tokens else "show"
|
||||
if action in {"show", ""}:
|
||||
snapshot = await store.get_thread(*tkey)
|
||||
bound = snapshot.context if snapshot is not None else None
|
||||
ambient = _merge_topic_context(chat_project=chat_project, bound=bound)
|
||||
resolved = cfg.runtime.resolve_message(
|
||||
text="",
|
||||
reply_text=msg.reply_to_text,
|
||||
chat_id=msg.chat_id,
|
||||
ambient_context=ambient,
|
||||
)
|
||||
text = _format_ctx_status(
|
||||
cfg=cfg,
|
||||
runtime=cfg.runtime,
|
||||
bound=bound,
|
||||
resolved=resolved.context,
|
||||
context_source=resolved.context_source,
|
||||
snapshot=snapshot,
|
||||
chat_project=chat_project,
|
||||
)
|
||||
await reply(text=text)
|
||||
return
|
||||
if action == "set":
|
||||
rest = " ".join(tokens[1:])
|
||||
context, error = _parse_project_branch_args(
|
||||
rest,
|
||||
runtime=cfg.runtime,
|
||||
require_branch=False,
|
||||
chat_project=chat_project,
|
||||
)
|
||||
if error is not None:
|
||||
await reply(
|
||||
text=f"error:\n{error}\n{_usage_ctx_set(chat_project=chat_project)}",
|
||||
)
|
||||
return
|
||||
if context is None:
|
||||
await reply(
|
||||
text=f"error:\n{_usage_ctx_set(chat_project=chat_project)}",
|
||||
)
|
||||
return
|
||||
await store.set_context(*tkey, context)
|
||||
await _maybe_rename_topic(
|
||||
cfg,
|
||||
store,
|
||||
chat_id=tkey[0],
|
||||
thread_id=tkey[1],
|
||||
context=context,
|
||||
)
|
||||
await reply(
|
||||
text=f"topic bound to `{_format_context(cfg.runtime, context)}`",
|
||||
)
|
||||
return
|
||||
if action == "clear":
|
||||
await store.clear_context(*tkey)
|
||||
await reply(text="topic binding cleared.")
|
||||
return
|
||||
await reply(
|
||||
text="unknown `/ctx` command. use `/ctx`, `/ctx set`, or `/ctx clear`.",
|
||||
)
|
||||
|
||||
|
||||
async def _handle_new_command(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
store: TopicStateStore,
|
||||
*,
|
||||
resolved_scope: str | None = None,
|
||||
scope_chat_ids: frozenset[int] | None = None,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
error = _topics_command_error(
|
||||
cfg,
|
||||
msg.chat_id,
|
||||
resolved_scope=resolved_scope,
|
||||
scope_chat_ids=scope_chat_ids,
|
||||
)
|
||||
if error is not None:
|
||||
await reply(text=error)
|
||||
return
|
||||
tkey = _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
|
||||
if tkey is None:
|
||||
await reply(text="this command only works inside a topic.")
|
||||
return
|
||||
await store.clear_sessions(*tkey)
|
||||
await reply(text="cleared stored sessions for this topic.")
|
||||
|
||||
|
||||
async def _handle_chat_new_command(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
store: ChatSessionStore,
|
||||
session_key: tuple[int, int | None] | None,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
if session_key is None:
|
||||
await reply(text="no stored sessions to clear for this chat.")
|
||||
return
|
||||
await store.clear_sessions(session_key[0], session_key[1])
|
||||
if msg.chat_type == "private":
|
||||
text = "cleared stored sessions for this chat."
|
||||
else:
|
||||
text = "cleared stored sessions for you in this chat."
|
||||
await reply(text=text)
|
||||
|
||||
|
||||
async def _handle_topic_command(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: TelegramIncomingMessage,
|
||||
args_text: str,
|
||||
store: TopicStateStore,
|
||||
*,
|
||||
resolved_scope: str | None = None,
|
||||
scope_chat_ids: frozenset[int] | None = None,
|
||||
) -> None:
|
||||
reply = make_reply(cfg, msg)
|
||||
error = _topics_command_error(
|
||||
cfg,
|
||||
msg.chat_id,
|
||||
resolved_scope=resolved_scope,
|
||||
scope_chat_ids=scope_chat_ids,
|
||||
)
|
||||
if error is not None:
|
||||
await reply(text=error)
|
||||
return
|
||||
chat_project = _topics_chat_project(cfg, msg.chat_id)
|
||||
context, error = _parse_project_branch_args(
|
||||
args_text,
|
||||
runtime=cfg.runtime,
|
||||
require_branch=True,
|
||||
chat_project=chat_project,
|
||||
)
|
||||
if error is not None or context is None:
|
||||
usage = _usage_topic(chat_project=chat_project)
|
||||
text = f"error:\n{error}\n{usage}" if error else usage
|
||||
await reply(text=text)
|
||||
return
|
||||
existing = await store.find_thread_for_context(msg.chat_id, context)
|
||||
if existing is not None:
|
||||
await reply(
|
||||
text=f"topic already exists for {_format_context(cfg.runtime, context)} "
|
||||
"in this chat.",
|
||||
)
|
||||
return
|
||||
title = _topic_title(runtime=cfg.runtime, context=context)
|
||||
created = await cfg.bot.create_forum_topic(msg.chat_id, title)
|
||||
if created is None:
|
||||
await reply(text="failed to create topic.")
|
||||
return
|
||||
thread_id = created.message_thread_id
|
||||
await store.set_context(
|
||||
msg.chat_id,
|
||||
thread_id,
|
||||
context,
|
||||
topic_title=title,
|
||||
)
|
||||
await reply(text=f"created topic `{title}`.")
|
||||
bound_text = f"topic bound to `{_format_context(cfg.runtime, context)}`"
|
||||
rendered_text, entities = prepare_telegram(MarkdownParts(header=bound_text))
|
||||
await cfg.exec_cfg.transport.send(
|
||||
channel_id=msg.chat_id,
|
||||
message=RenderedMessage(text=rendered_text, extra={"entities": entities}),
|
||||
options=SendOptions(thread_id=thread_id),
|
||||
)
|
||||
+16
-29
@@ -20,26 +20,25 @@ from ..transport import MessageRef
|
||||
from ..transport_runtime import ResolvedMessage
|
||||
from ..context import RunContext
|
||||
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
|
||||
from .commands import (
|
||||
from .commands.agent import _handle_agent_command
|
||||
from .commands.cancel import handle_callback_cancel, handle_cancel
|
||||
from .commands.dispatch import _dispatch_command
|
||||
from .commands.executor import _run_engine, _should_show_resume_line
|
||||
from .commands.file_transfer import (
|
||||
FILE_PUT_USAGE,
|
||||
_dispatch_command,
|
||||
_handle_agent_command,
|
||||
_handle_chat_new_command,
|
||||
_handle_ctx_command,
|
||||
_handle_file_command,
|
||||
_handle_file_put_default,
|
||||
_handle_media_group,
|
||||
_save_file_put,
|
||||
)
|
||||
from .commands.media import _handle_media_group
|
||||
from .commands.menu import _reserved_commands, _set_command_menu
|
||||
from .commands.parse import _parse_slash_command, is_cancel_command
|
||||
from .commands.reply import make_reply
|
||||
from .commands.topics import (
|
||||
_handle_chat_new_command,
|
||||
_handle_ctx_command,
|
||||
_handle_new_command,
|
||||
_handle_topic_command,
|
||||
_parse_slash_command,
|
||||
_reserved_commands,
|
||||
_save_file_put,
|
||||
_should_show_resume_line,
|
||||
_run_engine,
|
||||
_set_command_menu,
|
||||
handle_callback_cancel,
|
||||
handle_cancel,
|
||||
is_cancel_command,
|
||||
)
|
||||
from .context import _merge_topic_context, _usage_ctx_set, _usage_topic
|
||||
from .topics import (
|
||||
@@ -519,13 +518,7 @@ async def run_main_loop(
|
||||
text: str,
|
||||
ambient_context: RunContext | None,
|
||||
) -> ResolvedMessage | None:
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = make_reply(cfg, msg)
|
||||
try:
|
||||
resolved = cfg.runtime.resolve_message(
|
||||
text=text,
|
||||
@@ -757,13 +750,7 @@ async def run_main_loop(
|
||||
if reply_id is not None
|
||||
else None
|
||||
)
|
||||
reply = partial(
|
||||
send_plain,
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
reply = make_reply(cfg, msg)
|
||||
text = msg.text
|
||||
if msg.voice is not None:
|
||||
text = await transcribe_voice(
|
||||
|
||||
@@ -286,8 +286,8 @@ def _confirm(message: str, *, default: bool = True) -> bool | None:
|
||||
exit_with_result(event)
|
||||
|
||||
@bindings.add(Keys.Any)
|
||||
def other(event):
|
||||
_ = event
|
||||
def other(_event):
|
||||
return None
|
||||
|
||||
question = Question(
|
||||
PromptSession(get_prompt_tokens, key_bindings=bindings, style=merged_style).app
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from collections.abc import Awaitable, Callable, Hashable
|
||||
|
||||
import anyio
|
||||
|
||||
from .client_api import RetryAfter
|
||||
|
||||
SEND_PRIORITY = 0
|
||||
DELETE_PRIORITY = 1
|
||||
EDIT_PRIORITY = 2
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class OutboxOp:
|
||||
execute: Callable[[], Awaitable[Any]]
|
||||
priority: int
|
||||
queued_at: float
|
||||
chat_id: int | None
|
||||
label: str | None = None
|
||||
done: anyio.Event = field(default_factory=anyio.Event)
|
||||
result: Any = None
|
||||
|
||||
def set_result(self, result: Any) -> None:
|
||||
if self.done.is_set():
|
||||
return
|
||||
self.result = result
|
||||
self.done.set()
|
||||
|
||||
|
||||
class TelegramOutbox:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
interval_for_chat: Callable[[int | None], float],
|
||||
clock: Callable[[], float] = time.monotonic,
|
||||
sleep: Callable[[float], Awaitable[None]] = anyio.sleep,
|
||||
on_error: Callable[[OutboxOp, Exception], None] | None = None,
|
||||
on_outbox_error: Callable[[Exception], None] | None = None,
|
||||
) -> None:
|
||||
self._interval_for_chat = interval_for_chat
|
||||
self._clock = clock
|
||||
self._sleep = sleep
|
||||
self._on_error = on_error
|
||||
self._on_outbox_error = on_outbox_error
|
||||
self._pending: dict[Hashable, OutboxOp] = {}
|
||||
self._cond = anyio.Condition()
|
||||
self._start_lock = anyio.Lock()
|
||||
self._closed = False
|
||||
self._tg: TaskGroup | None = None
|
||||
self.next_at = 0.0
|
||||
self.retry_at = 0.0
|
||||
|
||||
async def ensure_worker(self) -> None:
|
||||
async with self._start_lock:
|
||||
if self._tg is not None or self._closed:
|
||||
return
|
||||
self._tg = await anyio.create_task_group().__aenter__()
|
||||
self._tg.start_soon(self.run)
|
||||
|
||||
async def enqueue(self, *, key: Hashable, op: OutboxOp, wait: bool = True) -> Any:
|
||||
await self.ensure_worker()
|
||||
async with self._cond:
|
||||
if self._closed:
|
||||
op.set_result(None)
|
||||
return op.result
|
||||
previous = self._pending.get(key)
|
||||
if previous is not None:
|
||||
op.queued_at = previous.queued_at
|
||||
previous.set_result(None)
|
||||
self._pending[key] = op
|
||||
self._cond.notify()
|
||||
if not wait:
|
||||
return None
|
||||
await op.done.wait()
|
||||
return op.result
|
||||
|
||||
async def drop_pending(self, *, key: Hashable) -> None:
|
||||
async with self._cond:
|
||||
pending = self._pending.pop(key, None)
|
||||
if pending is not None:
|
||||
pending.set_result(None)
|
||||
self._cond.notify()
|
||||
|
||||
async def close(self) -> None:
|
||||
async with self._cond:
|
||||
self._closed = True
|
||||
self.fail_pending()
|
||||
self._cond.notify_all()
|
||||
if self._tg is not None:
|
||||
await self._tg.__aexit__(None, None, None)
|
||||
self._tg = None
|
||||
|
||||
def fail_pending(self) -> None:
|
||||
for pending in list(self._pending.values()):
|
||||
pending.set_result(None)
|
||||
self._pending.clear()
|
||||
|
||||
def pick_locked(self) -> tuple[Hashable, OutboxOp] | None:
|
||||
if not self._pending:
|
||||
return None
|
||||
return min(
|
||||
self._pending.items(),
|
||||
key=lambda item: (item[1].priority, item[1].queued_at),
|
||||
)
|
||||
|
||||
async def execute_op(self, op: OutboxOp) -> Any:
|
||||
try:
|
||||
return await op.execute()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if isinstance(exc, RetryAfter):
|
||||
raise
|
||||
if self._on_error is not None:
|
||||
self._on_error(op, exc)
|
||||
return None
|
||||
|
||||
async def sleep_until(self, deadline: float) -> None:
|
||||
delay = deadline - self._clock()
|
||||
if delay > 0:
|
||||
await self._sleep(delay)
|
||||
|
||||
async def run(self) -> None:
|
||||
cancel_exc = anyio.get_cancelled_exc_class()
|
||||
try:
|
||||
while True:
|
||||
async with self._cond:
|
||||
while not self._pending and not self._closed:
|
||||
await self._cond.wait()
|
||||
if self._closed and not self._pending:
|
||||
return
|
||||
blocked_until = max(self.next_at, self.retry_at)
|
||||
if self._clock() < blocked_until:
|
||||
await self.sleep_until(blocked_until)
|
||||
continue
|
||||
async with self._cond:
|
||||
if self._closed and not self._pending:
|
||||
return
|
||||
picked = self.pick_locked()
|
||||
if picked is None:
|
||||
continue
|
||||
key, op = picked
|
||||
self._pending.pop(key, None)
|
||||
started_at = self._clock()
|
||||
try:
|
||||
result = await self.execute_op(op)
|
||||
except RetryAfter as exc:
|
||||
self.retry_at = max(self.retry_at, self._clock() + exc.retry_after)
|
||||
async with self._cond:
|
||||
if self._closed:
|
||||
op.set_result(None)
|
||||
elif key not in self._pending:
|
||||
self._pending[key] = op
|
||||
self._cond.notify()
|
||||
else:
|
||||
op.set_result(None)
|
||||
continue
|
||||
self.next_at = started_at + self._interval_for_chat(op.chat_id)
|
||||
op.set_result(result)
|
||||
except cancel_exc:
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
async with self._cond:
|
||||
self._closed = True
|
||||
self.fail_pending()
|
||||
self._cond.notify_all()
|
||||
if self._on_outbox_error is not None:
|
||||
self._on_outbox_error(exc)
|
||||
return
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anyio.abc import TaskGroup
|
||||
else:
|
||||
TaskGroup = object
|
||||
@@ -0,0 +1,283 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from collections.abc import AsyncIterator, Callable, Iterable
|
||||
|
||||
import anyio
|
||||
|
||||
from ..logging import get_logger
|
||||
from .api_models import Update
|
||||
from .client_api import BotClient
|
||||
from .types import (
|
||||
TelegramCallbackQuery,
|
||||
TelegramDocument,
|
||||
TelegramIncomingMessage,
|
||||
TelegramIncomingUpdate,
|
||||
TelegramVoice,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_incoming_update(
|
||||
update: Update | dict[str, Any],
|
||||
*,
|
||||
chat_id: int | None = None,
|
||||
chat_ids: set[int] | None = None,
|
||||
) -> TelegramIncomingUpdate | None:
|
||||
if isinstance(update, Update):
|
||||
msg = update.message
|
||||
callback_query = update.callback_query
|
||||
else:
|
||||
msg = update.get("message")
|
||||
callback_query = update.get("callback_query")
|
||||
|
||||
if isinstance(msg, dict):
|
||||
return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids)
|
||||
if isinstance(callback_query, dict):
|
||||
return _parse_callback_query(
|
||||
callback_query,
|
||||
chat_id=chat_id,
|
||||
chat_ids=chat_ids,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_incoming_message(
|
||||
msg: dict[str, Any],
|
||||
*,
|
||||
chat_id: int | None = None,
|
||||
chat_ids: set[int] | None = None,
|
||||
) -> TelegramIncomingMessage | None:
|
||||
def _parse_document_payload(payload: dict[str, Any]) -> TelegramDocument | None:
|
||||
file_id = payload.get("file_id")
|
||||
if not isinstance(file_id, str) or not file_id:
|
||||
return None
|
||||
return TelegramDocument(
|
||||
file_id=file_id,
|
||||
file_name=payload.get("file_name")
|
||||
if isinstance(payload.get("file_name"), str)
|
||||
else None,
|
||||
mime_type=payload.get("mime_type")
|
||||
if isinstance(payload.get("mime_type"), str)
|
||||
else None,
|
||||
file_size=payload.get("file_size")
|
||||
if isinstance(payload.get("file_size"), int)
|
||||
and not isinstance(payload.get("file_size"), bool)
|
||||
else None,
|
||||
raw=payload,
|
||||
)
|
||||
|
||||
raw_text = msg.get("text")
|
||||
text = raw_text if isinstance(raw_text, str) else None
|
||||
caption = msg.get("caption")
|
||||
if text is None and isinstance(caption, str):
|
||||
text = caption
|
||||
if text is None:
|
||||
text = ""
|
||||
file_command = False
|
||||
if isinstance(text, str):
|
||||
stripped = text.lstrip()
|
||||
if stripped.startswith("/"):
|
||||
token = stripped.split(maxsplit=1)[0]
|
||||
file_command = token.startswith("/file")
|
||||
voice_payload: TelegramVoice | None = None
|
||||
voice = msg.get("voice")
|
||||
if isinstance(voice, dict):
|
||||
file_id = voice.get("file_id")
|
||||
if not isinstance(file_id, str) or not file_id:
|
||||
file_id = None
|
||||
if file_id is not None:
|
||||
voice_payload = TelegramVoice(
|
||||
file_id=file_id,
|
||||
mime_type=voice.get("mime_type")
|
||||
if isinstance(voice.get("mime_type"), str)
|
||||
else None,
|
||||
file_size=voice.get("file_size")
|
||||
if isinstance(voice.get("file_size"), int)
|
||||
and not isinstance(voice.get("file_size"), bool)
|
||||
else None,
|
||||
duration=voice.get("duration")
|
||||
if isinstance(voice.get("duration"), int)
|
||||
and not isinstance(voice.get("duration"), bool)
|
||||
else None,
|
||||
raw=voice,
|
||||
)
|
||||
if not isinstance(raw_text, str) and not isinstance(caption, str):
|
||||
text = ""
|
||||
document_payload: TelegramDocument | None = None
|
||||
document = msg.get("document")
|
||||
if isinstance(document, dict):
|
||||
document_payload = _parse_document_payload(document)
|
||||
if document_payload is None:
|
||||
video = msg.get("video")
|
||||
if isinstance(video, dict):
|
||||
document_payload = _parse_document_payload(video)
|
||||
if document_payload is None:
|
||||
photo = msg.get("photo")
|
||||
if isinstance(photo, list):
|
||||
best: dict[str, Any] | None = None
|
||||
best_score = -1
|
||||
for item in photo:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
file_id = item.get("file_id")
|
||||
if not isinstance(file_id, str) or not file_id:
|
||||
continue
|
||||
size = item.get("file_size")
|
||||
if isinstance(size, int) and not isinstance(size, bool):
|
||||
score = size
|
||||
else:
|
||||
width = item.get("width")
|
||||
height = item.get("height")
|
||||
if isinstance(width, int) and isinstance(height, int):
|
||||
score = width * height
|
||||
else:
|
||||
score = 0
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best = item
|
||||
if best is not None:
|
||||
document_payload = _parse_document_payload(best)
|
||||
if document_payload is None and file_command:
|
||||
sticker = msg.get("sticker")
|
||||
if isinstance(sticker, dict):
|
||||
document_payload = _parse_document_payload(sticker)
|
||||
has_text = isinstance(raw_text, str) or isinstance(caption, str)
|
||||
if not has_text and voice_payload is None and document_payload is None:
|
||||
return None
|
||||
chat = msg.get("chat")
|
||||
if not isinstance(chat, dict):
|
||||
return None
|
||||
msg_chat_id = chat.get("id")
|
||||
if not isinstance(msg_chat_id, int):
|
||||
return None
|
||||
chat_type = chat.get("type") if isinstance(chat.get("type"), str) else None
|
||||
is_forum = chat.get("is_forum")
|
||||
if not isinstance(is_forum, bool):
|
||||
is_forum = None
|
||||
allowed = chat_ids
|
||||
if allowed is None and chat_id is not None:
|
||||
allowed = {chat_id}
|
||||
if allowed is not None and msg_chat_id not in allowed:
|
||||
return None
|
||||
message_id = msg.get("message_id")
|
||||
if not isinstance(message_id, int):
|
||||
return None
|
||||
reply = msg.get("reply_to_message")
|
||||
reply_to_message_id = None
|
||||
reply_to_text = None
|
||||
if isinstance(reply, dict):
|
||||
reply_to_message_id = (
|
||||
reply.get("message_id")
|
||||
if isinstance(reply.get("message_id"), int)
|
||||
else None
|
||||
)
|
||||
reply_to_text = (
|
||||
reply.get("text") if isinstance(reply.get("text"), str) else None
|
||||
)
|
||||
sender = msg.get("from")
|
||||
sender_id = (
|
||||
sender.get("id")
|
||||
if isinstance(sender, dict) and isinstance(sender.get("id"), int)
|
||||
else None
|
||||
)
|
||||
media_group_id = msg.get("media_group_id")
|
||||
if not isinstance(media_group_id, str):
|
||||
media_group_id = None
|
||||
thread_id = msg.get("message_thread_id")
|
||||
if isinstance(thread_id, bool) or not isinstance(thread_id, int):
|
||||
thread_id = None
|
||||
is_topic_message = msg.get("is_topic_message")
|
||||
if not isinstance(is_topic_message, bool):
|
||||
is_topic_message = None
|
||||
return TelegramIncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=msg_chat_id,
|
||||
message_id=message_id,
|
||||
text=text,
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
reply_to_text=reply_to_text,
|
||||
sender_id=sender_id,
|
||||
media_group_id=media_group_id,
|
||||
thread_id=thread_id,
|
||||
is_topic_message=is_topic_message,
|
||||
chat_type=chat_type,
|
||||
is_forum=is_forum,
|
||||
voice=voice_payload,
|
||||
document=document_payload,
|
||||
raw=msg,
|
||||
)
|
||||
|
||||
|
||||
def _parse_callback_query(
|
||||
query: dict[str, Any],
|
||||
*,
|
||||
chat_id: int | None = None,
|
||||
chat_ids: set[int] | None = None,
|
||||
) -> TelegramCallbackQuery | None:
|
||||
callback_id = query.get("id")
|
||||
if not isinstance(callback_id, str) or not callback_id:
|
||||
return None
|
||||
msg = query.get("message")
|
||||
if not isinstance(msg, dict):
|
||||
return None
|
||||
chat = msg.get("chat")
|
||||
if not isinstance(chat, dict):
|
||||
return None
|
||||
msg_chat_id = chat.get("id")
|
||||
if not isinstance(msg_chat_id, int):
|
||||
return None
|
||||
allowed = chat_ids
|
||||
if allowed is None and chat_id is not None:
|
||||
allowed = {chat_id}
|
||||
if allowed is not None and msg_chat_id not in allowed:
|
||||
return None
|
||||
message_id = msg.get("message_id")
|
||||
if not isinstance(message_id, int):
|
||||
return None
|
||||
data = query.get("data") if isinstance(query.get("data"), str) else None
|
||||
sender = query.get("from")
|
||||
sender_id = (
|
||||
sender.get("id")
|
||||
if isinstance(sender, dict) and isinstance(sender.get("id"), int)
|
||||
else None
|
||||
)
|
||||
return TelegramCallbackQuery(
|
||||
transport="telegram",
|
||||
chat_id=msg_chat_id,
|
||||
message_id=message_id,
|
||||
callback_query_id=callback_id,
|
||||
data=data,
|
||||
sender_id=sender_id,
|
||||
raw=query,
|
||||
)
|
||||
|
||||
|
||||
async def poll_incoming(
|
||||
bot: BotClient,
|
||||
*,
|
||||
chat_id: int | None = None,
|
||||
chat_ids: Iterable[int] | Callable[[], Iterable[int]] | None = None,
|
||||
offset: int | None = None,
|
||||
) -> AsyncIterator[TelegramIncomingUpdate]:
|
||||
while True:
|
||||
updates = await bot.get_updates(
|
||||
offset=offset,
|
||||
timeout_s=50,
|
||||
allowed_updates=["message", "callback_query"],
|
||||
)
|
||||
if updates is None:
|
||||
logger.info("loop.get_updates.failed")
|
||||
await anyio.sleep(2)
|
||||
continue
|
||||
logger.debug("loop.updates", updates=updates)
|
||||
resolved_chat_ids = chat_ids() if callable(chat_ids) else chat_ids
|
||||
allowed = set(resolved_chat_ids) if resolved_chat_ids is not None else None
|
||||
if allowed is None and chat_id is not None:
|
||||
allowed = {chat_id}
|
||||
for upd in updates:
|
||||
offset = upd.update_id + 1
|
||||
msg = parse_incoming_update(upd, chat_ids=allowed)
|
||||
if msg is not None:
|
||||
yield msg
|
||||
@@ -1,14 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generic, Protocol, TypeVar
|
||||
from typing import Any, Protocol
|
||||
|
||||
import anyio
|
||||
import msgspec
|
||||
|
||||
T = TypeVar("T", bound="_VersionedState")
|
||||
from ..utils.json_state import atomic_write_json
|
||||
|
||||
|
||||
class _Logger(Protocol):
|
||||
@@ -19,7 +18,7 @@ class _VersionedState(Protocol):
|
||||
version: int
|
||||
|
||||
|
||||
class JsonStateStore(Generic[T]):
|
||||
class JsonStateStore[T: _VersionedState]:
|
||||
def __init__(
|
||||
self,
|
||||
path: Path,
|
||||
@@ -84,11 +83,6 @@ class JsonStateStore(Generic[T]):
|
||||
self._state = payload
|
||||
|
||||
def _save_locked(self) -> None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = msgspec.to_builtins(self._state)
|
||||
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
|
||||
with open(tmp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(payload, handle, indent=2, sort_keys=True)
|
||||
handle.write("\n")
|
||||
os.replace(tmp_path, self._path)
|
||||
atomic_write_json(self._path, payload)
|
||||
self._mtime_ns = self._stat_mtime_ns()
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol, TypeAlias
|
||||
from typing import Any, Protocol
|
||||
|
||||
ChannelId: TypeAlias = int | str
|
||||
MessageId: TypeAlias = int | str
|
||||
ThreadId: TypeAlias = int | str
|
||||
type ChannelId = int | str
|
||||
type MessageId = int | str
|
||||
type ThreadId = int | str
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
from typing import Any, Literal
|
||||
|
||||
from .config import ConfigError, ProjectsConfig
|
||||
from .context import RunContext
|
||||
@@ -19,7 +19,7 @@ from .router import AutoRouter, EngineStatus
|
||||
from .runner import Runner
|
||||
from .worktrees import WorktreeError, resolve_run_cwd
|
||||
|
||||
ContextSource: TypeAlias = Literal[
|
||||
type ContextSource = Literal[
|
||||
"reply_ctx",
|
||||
"directives",
|
||||
"ambient",
|
||||
@@ -234,13 +234,13 @@ class TransportRuntime:
|
||||
project_key = ambient_context.project
|
||||
else:
|
||||
project_key = default_project
|
||||
if branch is None:
|
||||
if (
|
||||
ambient_context is not None
|
||||
and ambient_context.branch is not None
|
||||
and project_key == ambient_context.project
|
||||
):
|
||||
branch = ambient_context.branch
|
||||
if (
|
||||
branch is None
|
||||
and ambient_context is not None
|
||||
and ambient_context.branch is not None
|
||||
and project_key == ambient_context.project
|
||||
):
|
||||
branch = ambient_context.branch
|
||||
context: RunContext | None = None
|
||||
if project_key is not None or branch is not None:
|
||||
context = RunContext(project=project_key, branch=branch)
|
||||
|
||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Protocol, runtime_checkable
|
||||
from typing import Protocol, runtime_checkable
|
||||
from collections.abc import Iterable
|
||||
|
||||
from .backends import EngineBackend, SetupIssue
|
||||
from .plugins import TRANSPORT_GROUP, list_ids, load_plugin_backend
|
||||
@@ -34,7 +35,7 @@ class TransportBackend(Protocol):
|
||||
def interactive_setup(self, *, force: bool) -> bool: ...
|
||||
|
||||
def lock_token(
|
||||
self, *, transport_config: object, config_path: Path
|
||||
self, *, transport_config: object, _config_path: Path
|
||||
) -> str | None: ...
|
||||
|
||||
def build_and_run(
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def atomic_write_json(
|
||||
path: Path,
|
||||
payload: Any,
|
||||
*,
|
||||
indent: int = 2,
|
||||
sort_keys: bool = True,
|
||||
) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = path.with_suffix(f"{path.suffix}.tmp")
|
||||
with open(tmp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(payload, handle, indent=indent, sort_keys=sort_keys)
|
||||
handle.write("\n")
|
||||
os.replace(tmp_path, path)
|
||||
+3
-3
@@ -14,7 +14,7 @@ from takopi.model import (
|
||||
|
||||
|
||||
def session_started(engine: str, value: str, title: str = "Codex") -> TakopiEvent:
|
||||
engine_id = EngineId(engine)
|
||||
engine_id: EngineId = engine
|
||||
return StartedEvent(
|
||||
engine=engine_id,
|
||||
resume=ResumeToken(engine=engine_id, value=value),
|
||||
@@ -29,7 +29,7 @@ def action_started(
|
||||
detail: dict[str, Any] | None = None,
|
||||
engine: str = "codex",
|
||||
) -> TakopiEvent:
|
||||
engine_id = EngineId(engine)
|
||||
engine_id: EngineId = engine
|
||||
return ActionEvent(
|
||||
engine=engine_id,
|
||||
action=Action(
|
||||
@@ -50,7 +50,7 @@ def action_completed(
|
||||
detail: dict[str, Any] | None = None,
|
||||
engine: str = "codex",
|
||||
) -> TakopiEvent:
|
||||
engine_id = EngineId(engine)
|
||||
engine_id: EngineId = engine
|
||||
return ActionEvent(
|
||||
engine=engine_id,
|
||||
action=Action(
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Iterable
|
||||
from typing import Any
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from takopi.runner_bridge import ExecBridgeConfig, IncomingMessage, handle_message
|
||||
from takopi.markdown import MarkdownParts, MarkdownPresenter
|
||||
from takopi.model import EngineId, ResumeToken, TakopiEvent
|
||||
from takopi.model import ResumeToken, TakopiEvent
|
||||
from takopi.telegram.render import prepare_telegram
|
||||
from takopi.runners.codex import CodexRunner
|
||||
from takopi.runners.mock import Advance, Emit, Raise, Return, ScriptRunner, Wait
|
||||
@@ -13,7 +13,7 @@ from takopi.settings import load_settings, require_telegram
|
||||
from takopi.transport import MessageRef, RenderedMessage, SendOptions
|
||||
from tests.factories import action_completed, action_started
|
||||
|
||||
CODEX_ENGINE = EngineId("codex")
|
||||
CODEX_ENGINE = "codex"
|
||||
|
||||
|
||||
class _FakeTransport:
|
||||
|
||||
@@ -7,14 +7,13 @@ from collections.abc import AsyncIterator
|
||||
from takopi.model import (
|
||||
ActionEvent,
|
||||
CompletedEvent,
|
||||
EngineId,
|
||||
ResumeToken,
|
||||
StartedEvent,
|
||||
TakopiEvent,
|
||||
)
|
||||
from takopi.runners.codex import CodexRunner, find_exec_only_flag
|
||||
|
||||
CODEX_ENGINE = EngineId("codex")
|
||||
CODEX_ENGINE = "codex"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@@ -45,11 +45,10 @@ def test_resolve_default_base_prefers_master_over_main(monkeypatch) -> None:
|
||||
return None
|
||||
|
||||
def _fake_ok(args, **kwargs):
|
||||
if args == ["show-ref", "--verify", "--quiet", "refs/heads/master"]:
|
||||
return True
|
||||
if args == ["show-ref", "--verify", "--quiet", "refs/heads/main"]:
|
||||
return True
|
||||
return False
|
||||
return args in (
|
||||
["show-ref", "--verify", "--quiet", "refs/heads/master"],
|
||||
["show-ref", "--verify", "--quiet", "refs/heads/main"],
|
||||
)
|
||||
|
||||
monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout)
|
||||
monkeypatch.setattr("takopi.utils.git.git_ok", _fake_ok)
|
||||
|
||||
@@ -7,7 +7,6 @@ from takopi.model import (
|
||||
Action,
|
||||
ActionEvent,
|
||||
CompletedEvent,
|
||||
EngineId,
|
||||
ResumeToken,
|
||||
StartedEvent,
|
||||
TakopiEvent,
|
||||
@@ -15,7 +14,7 @@ from takopi.model import (
|
||||
from takopi.runners.mock import Emit, Return, ScriptRunner, Wait
|
||||
from tests.factories import action_started
|
||||
|
||||
CODEX_ENGINE = EngineId("codex")
|
||||
CODEX_ENGINE = "codex"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -84,7 +83,7 @@ async def test_runner_releases_lock_when_consumer_closes() -> None:
|
||||
gate = anyio.Event()
|
||||
runner = ScriptRunner([Wait(gate)], engine=CODEX_ENGINE, resume_value="sid")
|
||||
|
||||
gen = cast(AsyncGenerator[TakopiEvent, None], runner.run("hello", None))
|
||||
gen = cast(AsyncGenerator[TakopiEvent], runner.run("hello", None))
|
||||
try:
|
||||
while True:
|
||||
evt = await anext(gen)
|
||||
@@ -94,7 +93,7 @@ async def test_runner_releases_lock_when_consumer_closes() -> None:
|
||||
await gen.aclose()
|
||||
|
||||
gen2 = cast(
|
||||
AsyncGenerator[TakopiEvent, None],
|
||||
AsyncGenerator[TakopiEvent],
|
||||
runner.run("again", ResumeToken(engine=CODEX_ENGINE, value="sid")),
|
||||
)
|
||||
try:
|
||||
|
||||
@@ -8,7 +8,6 @@ import takopi.runner as runner_module
|
||||
from takopi.model import (
|
||||
ActionEvent,
|
||||
CompletedEvent,
|
||||
EngineId,
|
||||
ResumeToken,
|
||||
StartedEvent,
|
||||
TakopiEvent,
|
||||
@@ -22,7 +21,7 @@ from takopi.runner import (
|
||||
|
||||
|
||||
class _DummyRunner(ResumeTokenMixin, BaseRunner):
|
||||
engine = EngineId("dummy")
|
||||
engine = "dummy"
|
||||
resume_re = re.compile(r"(?im)^`?dummy resume (?P<token>[^`\s]+)`?$")
|
||||
|
||||
async def run_impl(
|
||||
@@ -39,7 +38,7 @@ class _DummyRunner(ResumeTokenMixin, BaseRunner):
|
||||
|
||||
|
||||
class _DummyJsonlRunner(JsonlSubprocessRunner):
|
||||
engine = EngineId("dummy-jsonl")
|
||||
engine = "dummy-jsonl"
|
||||
|
||||
def command(self) -> str:
|
||||
return "dummy"
|
||||
@@ -67,7 +66,7 @@ class _DummyJsonlRunner(JsonlSubprocessRunner):
|
||||
|
||||
|
||||
class _BareJsonlRunner(JsonlSubprocessRunner):
|
||||
engine = EngineId("bare-jsonl")
|
||||
engine = "bare-jsonl"
|
||||
|
||||
|
||||
class _RunJsonlRunner(_DummyJsonlRunner):
|
||||
@@ -177,7 +176,7 @@ async def test_base_runner_run_locked_handles_resume() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_base_runner_rejects_wrong_resume_engine() -> None:
|
||||
runner = _DummyRunner()
|
||||
bad_resume = ResumeToken(engine=EngineId("other"), value="oops")
|
||||
bad_resume = ResumeToken(engine="other", value="oops")
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = [evt async for evt in runner.run("hello", bad_resume)]
|
||||
|
||||
@@ -185,7 +184,7 @@ async def test_base_runner_rejects_wrong_resume_engine() -> None:
|
||||
@pytest.mark.anyio
|
||||
async def test_base_runner_run_impl_not_implemented() -> None:
|
||||
class _BareRunner(BaseRunner):
|
||||
engine = EngineId("bare")
|
||||
engine = "bare"
|
||||
|
||||
runner = _BareRunner()
|
||||
with pytest.raises(NotImplementedError):
|
||||
@@ -204,7 +203,7 @@ def test_resume_token_format_and_extract() -> None:
|
||||
assert runner.extract_resume(None) is None
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
runner.format_resume(ResumeToken(engine=EngineId("other"), value="bad"))
|
||||
runner.format_resume(ResumeToken(engine="other", value="bad"))
|
||||
|
||||
|
||||
def test_session_lock_reuse() -> None:
|
||||
@@ -294,7 +293,7 @@ def test_jsonl_helpers() -> None:
|
||||
assert found == resume
|
||||
assert emit is False
|
||||
|
||||
mismatch = StartedEvent(engine=EngineId("other"), resume=resume, title="t")
|
||||
mismatch = StartedEvent(engine="other", resume=resume, title="t")
|
||||
with pytest.raises(RuntimeError):
|
||||
runner.handle_started_event(mismatch, expected_session=None, found_session=None)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
import pytest
|
||||
|
||||
from takopi.config import ProjectsConfig
|
||||
from takopi.model import EngineId
|
||||
from takopi.router import AutoRouter, RunnerEntry
|
||||
from takopi.runners.mock import Return, ScriptRunner
|
||||
from takopi.settings import (
|
||||
@@ -19,8 +18,8 @@ from takopi.transport_runtime import TransportRuntime
|
||||
|
||||
|
||||
def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
|
||||
codex = EngineId("codex")
|
||||
pi = EngineId("pi")
|
||||
codex = "codex"
|
||||
pi = "pi"
|
||||
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
||||
missing = ScriptRunner([Return(answer="ok")], engine=pi)
|
||||
router = AutoRouter(
|
||||
@@ -53,9 +52,9 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
|
||||
def test_build_startup_message_surfaces_unavailable_engine_reasons(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
codex = EngineId("codex")
|
||||
pi = EngineId("pi")
|
||||
claude = EngineId("claude")
|
||||
codex = "codex"
|
||||
pi = "pi"
|
||||
claude = "claude"
|
||||
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
||||
bad_cfg = ScriptRunner([Return(answer="ok")], engine=pi)
|
||||
load_err = ScriptRunner([Return(answer="ok")], engine=claude)
|
||||
@@ -100,7 +99,7 @@ def test_telegram_backend_build_and_run_wires_config(
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
codex = EngineId("codex")
|
||||
codex = "codex"
|
||||
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
||||
router = AutoRouter(
|
||||
entries=[RunnerEntry(engine=codex, runner=runner)],
|
||||
|
||||
@@ -6,8 +6,9 @@ import anyio
|
||||
import pytest
|
||||
|
||||
from takopi import commands, plugins
|
||||
from takopi.telegram.commands.executor import _CaptureTransport, _run_engine
|
||||
from takopi.telegram.commands.file_transfer import _handle_file_get, _handle_file_put
|
||||
import takopi.telegram.loop as telegram_loop
|
||||
import takopi.telegram.commands as telegram_commands
|
||||
import takopi.telegram.topics as telegram_topics
|
||||
from takopi.directives import parse_directives
|
||||
from takopi.telegram.api_models import (
|
||||
@@ -39,7 +40,7 @@ from takopi.context import RunContext
|
||||
from takopi.config import ProjectConfig, ProjectsConfig
|
||||
from takopi.runner_bridge import ExecBridgeConfig, RunningTask
|
||||
from takopi.markdown import MarkdownPresenter
|
||||
from takopi.model import EngineId, ResumeToken
|
||||
from takopi.model import ResumeToken
|
||||
from takopi.progress import ProgressTracker
|
||||
from takopi.router import AutoRouter, RunnerEntry
|
||||
from takopi.transport_runtime import TransportRuntime
|
||||
@@ -52,7 +53,7 @@ from takopi.telegram.types import (
|
||||
from takopi.transport import MessageRef, RenderedMessage, SendOptions
|
||||
from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints
|
||||
|
||||
CODEX_ENGINE = EngineId("codex")
|
||||
CODEX_ENGINE = "codex"
|
||||
|
||||
|
||||
def _empty_projects() -> ProjectsConfig:
|
||||
@@ -905,9 +906,7 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
|
||||
),
|
||||
)
|
||||
|
||||
await telegram_commands._handle_file_put(
|
||||
cfg, msg, "/proj uploads/hello.txt", None, None
|
||||
)
|
||||
await _handle_file_put(cfg, msg, "/proj uploads/hello.txt", None, None)
|
||||
|
||||
target = tmp_path / "uploads" / "hello.txt"
|
||||
assert target.read_bytes() == payload
|
||||
@@ -966,7 +965,7 @@ async def test_handle_file_get_sends_document_for_allowed_user(
|
||||
chat_type="supergroup",
|
||||
)
|
||||
|
||||
await telegram_commands._handle_file_get(cfg, msg, "/proj hello.txt", None, None)
|
||||
await _handle_file_get(cfg, msg, "/proj hello.txt", None, None)
|
||||
|
||||
assert bot.document_calls
|
||||
assert bot.document_calls[0]["filename"] == "hello.txt"
|
||||
@@ -1263,7 +1262,7 @@ async def test_send_with_resume_reports_when_missing() -> None:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_engine_hides_resume_line_in_topics() -> None:
|
||||
transport = telegram_commands._CaptureTransport()
|
||||
transport = _CaptureTransport()
|
||||
runner = ScriptRunner(
|
||||
[Return(answer="ok")],
|
||||
engine=CODEX_ENGINE,
|
||||
@@ -1279,7 +1278,7 @@ async def test_run_engine_hides_resume_line_in_topics() -> None:
|
||||
projects=_empty_projects(),
|
||||
)
|
||||
|
||||
await telegram_commands._run_engine(
|
||||
await _run_engine(
|
||||
exec_cfg=exec_cfg,
|
||||
runtime=runtime,
|
||||
running_tasks={},
|
||||
@@ -1456,14 +1455,14 @@ async def test_run_main_loop_auto_resumes_topic_default_engine(
|
||||
123, 77, ResumeToken(engine=CODEX_ENGINE, value="resume-codex")
|
||||
)
|
||||
await store.set_session_resume(
|
||||
123, 77, ResumeToken(engine=EngineId("claude"), value="resume-claude")
|
||||
123, 77, ResumeToken(engine="claude", value="resume-claude")
|
||||
)
|
||||
await store.set_default_engine(123, 77, "claude")
|
||||
|
||||
transport = _FakeTransport()
|
||||
bot = _FakeBot()
|
||||
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
|
||||
claude_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("claude"))
|
||||
claude_runner = ScriptRunner([Return(answer="ok")], engine="claude")
|
||||
router = AutoRouter(
|
||||
entries=[
|
||||
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
|
||||
@@ -1521,7 +1520,7 @@ async def test_run_main_loop_auto_resumes_topic_default_engine(
|
||||
assert codex_runner.calls == []
|
||||
assert len(claude_runner.calls) == 1
|
||||
assert claude_runner.calls[0][1] == ResumeToken(
|
||||
engine=EngineId("claude"), value="resume-claude"
|
||||
engine="claude", value="resume-claude"
|
||||
)
|
||||
|
||||
|
||||
@@ -2443,7 +2442,7 @@ async def test_run_main_loop_command_uses_project_default_engine(
|
||||
transport = _FakeTransport()
|
||||
bot = _FakeBot()
|
||||
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
|
||||
pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
|
||||
pi_runner = ScriptRunner([Return(answer="ok")], engine="pi")
|
||||
router = AutoRouter(
|
||||
entries=[
|
||||
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
|
||||
@@ -2525,7 +2524,7 @@ async def test_run_main_loop_command_defaults_to_chat_project(
|
||||
transport = _FakeTransport()
|
||||
bot = _FakeBot()
|
||||
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
|
||||
pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
|
||||
pi_runner = ScriptRunner([Return(answer="ok")], engine="pi")
|
||||
router = AutoRouter(
|
||||
entries=[
|
||||
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
|
||||
|
||||
@@ -3,6 +3,7 @@ import pytest
|
||||
|
||||
from takopi.logging import setup_logging
|
||||
from takopi.telegram.client import TelegramClient, TelegramRetryAfter
|
||||
from takopi.telegram.client_api import HttpBotClient
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -25,9 +26,9 @@ async def test_telegram_429_no_retry() -> None:
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
|
||||
with pytest.raises(TelegramRetryAfter) as exc:
|
||||
await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
||||
await api._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
@@ -49,8 +50,8 @@ async def test_no_token_in_logs_on_http_error(
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient(token, http_client=client)
|
||||
await tg._post("getUpdates", {"timeout": 1})
|
||||
api = HttpBotClient(token, http_client=client)
|
||||
await api._post("getUpdates", {"timeout": 1})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
@@ -79,9 +80,9 @@ async def test_telegram_429_no_retry_post_form() -> None:
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
|
||||
with pytest.raises(TelegramRetryAfter) as exc:
|
||||
await tg._post_form(
|
||||
await api._post_form(
|
||||
"sendDocument",
|
||||
{"chat_id": 1},
|
||||
files={"document": ("note.txt", b"hi")},
|
||||
@@ -102,9 +103,9 @@ async def test_telegram_429_defaults_retry_after_on_bad_body() -> None:
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
|
||||
with pytest.raises(TelegramRetryAfter) as exc:
|
||||
await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
||||
await api._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
@@ -124,8 +125,8 @@ async def test_telegram_ok_false_returns_none() -> None:
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||
result = await tg._post("getUpdates", {"timeout": 1})
|
||||
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
|
||||
result = await api._post("getUpdates", {"timeout": 1})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
@@ -141,8 +142,8 @@ async def test_telegram_invalid_payload_returns_none() -> None:
|
||||
|
||||
client = httpx.AsyncClient(transport=transport)
|
||||
try:
|
||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
||||
result = await tg._post("getUpdates", {"timeout": 1})
|
||||
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
|
||||
result = await api._post("getUpdates", {"timeout": 1})
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import pytest
|
||||
|
||||
from takopi.config import ProjectConfig, ProjectsConfig
|
||||
from takopi.context import RunContext
|
||||
from takopi.model import EngineId
|
||||
from takopi.router import AutoRouter, RunnerEntry
|
||||
from takopi.runners.mock import Return, ScriptRunner
|
||||
from takopi.telegram.chat_prefs import ChatPrefsStore
|
||||
@@ -15,8 +14,8 @@ from takopi.transport_runtime import TransportRuntime
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_resolve_engine_for_message_sources(tmp_path) -> None:
|
||||
codex = ScriptRunner([Return(answer="ok")], engine=EngineId("codex"))
|
||||
pi = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
|
||||
codex = ScriptRunner([Return(answer="ok")], engine="codex")
|
||||
pi = ScriptRunner([Return(answer="ok")], engine="pi")
|
||||
router = AutoRouter(
|
||||
entries=[
|
||||
RunnerEntry(engine=codex.engine, runner=codex),
|
||||
@@ -42,7 +41,7 @@ async def test_resolve_engine_for_message_sources(tmp_path) -> None:
|
||||
resolved = await resolve_engine_for_message(
|
||||
runtime=runtime,
|
||||
context=RunContext(project="proj"),
|
||||
explicit_engine=EngineId("codex"),
|
||||
explicit_engine="codex",
|
||||
chat_id=1,
|
||||
topic_key=(1, 10),
|
||||
topic_store=topic_store,
|
||||
|
||||
@@ -15,8 +15,8 @@ class DummyTransport:
|
||||
def interactive_setup(self, *, force: bool) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def lock_token(self, *, transport_config: object, config_path):
|
||||
_ = transport_config, config_path
|
||||
def lock_token(self, *, transport_config: object, _config_path):
|
||||
_ = transport_config, _config_path
|
||||
raise NotImplementedError
|
||||
|
||||
def build_and_run(
|
||||
|
||||
Reference in New Issue
Block a user