refactor: telegram modules and tighten linting (#111)

This commit is contained in:
banteg
2026-01-13 05:14:26 +04:00
committed by GitHub
parent f060d3b59c
commit c1205cd5a8
63 changed files with 3257 additions and 3073 deletions
+1 -1
View File
@@ -408,7 +408,7 @@ from ..schemas import acme as acme_schema
logger = logging.getLogger(__name__)
ENGINE: EngineId = EngineId("acme")
ENGINE: EngineId = "acme"
_RESUME_RE = re.compile(
r"(?im)^\s*`?acme\s+--resume\s+(?P<token>[^`\s]+)`?\s*$"
)
+1 -1
View File
@@ -65,7 +65,7 @@ addopts = ["--cov=takopi", "--cov-report=term-missing", "--cov-fail-under=75"]
testpaths = ["tests"]
[tool.ruff.lint]
extend-select = ["B904", "BLE001", "S110", "RUF043"]
extend-select = ["B", "BLE001", "C4", "PERF", "RUF043", "S110", "SIM", "UP"]
[tool.ty.src]
include = ["src", "tests"]
+2 -5
View File
@@ -256,7 +256,7 @@ def _run_auto_router(
)
lock_token = transport_backend.lock_token(
transport_config=transport_config,
config_path=config_path,
_config_path=config_path,
)
lock_handle = acquire_config_lock(config_path, lock_token)
runtime = spec.to_runtime(config_path=config_path)
@@ -301,10 +301,7 @@ def _default_alias_from_path(path: Path) -> str | None:
def _ensure_projects_table(config: dict, config_path: Path) -> dict:
projects = config.get("projects")
if projects is None:
projects = {}
config["projects"] = projects
projects = config.setdefault("projects", {})
if not isinstance(projects, dict):
raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.")
return projects
+1 -1
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Awaitable, Callable, Iterable
from collections.abc import Awaitable, Callable, Iterable
from watchfiles import awatch
+1 -1
View File
@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Iterable
from collections.abc import Iterable
from .backends import EngineBackend
from .config import ConfigError
+3 -3
View File
@@ -11,7 +11,7 @@ from .logging import get_logger
logger = get_logger(__name__)
@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class LockInfo:
pid: int | None
token_fingerprint: str | None
@@ -29,7 +29,7 @@ class LockError(RuntimeError):
super().__init__(_format_lock_message(path, state))
@dataclass
@dataclass(slots=True)
class LockHandle:
path: Path
@@ -44,7 +44,7 @@ class LockHandle:
error_type=exc.__class__.__name__,
)
def __enter__(self) -> "LockHandle":
def __enter__(self) -> LockHandle:
return self
def __exit__(self, exc_type, exc, tb) -> None:
+5 -7
View File
@@ -107,9 +107,8 @@ def _redact_value(value: Any, memo: dict[int, Any]) -> Any:
def _redact_event_dict(
logger: Any, method_name: str, event_dict: dict[str, Any]
_logger: Any, _method_name: str, event_dict: dict[str, Any]
) -> dict[str, Any]:
_ = logger, method_name
return _redact_value(event_dict, memo={})
@@ -222,10 +221,7 @@ def setup_logging(
format_value = os.environ.get("TAKOPI_LOG_FORMAT", "console").strip().lower()
color_override = os.environ.get("TAKOPI_LOG_COLOR")
if color_override is None:
is_tty = sys.stdout.isatty()
else:
is_tty = _truthy(color_override)
is_tty = sys.stdout.isatty() if color_override is None else _truthy(color_override)
if format_value == "json":
renderer: Any = structlog.processors.JSONRenderer(default=str)
else:
@@ -242,7 +238,9 @@ def setup_logging(
_log_file_handle = None
if log_file:
try:
_log_file_handle = open(log_file, "a", encoding="utf-8")
_log_file_handle = open( # noqa: SIM115
log_file, "a", encoding="utf-8"
)
except OSError:
_log_file_handle = None
+7 -7
View File
@@ -120,9 +120,12 @@ def format_file_change_title(action: Action, *, command_width: int | None) -> st
was_relativized = relativized != fallback
if was_relativized:
fallback = relativized
if fallback and not (fallback.startswith("`") and fallback.endswith("`")):
if was_relativized or os.sep in fallback or "/" in fallback:
fallback = f"`{fallback}`"
if (
fallback
and not (fallback.startswith("`") and fallback.endswith("`"))
and (was_relativized or os.sep in fallback or "/" in fallback)
):
fallback = f"`{fallback}`"
return f"files: {shorten(fallback, command_width)}"
@@ -247,10 +250,7 @@ class MarkdownFormatter:
def _format_actions(self, state: ProgressState) -> list[str]:
actions = list(state.actions)
if self.max_actions == 0:
actions = []
else:
actions = actions[-self.max_actions :]
actions = [] if self.max_actions == 0 else actions[-self.max_actions :]
return [
format_action_line(
action_state.action,
+7 -7
View File
@@ -3,11 +3,11 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal, TypeAlias
from typing import Any, Literal
EngineId: TypeAlias = str
type EngineId = str
ActionKind: TypeAlias = Literal[
type ActionKind = Literal[
"command",
"tool",
"file_change",
@@ -19,14 +19,14 @@ ActionKind: TypeAlias = Literal[
"telemetry",
]
TakopiEventType: TypeAlias = Literal[
type TakopiEventType = Literal[
"started",
"action",
"completed",
]
ActionPhase: TypeAlias = Literal["started", "updated", "completed"]
ActionLevel: TypeAlias = Literal["debug", "info", "warning", "error"]
type ActionPhase = Literal["started", "updated", "completed"]
type ActionLevel = Literal["debug", "info", "warning", "error"]
@dataclass(frozen=True, slots=True)
@@ -74,4 +74,4 @@ class CompletedEvent:
usage: dict[str, Any] | None = None
TakopiEvent: TypeAlias = StartedEvent | ActionEvent | CompletedEvent
type TakopiEvent = StartedEvent | ActionEvent | CompletedEvent
+4 -8
View File
@@ -1,10 +1,11 @@
from __future__ import annotations
from collections.abc import Iterable, Mapping
from collections.abc import Iterable
from dataclasses import dataclass
from importlib.metadata import EntryPoint, entry_points
import re
from typing import Any, Callable
from typing import Any
from collections.abc import Callable
from .ids import ID_PATTERN, is_valid_id
@@ -80,12 +81,7 @@ def reset_plugin_state() -> None:
def _select_entrypoints(group: str) -> list[EntryPoint]:
eps = entry_points()
if hasattr(eps, "select"):
return list(eps.select(group=group))
if isinstance(eps, Mapping):
return list(eps.get(group, []))
return []
return list(entry_points().select(group=group))
def entrypoint_distribution_name(ep: EntryPoint) -> str | None:
+1 -1
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable
from collections.abc import Callable
from .model import Action, ActionEvent, ResumeToken, StartedEvent, TakopiEvent
+3 -2
View File
@@ -1,7 +1,8 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Literal, TypeAlias
from typing import Literal
from collections.abc import Iterable
from .model import EngineId, ResumeToken
from .runner import Runner
@@ -17,7 +18,7 @@ class RunnerUnavailableError(RuntimeError):
self.issue = issue
EngineStatus: TypeAlias = Literal["ok", "missing_cli", "bad_config", "load_error"]
type EngineStatus = Literal["ok", "missing_cli", "bad_config", "load_error"]
@dataclass(frozen=True, slots=True)
+5 -8
View File
@@ -37,12 +37,9 @@ def _log_runner_event(evt: TakopiEvent) -> None:
def _strip_resume_lines(text: str, *, is_resume_line: Callable[[str], bool]) -> str:
stripped_lines: list[str] = []
for line in text.splitlines():
if is_resume_line(line):
continue
stripped_lines.append(line)
prompt = "\n".join(stripped_lines).strip()
prompt = "\n".join(
line for line in text.splitlines() if not is_resume_line(line)
).strip()
return prompt or "continue"
@@ -83,14 +80,14 @@ class IncomingMessage:
thread_id: ThreadId | None = None
@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class ExecBridgeConfig:
transport: Transport
presenter: Presenter
final_notify: bool
@dataclass
@dataclass(slots=True)
class RunningTask:
resume: ResumeToken | None = None
resume_ready: anyio.Event = field(default_factory=anyio.Event)
+5 -50
View File
@@ -14,11 +14,11 @@ from ..logging import get_logger
from ..model import Action, ActionKind, EngineId, ResumeToken, TakopiEvent
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
from ..schemas import claude as claude_schema
from ..utils.paths import relativize_command, relativize_path
from .tool_actions import tool_input_path, tool_kind_and_title
logger = get_logger(__name__)
ENGINE: EngineId = EngineId("claude")
ENGINE: EngineId = "claude"
DEFAULT_ALLOWED_TOOLS = ["Bash", "Read", "Edit", "Write"]
_RESUME_RE = re.compile(
@@ -67,55 +67,10 @@ def _coerce_comma_list(value: Any) -> str | None:
return text or None
def _tool_input_path(tool_input: dict[str, Any]) -> str | None:
for key in ("file_path", "path"):
value = tool_input.get(key)
if isinstance(value, str) and value:
return value
return None
def _tool_kind_and_title(
name: str, tool_input: dict[str, Any]
) -> tuple[ActionKind, str]:
if name in {"Bash", "Shell", "KillShell"}:
command = tool_input.get("command")
display = relativize_command(str(command or name))
return "command", display
if name in {"Edit", "Write", "NotebookEdit", "MultiEdit"}:
path = _tool_input_path(tool_input)
if path:
return "file_change", relativize_path(str(path))
return "file_change", str(name)
if name == "Read":
path = _tool_input_path(tool_input)
if path:
return "tool", f"read: `{relativize_path(str(path))}`"
return "tool", "read"
if name == "Glob":
pattern = tool_input.get("pattern")
if pattern:
return "tool", f"glob: `{pattern}`"
return "tool", "glob"
if name == "Grep":
pattern = tool_input.get("pattern")
if pattern:
return "tool", f"grep: {pattern}"
return "tool", "grep"
if name == "WebSearch":
query = tool_input.get("query")
return "web_search", str(query or "search")
if name == "WebFetch":
url = tool_input.get("url")
return "web_search", str(url or "fetch")
if name in {"TodoWrite", "TodoRead"}:
return "note", "update todos" if name == "TodoWrite" else "read todos"
if name == "AskUserQuestion":
return "note", "ask user"
if name in {"Task", "Agent"}:
desc = tool_input.get("description") or tool_input.get("prompt")
return "subagent", str(desc or name)
return "tool", name
return tool_kind_and_title(name, tool_input, path_keys=("file_path", "path"))
def _tool_action(
@@ -137,7 +92,7 @@ def _tool_action(
detail["parent_tool_use_id"] = parent_tool_use_id
if kind == "file_change":
path = _tool_input_path(tool_input)
path = tool_input_path(tool_input, path_keys=("file_path", "path"))
if path:
detail["changes"] = [{"path": path, "kind": "update"}]
@@ -321,7 +276,7 @@ def translate_claude_event(
return []
@dataclass
@dataclass(slots=True)
class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
engine: EngineId = ENGINE
resume_re: re.Pattern[str] = _RESUME_RE
+1 -1
View File
@@ -18,7 +18,7 @@ from ..utils.paths import relativize_command
logger = get_logger(__name__)
ENGINE: EngineId = EngineId("codex")
ENGINE: EngineId = "codex"
__all__ = [
"ENGINE",
+6 -7
View File
@@ -4,7 +4,6 @@ import re
import uuid
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
from dataclasses import dataclass, replace
from typing import TypeAlias
import anyio
@@ -18,7 +17,7 @@ from ..model import (
)
from ..runner import ResumeTokenMixin, Runner, SessionLockMixin
ENGINE: EngineId = EngineId("mock")
ENGINE: EngineId = "mock"
@dataclass(frozen=True, slots=True)
@@ -52,7 +51,7 @@ class Raise:
error: Exception
ScriptStep: TypeAlias = Emit | Advance | Sleep | Wait | Return | Raise
type ScriptStep = Emit | Advance | Sleep | Wait | Return | Raise
def _resume_token(engine: EngineId, value: str | None) -> ResumeToken:
@@ -108,9 +107,9 @@ class MockRunner(SessionLockMixin, ResumeTokenMixin, Runner):
if (
isinstance(event_out, ActionEvent)
and event_out.phase == "completed"
and event_out.ok is None
):
if event_out.ok is None:
event_out = replace(event_out, ok=True)
event_out = replace(event_out, ok=True)
yield event_out
await anyio.sleep(0)
@@ -187,9 +186,9 @@ class ScriptRunner(MockRunner):
if (
isinstance(event_out, ActionEvent)
and event_out.phase == "completed"
and event_out.ok is None
):
if event_out.ok is None:
event_out = replace(event_out, ok=True)
event_out = replace(event_out, ok=True)
yield event_out
await anyio.sleep(0)
continue
+14 -56
View File
@@ -35,11 +35,12 @@ from ..model import (
)
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
from ..schemas import opencode as opencode_schema
from ..utils.paths import relativize_command, relativize_path
from ..utils.paths import relativize_path
from .tool_actions import tool_input_path, tool_kind_and_title
logger = get_logger(__name__)
ENGINE: EngineId = EngineId("opencode")
ENGINE: EngineId = "opencode"
_RESUME_RE = re.compile(
r"(?im)^\s*`?opencode(?:\s+run)?\s+(?:--session|-s)\s+(?P<token>ses_[A-Za-z0-9]+)`?\s*$"
@@ -79,54 +80,12 @@ def _action_event(
def _tool_kind_and_title(
tool_name: str, tool_input: dict[str, Any]
) -> tuple[ActionKind, str]:
"""Map OpenCode tool names to Takopi action kinds and titles."""
name_lower = tool_name.lower()
if name_lower in {"bash", "shell"}:
command = tool_input.get("command")
display = relativize_command(str(command or tool_name))
return "command", display
if name_lower in {"edit", "write", "multiedit"}:
path = tool_input.get("file_path") or tool_input.get("filePath")
if path:
return "file_change", relativize_path(str(path))
return "file_change", str(tool_name)
if name_lower == "read":
path = tool_input.get("file_path") or tool_input.get("filePath")
if path:
return "tool", f"read: `{relativize_path(str(path))}`"
return "tool", "read"
if name_lower == "glob":
pattern = tool_input.get("pattern")
if pattern:
return "tool", f"glob: `{pattern}`"
return "tool", "glob"
if name_lower == "grep":
pattern = tool_input.get("pattern")
if pattern:
return "tool", f"grep: {pattern}"
return "tool", "grep"
if name_lower in {"websearch", "web_search"}:
query = tool_input.get("query")
return "web_search", str(query or "search")
if name_lower in {"webfetch", "web_fetch"}:
url = tool_input.get("url")
return "web_search", str(url or "fetch")
if name_lower in {"todowrite", "todoread"}:
return "note", "update todos" if "write" in name_lower else "read todos"
if name_lower == "task":
desc = tool_input.get("description") or tool_input.get("prompt")
return "tool", str(desc or tool_name)
return "tool", tool_name
return tool_kind_and_title(
tool_name,
tool_input,
path_keys=("file_path", "filePath"),
task_kind="tool",
)
def _normalize_tool_title(
@@ -137,10 +96,10 @@ def _normalize_tool_title(
if "`" in title:
return title
path = tool_input.get("file_path") or tool_input.get("filePath")
path = tool_input_path(tool_input, path_keys=("file_path", "filePath"))
if isinstance(path, str) and path:
rel_path = relativize_path(path)
if title == path or title == rel_path:
if title in (path, rel_path):
return f"`{rel_path}`"
return title
@@ -190,9 +149,8 @@ def translate_opencode_event(
"""Translate an OpenCode JSON event into Takopi events."""
session_id = event.sessionID
if isinstance(session_id, str) and session_id:
if state.session_id is None:
state.session_id = session_id
if isinstance(session_id, str) and session_id and state.session_id is None:
state.session_id = session_id
match event:
case opencode_schema.StepStart():
@@ -340,7 +298,7 @@ def translate_opencode_event(
return []
@dataclass
@dataclass(slots=True)
class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
"""Runner for OpenCode CLI."""
+9 -31
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
import os
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from datetime import datetime, UTC
from pathlib import Path, PurePath
from typing import Any
from uuid import uuid4
@@ -27,11 +27,12 @@ from ..model import (
)
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
from ..schemas import pi as pi_schema
from ..utils.paths import get_run_base_dir, relativize_command, relativize_path
from ..utils.paths import get_run_base_dir
from .tool_actions import tool_kind_and_title
logger = get_logger(__name__)
ENGINE: EngineId = EngineId("pi")
ENGINE: EngineId = "pi"
_RESUME_RE = re.compile(r"(?im)^\s*`?pi\s+--session\s+(?P<token>.+?)`?\s*$")
@@ -96,32 +97,7 @@ def _tool_kind_and_title(
name: str,
args: dict[str, Any],
) -> tuple[ActionKind, str]:
tool = name.lower()
if tool == "bash":
command = args.get("command")
return "command", relativize_command(str(command or "bash"))
if tool in {"edit", "write"}:
path = args.get("path")
if path:
return "file_change", relativize_path(str(path))
return "file_change", tool
if tool == "read":
path = args.get("path")
if path:
return "tool", f"read: `{relativize_path(str(path))}`"
return "tool", "read"
if tool == "grep":
pattern = args.get("pattern")
return "tool", f"grep: {pattern}" if pattern else "grep"
if tool == "find":
pattern = args.get("pattern")
return "tool", f"find: {pattern}" if pattern else "find"
if tool == "ls":
path = args.get("path")
if path:
return "tool", f"ls: `{relativize_path(str(path))}`"
return "tool", "ls"
return "tool", name
return tool_kind_and_title(name, args, path_keys=("path",))
def _last_assistant_message(messages: Any) -> dict[str, Any] | None:
@@ -418,7 +394,7 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
cwd = get_run_base_dir() or Path.cwd()
session_dir = _default_session_dir(cwd)
session_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now(timezone.utc).isoformat()
timestamp = datetime.now(UTC).isoformat()
safe_timestamp = timestamp.replace(":", "-").replace(".", "-")
token = uuid4().hex
filename = f"{safe_timestamp}_{token}.jsonl"
@@ -442,7 +418,9 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
def _default_session_dir(cwd: PurePath) -> Path:
agent_dir = os.environ.get("PI_CODING_AGENT_DIR")
base = Path(agent_dir).expanduser() if agent_dir else Path.home() / ".pi" / "agent"
safe_path = f"--{str(cwd).lstrip('/\\\\').replace('/', '-').replace('\\', '-').replace(':', '-')}--"
cwd_str = str(cwd).lstrip("/\\")
safe_path_part = cwd_str.translate(str.maketrans({"/": "-", "\\": "-", ":": "-"}))
safe_path = f"--{safe_path_part}--"
return base / "sessions" / safe_path
+90
View File
@@ -0,0 +1,90 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from ..model import ActionKind
from ..utils.paths import relativize_command, relativize_path
def tool_input_path(
tool_input: Mapping[str, Any],
*,
path_keys: Sequence[str],
) -> str | None:
for key in path_keys:
value = tool_input.get(key)
if isinstance(value, str) and value:
return value
return None
def tool_kind_and_title(
tool_name: str,
tool_input: Mapping[str, Any],
*,
path_keys: Sequence[str],
task_kind: ActionKind = "subagent",
) -> tuple[ActionKind, str]:
name_lower = tool_name.lower()
if name_lower in {"bash", "shell", "killshell"}:
command = tool_input.get("command")
display = relativize_command(str(command or tool_name))
return "command", display
if name_lower in {"edit", "write", "notebookedit", "multiedit"}:
path = tool_input_path(tool_input, path_keys=path_keys)
if path:
return "file_change", relativize_path(str(path))
return "file_change", str(tool_name)
if name_lower == "read":
path = tool_input_path(tool_input, path_keys=path_keys)
if path:
return "tool", f"read: `{relativize_path(str(path))}`"
return "tool", "read"
if name_lower == "glob":
pattern = tool_input.get("pattern")
if pattern:
return "tool", f"glob: `{pattern}`"
return "tool", "glob"
if name_lower == "grep":
pattern = tool_input.get("pattern")
if pattern:
return "tool", f"grep: {pattern}"
return "tool", "grep"
if name_lower == "find":
pattern = tool_input.get("pattern")
if pattern:
return "tool", f"find: {pattern}"
return "tool", "find"
if name_lower == "ls":
path = tool_input_path(tool_input, path_keys=path_keys)
if path:
return "tool", f"ls: `{relativize_path(str(path))}`"
return "tool", "ls"
if name_lower in {"websearch", "web_search"}:
query = tool_input.get("query")
return "web_search", str(query or "search")
if name_lower in {"webfetch", "web_fetch"}:
url = tool_input.get("url")
return "web_search", str(url or "fetch")
if name_lower in {"todowrite", "todoread"}:
return "note", "update todos" if "write" in name_lower else "read todos"
if name_lower == "askuserquestion":
return "note", "ask user"
if name_lower in {"task", "agent"}:
desc = tool_input.get("description") or tool_input.get("prompt")
return task_kind, str(desc or tool_name)
return "tool", tool_name
+2 -1
View File
@@ -3,7 +3,8 @@ from __future__ import annotations
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable, Mapping
from typing import Any
from collections.abc import Iterable, Mapping
from .backends import EngineBackend
from .config import ConfigError, ProjectsConfig
+17 -2
View File
@@ -2,14 +2,18 @@ from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Protocol
from typing import Any, Protocol
from collections.abc import Awaitable, Callable
import anyio
from .context import RunContext
from .logging import get_logger
from .model import ResumeToken
from .transport import ChannelId, MessageId, ThreadId
logger = get_logger(__name__)
@dataclass(frozen=True, slots=True)
class ThreadJob:
@@ -108,7 +112,18 @@ class ThreadScheduler:
if done is not None and not done.is_set():
await done.wait()
await self._run_job(job)
try:
await self._run_job(job)
except Exception as exc: # noqa: BLE001
logger.exception(
"scheduler.job_failed",
key=key,
tag=job.resume_token.engine,
chat_id=job.chat_id,
user_msg_id=job.user_msg_id,
error=str(exc),
error_type=exc.__class__.__name__,
)
finally:
async with self._lock:
self._active_threads.discard(key)
+5 -5
View File
@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Literal, TypeAlias
from typing import Any, Literal
import msgspec
@@ -36,7 +36,7 @@ class StreamToolResultBlock(
is_error: bool | None = None
StreamContentBlock: TypeAlias = (
type StreamContentBlock = (
StreamTextBlock | StreamThinkingBlock | StreamToolUseBlock | StreamToolResultBlock
)
@@ -164,7 +164,7 @@ class ControlRewindFilesRequest(
user_message_id: str
ControlRequest: TypeAlias = (
type ControlRequest = (
ControlInterruptRequest
| ControlCanUseToolRequest
| ControlInitializeRequest
@@ -196,7 +196,7 @@ class ControlErrorResponse(
error: str
ControlResponse: TypeAlias = ControlSuccessResponse | ControlErrorResponse
type ControlResponse = ControlSuccessResponse | ControlErrorResponse
class StreamControlResponse(
@@ -217,7 +217,7 @@ class StreamControlCancelRequest(
request_id: str | None = None
StreamJsonMessage: TypeAlias = (
type StreamJsonMessage = (
StreamUserMessage
| StreamAssistantMessage
| StreamSystemMessage
+7 -7
View File
@@ -2,27 +2,27 @@ from __future__ import annotations
# Headless JSONL schema derived from tag rust-v0.77.0 (git 112f40e91c12af0f7146d7e03f20283516a8af0b).
from typing import Any, Literal, TypeAlias
from typing import Any, Literal
import msgspec
CommandExecutionStatus: TypeAlias = Literal[
type CommandExecutionStatus = Literal[
"in_progress",
"completed",
"failed",
"declined",
]
PatchApplyStatus: TypeAlias = Literal[
type PatchApplyStatus = Literal[
"in_progress",
"completed",
"failed",
]
PatchChangeKind: TypeAlias = Literal[
type PatchChangeKind = Literal[
"add",
"delete",
"update",
]
McpToolCallStatus: TypeAlias = Literal[
type McpToolCallStatus = Literal[
"in_progress",
"completed",
"failed",
@@ -127,7 +127,7 @@ class TodoListItem(msgspec.Struct, tag="todo_list", kw_only=True):
items: list[TodoItem]
ThreadItem: TypeAlias = (
type ThreadItem = (
AgentMessageItem
| ReasoningItem
| CommandExecutionItem
@@ -151,7 +151,7 @@ class ItemCompleted(msgspec.Struct, tag="item.completed", kw_only=True):
item: ThreadItem
ThreadEvent: TypeAlias = (
type ThreadEvent = (
ThreadStarted
| TurnStarted
| TurnCompleted
+2 -2
View File
@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, TypeAlias
from typing import Any
import msgspec
@@ -42,7 +42,7 @@ class Error(_Event, tag="error"):
message: Any = None
OpenCodeEvent: TypeAlias = StepStart | StepFinish | ToolUse | Text | Error
type OpenCodeEvent = StepStart | StepFinish | ToolUse | Text | Error
_DECODER = msgspec.json.Decoder(OpenCodeEvent)
+2 -2
View File
@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, TypeAlias
from typing import Any
import msgspec
@@ -84,7 +84,7 @@ class AutoRetryEnd(_Event, tag="auto_retry_end"):
finalError: str | None = None
PiEvent: TypeAlias = (
type PiEvent = (
AgentStart
| AgentEnd
| MessageStart
+2 -1
View File
@@ -1,7 +1,8 @@
from __future__ import annotations
from pathlib import Path
from typing import Annotated, Any, ClassVar, Iterable, Literal
from typing import Annotated, Any, ClassVar, Literal
from collections.abc import Iterable
from pydantic import (
BaseModel,
+2 -5
View File
@@ -50,9 +50,7 @@ def _build_startup_message(
notes.append(f"failed to load: {', '.join(failed_engines)}")
if notes:
engine_list = f"{engine_list} ({'; '.join(notes)})"
project_aliases = sorted(
{alias for alias in runtime.project_aliases()}, key=str.lower
)
project_aliases = sorted(set(runtime.project_aliases()), key=str.lower)
project_list = ", ".join(project_aliases) if project_aliases else "none"
return (
f"\N{OCTOPUS} **takopi is ready**\n\n"
@@ -78,8 +76,7 @@ class TelegramBackend(TransportBackend):
def interactive_setup(self, *, force: bool) -> bool:
return interactive_setup(force=force)
def lock_token(self, *, transport_config: object, config_path: Path) -> str | None:
_ = config_path
def lock_token(self, *, transport_config: object, _config_path: Path) -> str | None:
settings = _expect_transport_settings(transport_config)
return settings.bot_token
+1 -1
View File
@@ -111,7 +111,7 @@ def _is_cancelled_label(label: str) -> bool:
return stripped.lower() == "cancelled"
@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class TelegramBridgeConfig:
bot: BotClient
runtime: TransportRuntime
File diff suppressed because it is too large Load Diff
+537
View File
@@ -0,0 +1,537 @@
from __future__ import annotations
from typing import Any, Protocol, TypeVar
import httpx
import msgspec
from ..logging import get_logger
from .api_models import Chat, ChatMember, File, ForumTopic, Message, Update, User
logger = get_logger(__name__)
T = TypeVar("T")
class RetryAfter(Exception):
def __init__(self, retry_after: float, description: str | None = None) -> None:
super().__init__(description or f"retry after {retry_after}")
self.retry_after = float(retry_after)
self.description = description
class TelegramRetryAfter(RetryAfter):
pass
def retry_after_from_payload(payload: dict[str, Any]) -> float | None:
params = payload.get("parameters")
if isinstance(params, dict):
retry_after = params.get("retry_after")
if isinstance(retry_after, (int, float)):
return float(retry_after)
return None
class BotClient(Protocol):
async def close(self) -> None: ...
async def get_updates(
self,
offset: int | None,
timeout_s: int = 50,
allowed_updates: list[str] | None = None,
) -> list[Update] | None: ...
async def get_file(self, file_id: str) -> File | None: ...
async def download_file(self, file_path: str) -> bytes | None: ...
async def send_message(
self,
chat_id: int,
text: str,
reply_to_message_id: int | None = None,
disable_notification: bool | None = False,
message_thread_id: int | None = None,
entities: list[dict] | None = None,
parse_mode: str | None = None,
reply_markup: dict[str, Any] | None = None,
*,
replace_message_id: int | None = None,
) -> Message | None: ...
async def send_document(
self,
chat_id: int,
filename: str,
content: bytes,
reply_to_message_id: int | None = None,
message_thread_id: int | None = None,
disable_notification: bool | None = False,
caption: str | None = None,
) -> Message | None: ...
async def edit_message_text(
self,
chat_id: int,
message_id: int,
text: str,
entities: list[dict] | None = None,
parse_mode: str | None = None,
reply_markup: dict[str, Any] | None = None,
*,
wait: bool = True,
) -> Message | None: ...
async def delete_message(
self,
chat_id: int,
message_id: int,
) -> bool: ...
async def set_my_commands(
self,
commands: list[dict[str, Any]],
*,
scope: dict[str, Any] | None = None,
language_code: str | None = None,
) -> bool: ...
async def get_me(self) -> User | None: ...
async def answer_callback_query(
self,
callback_query_id: str,
text: str | None = None,
show_alert: bool | None = None,
) -> bool: ...
async def get_chat(self, chat_id: int) -> Chat | None: ...
async def get_chat_member(
self, chat_id: int, user_id: int
) -> ChatMember | None: ...
async def create_forum_topic(
self,
chat_id: int,
name: str,
) -> ForumTopic | None: ...
async def edit_forum_topic(
self,
chat_id: int,
message_thread_id: int,
name: str,
) -> bool: ...
class HttpBotClient:
def __init__(
self,
token: str,
*,
timeout_s: float = 120,
http_client: httpx.AsyncClient | None = None,
) -> None:
if not token:
raise ValueError("Telegram token is empty")
self._base = f"https://api.telegram.org/bot{token}"
self._file_base = f"https://api.telegram.org/file/bot{token}"
self._http_client = http_client or httpx.AsyncClient(timeout=timeout_s)
self._owns_http_client = http_client is None
async def close(self) -> None:
if self._owns_http_client:
await self._http_client.aclose()
def _parse_telegram_envelope(
self,
*,
method: str,
resp: httpx.Response,
payload: Any,
) -> Any | None:
if not isinstance(payload, dict):
logger.error(
"telegram.invalid_payload",
method=method,
url=str(resp.request.url),
payload=payload,
)
return None
if not payload.get("ok"):
if payload.get("error_code") == 429:
retry_after = retry_after_from_payload(payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method=method,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after)
logger.error(
"telegram.api_error",
method=method,
url=str(resp.request.url),
payload=payload,
)
return None
logger.debug("telegram.response", method=method, payload=payload)
return payload.get("result")
async def _request(
self,
method: str,
*,
json: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
files: dict[str, Any] | None = None,
) -> Any | None:
request_payload = json if json is not None else data
logger.debug("telegram.request", method=method, payload=request_payload)
try:
if json is not None:
resp = await self._http_client.post(f"{self._base}/{method}", json=json)
else:
resp = await self._http_client.post(
f"{self._base}/{method}", data=data, files=files
)
except httpx.HTTPError as exc:
url = getattr(exc.request, "url", None)
logger.error(
"telegram.network_error",
method=method,
url=str(url) if url is not None else None,
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
try:
resp.raise_for_status()
except httpx.HTTPStatusError as exc:
if resp.status_code == 429:
retry_after: float | None = None
try:
response_payload = resp.json()
except Exception: # noqa: BLE001
response_payload = None
if isinstance(response_payload, dict):
retry_after = retry_after_from_payload(response_payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method=method,
status=resp.status_code,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after) from exc
body = resp.text
logger.error(
"telegram.http_error",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(exc),
body=body,
)
return None
try:
response_payload = resp.json()
except Exception as exc: # noqa: BLE001
body = resp.text
logger.error(
"telegram.bad_response",
method=method,
status=resp.status_code,
url=str(resp.request.url),
error=str(exc),
error_type=exc.__class__.__name__,
body=body,
)
return None
return self._parse_telegram_envelope(
method=method,
resp=resp,
payload=response_payload,
)
def _decode_result(
self,
*,
method: str,
payload: Any,
model: type[T],
) -> T | None:
if payload is None:
return None
try:
return msgspec.convert(payload, type=model)
except Exception as exc: # noqa: BLE001
logger.error(
"telegram.decode_error",
method=method,
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
return await self._request(method, json=json_data)
async def _post_form(
self,
method: str,
data: dict[str, Any],
files: dict[str, Any],
) -> Any | None:
return await self._request(method, data=data, files=files)
async def get_updates(
self,
offset: int | None,
timeout_s: int = 50,
allowed_updates: list[str] | None = None,
) -> list[Update] | None:
params: dict[str, Any] = {"timeout": timeout_s}
if offset is not None:
params["offset"] = offset
if allowed_updates is not None:
params["allowed_updates"] = allowed_updates
result = await self._post("getUpdates", params)
if result is None or not isinstance(result, list):
return None
try:
return msgspec.convert(result, type=list[Update])
except Exception as exc: # noqa: BLE001
logger.error(
"telegram.decode_error",
method="getUpdates",
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
async def get_file(self, file_id: str) -> File | None:
result = await self._post("getFile", {"file_id": file_id})
return self._decode_result(method="getFile", payload=result, model=File)
async def download_file(self, file_path: str) -> bytes | None:
url = f"{self._file_base}/{file_path}"
try:
resp = await self._http_client.get(url)
except httpx.HTTPError as exc:
request_url = getattr(exc.request, "url", None)
logger.error(
"telegram.file_network_error",
url=str(request_url) if request_url is not None else None,
error=str(exc),
error_type=exc.__class__.__name__,
)
return None
try:
resp.raise_for_status()
except httpx.HTTPStatusError as exc:
if resp.status_code == 429:
retry_after: float | None = None
try:
response_payload = resp.json()
except Exception: # noqa: BLE001
response_payload = None
if isinstance(response_payload, dict):
retry_after = retry_after_from_payload(response_payload)
retry_after = 5.0 if retry_after is None else retry_after
logger.warning(
"telegram.rate_limited",
method="download_file",
status=resp.status_code,
url=str(resp.request.url),
retry_after=retry_after,
)
raise TelegramRetryAfter(retry_after) from exc
logger.error(
"telegram.file_http_error",
status=resp.status_code,
url=str(resp.request.url),
error=str(exc),
body=resp.text,
)
return None
return resp.content
async def send_message(
self,
chat_id: int,
text: str,
reply_to_message_id: int | None = None,
disable_notification: bool | None = False,
message_thread_id: int | None = None,
entities: list[dict] | None = None,
parse_mode: str | None = None,
reply_markup: dict[str, Any] | None = None,
*,
replace_message_id: int | None = None,
) -> Message | None:
params: dict[str, Any] = {"chat_id": chat_id, "text": text}
if disable_notification is not None:
params["disable_notification"] = disable_notification
if reply_to_message_id is not None:
params["reply_to_message_id"] = reply_to_message_id
if message_thread_id is not None:
params["message_thread_id"] = message_thread_id
if entities is not None:
params["entities"] = entities
if parse_mode is not None:
params["parse_mode"] = parse_mode
if reply_markup is not None:
params["reply_markup"] = reply_markup
result = await self._post("sendMessage", params)
return self._decode_result(method="sendMessage", payload=result, model=Message)
async def send_document(
self,
chat_id: int,
filename: str,
content: bytes,
reply_to_message_id: int | None = None,
message_thread_id: int | None = None,
disable_notification: bool | None = False,
caption: str | None = None,
) -> Message | None:
params: dict[str, Any] = {"chat_id": chat_id}
if disable_notification is not None:
params["disable_notification"] = disable_notification
if reply_to_message_id is not None:
params["reply_to_message_id"] = reply_to_message_id
if message_thread_id is not None:
params["message_thread_id"] = message_thread_id
if caption is not None:
params["caption"] = caption
result = await self._post_form(
"sendDocument",
params,
files={"document": (filename, content)},
)
return self._decode_result(method="sendDocument", payload=result, model=Message)
async def edit_message_text(
self,
chat_id: int,
message_id: int,
text: str,
entities: list[dict] | None = None,
parse_mode: str | None = None,
reply_markup: dict[str, Any] | None = None,
*,
wait: bool = True,
) -> Message | None:
params: dict[str, Any] = {
"chat_id": chat_id,
"message_id": message_id,
"text": text,
}
if entities is not None:
params["entities"] = entities
if parse_mode is not None:
params["parse_mode"] = parse_mode
if reply_markup is not None:
params["reply_markup"] = reply_markup
result = await self._post("editMessageText", params)
return self._decode_result(
method="editMessageText",
payload=result,
model=Message,
)
async def delete_message(
self,
chat_id: int,
message_id: int,
) -> bool:
result = await self._post(
"deleteMessage",
{"chat_id": chat_id, "message_id": message_id},
)
return bool(result)
async def set_my_commands(
self,
commands: list[dict[str, Any]],
*,
scope: dict[str, Any] | None = None,
language_code: str | None = None,
) -> bool:
params: dict[str, Any] = {"commands": commands}
if scope is not None:
params["scope"] = scope
if language_code is not None:
params["language_code"] = language_code
result = await self._post("setMyCommands", params)
return bool(result)
async def get_me(self) -> User | None:
result = await self._post("getMe", {})
return self._decode_result(method="getMe", payload=result, model=User)
async def answer_callback_query(
self,
callback_query_id: str,
text: str | None = None,
show_alert: bool | None = None,
) -> bool:
params: dict[str, Any] = {"callback_query_id": callback_query_id}
if text is not None:
params["text"] = text
if show_alert is not None:
params["show_alert"] = show_alert
result = await self._post("answerCallbackQuery", params)
return bool(result)
async def get_chat(self, chat_id: int) -> Chat | None:
result = await self._post("getChat", {"chat_id": chat_id})
return self._decode_result(method="getChat", payload=result, model=Chat)
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
result = await self._post(
"getChatMember", {"chat_id": chat_id, "user_id": user_id}
)
return self._decode_result(
method="getChatMember",
payload=result,
model=ChatMember,
)
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
result = await self._post(
"createForumTopic", {"chat_id": chat_id, "name": name}
)
return self._decode_result(
method="createForumTopic",
payload=result,
model=ForumTopic,
)
async def edit_forum_topic(
self,
chat_id: int,
message_thread_id: int,
name: str,
) -> bool:
result = await self._post(
"editForumTopic",
{
"chat_id": chat_id,
"message_thread_id": message_thread_id,
"name": name,
},
)
return bool(result)
File diff suppressed because it is too large Load Diff
+12
View File
@@ -0,0 +1,12 @@
from __future__ import annotations
from .cancel import handle_callback_cancel, handle_cancel
from .menu import build_bot_commands
from .parse import is_cancel_command
__all__ = [
"build_bot_commands",
"handle_callback_cancel",
"handle_cancel",
"is_cancel_command",
]
+160
View File
@@ -0,0 +1,160 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from ...context import RunContext
from ...directives import DirectiveError
from ..chat_prefs import ChatPrefsStore
from ..engine_defaults import resolve_engine_for_message
from ..files import split_command_args
from ..topic_state import TopicStateStore
from ..topics import _topic_key
from ..types import TelegramIncomingMessage
from .reply import make_reply
if TYPE_CHECKING:
from ..bridge import TelegramBridgeConfig
AGENT_USAGE = "usage: `/agent`, `/agent set <engine>`, or `/agent clear`"
async def _check_agent_permissions(
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
) -> bool:
reply = make_reply(cfg, msg)
sender_id = msg.sender_id
if sender_id is None:
await reply(text="cannot verify sender for agent defaults.")
return False
is_private = msg.chat_type == "private"
if msg.chat_type is None:
is_private = msg.chat_id > 0
if is_private:
return True
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
if member is None:
await reply(text="failed to verify agent permissions.")
return False
if member.status in {"creator", "administrator"}:
return True
await reply(text="changing default agents is restricted to group admins.")
return False
async def _handle_agent_command(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
args_text: str,
ambient_context: RunContext | None,
topic_store: TopicStateStore | None,
chat_prefs: ChatPrefsStore | None,
*,
resolved_scope: str | None = None,
scope_chat_ids: frozenset[int] | None = None,
) -> None:
reply = make_reply(cfg, msg)
tkey = (
_topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
if topic_store is not None
else None
)
tokens = split_command_args(args_text)
action = tokens[0].lower() if tokens else "show"
if action in {"show", ""}:
try:
resolved = cfg.runtime.resolve_message(
text="",
reply_text=msg.reply_to_text,
ambient_context=ambient_context,
chat_id=msg.chat_id,
)
except DirectiveError as exc:
await reply(text=f"error:\n{exc}")
return
selection = await resolve_engine_for_message(
runtime=cfg.runtime,
context=resolved.context,
explicit_engine=None,
chat_id=msg.chat_id,
topic_key=tkey,
topic_store=topic_store,
chat_prefs=chat_prefs,
)
source_labels = {
"directive": "directive",
"topic_default": "topic default",
"chat_default": "chat default",
"project_default": "project default",
"global_default": "global default",
}
agent_line = f"agent: {selection.engine} ({source_labels[selection.source]})"
topic_default = selection.topic_default or "none"
if tkey is None:
topic_default = "none"
if chat_prefs is None:
chat_default = "unavailable"
else:
chat_default = selection.chat_default or "none"
project_default = (
selection.project_default
if selection.project_default is not None
else "none"
)
defaults_line = (
"defaults: "
f"topic: {topic_default}, "
f"chat: {chat_default}, "
f"project: {project_default}, "
f"global: {cfg.runtime.default_engine}"
)
available = ", ".join(cfg.runtime.engine_ids)
available_line = f"available: {available}"
await reply(text="\n\n".join([agent_line, defaults_line, available_line]))
return
if action == "set":
if len(tokens) < 2:
await reply(text=AGENT_USAGE)
return
if not await _check_agent_permissions(cfg, msg):
return
engine = tokens[1].strip().lower()
if engine not in cfg.runtime.engine_ids:
available = ", ".join(cfg.runtime.engine_ids)
await reply(
text=f"unknown engine `{engine}`.\navailable agents: `{available}`",
)
return
if tkey is not None:
if topic_store is None:
await reply(text="topic defaults are unavailable.")
return
await topic_store.set_default_engine(tkey[0], tkey[1], engine)
await reply(text=f"topic default agent set to `{engine}`")
return
if chat_prefs is None:
await reply(text="chat defaults are unavailable (no config path).")
return
await chat_prefs.set_default_engine(msg.chat_id, engine)
await reply(text=f"chat default agent set to `{engine}`")
return
if action == "clear":
if not await _check_agent_permissions(cfg, msg):
return
if tkey is not None:
if topic_store is None:
await reply(text="topic defaults are unavailable.")
return
await topic_store.clear_default_engine(tkey[0], tkey[1])
await reply(text="topic default agent cleared.")
return
if chat_prefs is None:
await reply(text="chat defaults are unavailable (no config path).")
return
await chat_prefs.clear_default_engine(msg.chat_id)
await reply(text="chat default agent cleared.")
return
await reply(text=AGENT_USAGE)
+69
View File
@@ -0,0 +1,69 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from ...logging import get_logger
from ...runner_bridge import RunningTasks
from ...transport import MessageRef
from ..types import TelegramCallbackQuery, TelegramIncomingMessage
from .reply import make_reply
if TYPE_CHECKING:
from ..bridge import TelegramBridgeConfig
logger = get_logger(__name__)
async def handle_cancel(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
running_tasks: RunningTasks,
) -> None:
reply = make_reply(cfg, msg)
chat_id = msg.chat_id
reply_id = msg.reply_to_message_id
if reply_id is None:
if msg.reply_to_text:
await reply(text="nothing is currently running for that message.")
return
await reply(text="reply to the progress message to cancel.")
return
progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id)
running_task = running_tasks.get(progress_ref)
if running_task is None:
await reply(text="nothing is currently running for that message.")
return
logger.info(
"cancel.requested",
chat_id=chat_id,
progress_message_id=reply_id,
)
running_task.cancel_requested.set()
async def handle_callback_cancel(
cfg: TelegramBridgeConfig,
query: TelegramCallbackQuery,
running_tasks: RunningTasks,
) -> None:
progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id)
running_task = running_tasks.get(progress_ref)
if running_task is None:
await cfg.bot.answer_callback_query(
callback_query_id=query.callback_query_id,
text="nothing is currently running for that message.",
)
return
logger.info(
"cancel.requested",
chat_id=query.chat_id,
progress_message_id=query.message_id,
)
running_task.cancel_requested.set()
await cfg.bot.answer_callback_query(
callback_query_id=query.callback_query_id,
text="cancelling...",
)
+107
View File
@@ -0,0 +1,107 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING
import anyio
from ...commands import CommandContext, get_command
from ...config import ConfigError
from ...logging import get_logger
from ...model import EngineId, ResumeToken
from ...runner_bridge import RunningTasks
from ...scheduler import ThreadScheduler
from ...transport import MessageRef
from ..files import split_command_args
from ..types import TelegramIncomingMessage
from .executor import _TelegramCommandExecutor
if TYPE_CHECKING:
from ..bridge import TelegramBridgeConfig
logger = get_logger(__name__)
async def _dispatch_command(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
text: str,
command_id: str,
args_text: str,
running_tasks: RunningTasks,
scheduler: ThreadScheduler,
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None,
stateful_mode: bool,
default_engine_override: EngineId | None,
) -> None:
allowlist = cfg.runtime.allowlist
chat_id = msg.chat_id
user_msg_id = msg.message_id
reply_ref = (
MessageRef(
channel_id=chat_id,
message_id=msg.reply_to_message_id,
thread_id=msg.thread_id,
)
if msg.reply_to_message_id is not None
else None
)
executor = _TelegramCommandExecutor(
exec_cfg=cfg.exec_cfg,
runtime=cfg.runtime,
running_tasks=running_tasks,
scheduler=scheduler,
on_thread_known=on_thread_known,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=msg.thread_id,
show_resume_line=cfg.show_resume_line,
stateful_mode=stateful_mode,
default_engine_override=default_engine_override,
)
message_ref = MessageRef(
channel_id=chat_id,
message_id=user_msg_id,
thread_id=msg.thread_id,
sender_id=msg.sender_id,
raw=msg.raw,
)
try:
backend = get_command(command_id, allowlist=allowlist, required=False)
except ConfigError as exc:
await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True)
return
if backend is None:
return
try:
plugin_config = cfg.runtime.plugin_config(command_id)
except ConfigError as exc:
await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True)
return
ctx = CommandContext(
command=command_id,
text=text,
args_text=args_text,
args=split_command_args(args_text),
message=message_ref,
reply_to=reply_ref,
reply_text=msg.reply_to_text,
config_path=cfg.runtime.config_path,
plugin_config=plugin_config,
runtime=cfg.runtime,
executor=executor,
)
try:
result = await backend.handle(ctx)
except Exception as exc:
logger.exception(
"command.failed",
command=command_id,
error=str(exc),
error_type=exc.__class__.__name__,
)
await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True)
return
if result is not None:
reply_to = message_ref if result.reply_to is None else result.reply_to
await executor.send(result.text, reply_to=reply_to, notify=result.notify)
+382
View File
@@ -0,0 +1,382 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
from dataclasses import dataclass
from functools import partial
from typing import cast
import anyio
from ...commands import CommandExecutor, RunMode, RunRequest, RunResult
from ...config import ConfigError
from ...context import RunContext
from ...logging import bind_run_context, clear_context, get_logger
from ...model import EngineId, ResumeToken, TakopiEvent
from ...progress import ProgressTracker
from ...router import RunnerUnavailableError
from ...runner import Runner
from ...runner_bridge import (
ExecBridgeConfig,
IncomingMessage as RunnerIncomingMessage,
RunningTasks,
handle_message,
)
from ...scheduler import ThreadScheduler
from ...transport import MessageRef, RenderedMessage, SendOptions
from ...transport_runtime import TransportRuntime
from ...utils.paths import reset_run_base_dir, set_run_base_dir
from ..bridge import send_plain
logger = get_logger(__name__)
@dataclass(slots=True)
class _ResumeLineProxy:
runner: Runner
@property
def engine(self) -> str:
return self.runner.engine
def is_resume_line(self, line: str) -> bool:
return self.runner.is_resume_line(line)
def format_resume(self, _: ResumeToken) -> str:
return ""
def extract_resume(self, text: str | None) -> ResumeToken | None:
return self.runner.extract_resume(text)
def run(
self, prompt: str, resume: ResumeToken | None
) -> AsyncIterator[TakopiEvent]:
return self.runner.run(prompt, resume)
def _should_show_resume_line(
*,
show_resume_line: bool,
stateful_mode: bool,
context: RunContext | None,
) -> bool:
if show_resume_line or not stateful_mode:
return True
return context is None or context.project is None
async def _send_runner_unavailable(
exec_cfg: ExecBridgeConfig,
*,
chat_id: int,
user_msg_id: int,
resume_token: ResumeToken | None,
runner: Runner,
reason: str,
thread_id: int | None = None,
) -> None:
tracker = ProgressTracker(engine=runner.engine)
tracker.set_resume(resume_token)
state = tracker.snapshot(resume_formatter=runner.format_resume)
message = exec_cfg.presenter.render_final(
state,
elapsed_s=0.0,
status="error",
answer=f"error:\n{reason}",
)
reply_to = MessageRef(channel_id=chat_id, message_id=user_msg_id)
await exec_cfg.transport.send(
channel_id=chat_id,
message=message,
options=SendOptions(reply_to=reply_to, notify=True, thread_id=thread_id),
)
async def _run_engine(
*,
exec_cfg: ExecBridgeConfig,
runtime: TransportRuntime,
running_tasks: RunningTasks | None,
chat_id: int,
user_msg_id: int,
text: str,
resume_token: ResumeToken | None,
context: RunContext | None,
reply_ref: MessageRef | None = None,
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
| None = None,
engine_override: EngineId | None = None,
thread_id: int | None = None,
show_resume_line: bool = True,
) -> None:
reply = partial(
send_plain,
exec_cfg.transport,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=thread_id,
)
try:
try:
entry = runtime.resolve_runner(
resume_token=resume_token,
engine_override=engine_override,
)
except RunnerUnavailableError as exc:
await reply(text=f"error:\n{exc}")
return
runner: Runner = entry.runner
if not show_resume_line:
runner = cast(Runner, _ResumeLineProxy(runner))
if not entry.available:
reason = entry.issue or "engine unavailable"
await _send_runner_unavailable(
exec_cfg,
chat_id=chat_id,
user_msg_id=user_msg_id,
resume_token=resume_token,
runner=runner,
reason=reason,
thread_id=thread_id,
)
return
try:
cwd = runtime.resolve_run_cwd(context)
except ConfigError as exc:
await reply(text=f"error:\n{exc}")
return
run_base_token = set_run_base_dir(cwd)
try:
run_fields = {
"chat_id": chat_id,
"user_msg_id": user_msg_id,
"engine": runner.engine,
"resume": resume_token.value if resume_token else None,
}
if context is not None:
run_fields["project"] = context.project
run_fields["branch"] = context.branch
if cwd is not None:
run_fields["cwd"] = str(cwd)
bind_run_context(**run_fields)
context_line = runtime.format_context_line(context)
incoming = RunnerIncomingMessage(
channel_id=chat_id,
message_id=user_msg_id,
text=text,
reply_to=reply_ref,
thread_id=thread_id,
)
await handle_message(
exec_cfg,
runner=runner,
incoming=incoming,
resume_token=resume_token,
context=context,
context_line=context_line,
strip_resume_line=runtime.is_resume_line,
running_tasks=running_tasks,
on_thread_known=on_thread_known,
)
finally:
reset_run_base_dir(run_base_token)
except Exception as exc:
logger.exception(
"handle.worker_failed",
error=str(exc),
error_type=exc.__class__.__name__,
)
finally:
clear_context()
class _CaptureTransport:
def __init__(self) -> None:
self._next_id = 1
self.last_message: RenderedMessage | None = None
async def send(
self,
*,
channel_id: int | str,
message: RenderedMessage,
options: SendOptions | None = None,
) -> MessageRef:
thread_id = options.thread_id if options is not None else None
ref = MessageRef(channel_id=channel_id, message_id=self._next_id)
self._next_id += 1
self.last_message = message
return MessageRef(
channel_id=ref.channel_id,
message_id=ref.message_id,
thread_id=thread_id,
)
async def edit(
self, *, ref: MessageRef, message: RenderedMessage, wait: bool = True
) -> MessageRef:
self.last_message = message
return ref
async def delete(self, *, ref: MessageRef) -> bool:
return True
async def close(self) -> None:
return None
class _TelegramCommandExecutor(CommandExecutor):
def __init__(
self,
*,
exec_cfg: ExecBridgeConfig,
runtime: TransportRuntime,
running_tasks: RunningTasks,
scheduler: ThreadScheduler,
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None,
chat_id: int,
user_msg_id: int,
thread_id: int | None,
show_resume_line: bool,
stateful_mode: bool,
default_engine_override: EngineId | None,
) -> None:
self._exec_cfg = exec_cfg
self._runtime = runtime
self._running_tasks = running_tasks
self._scheduler = scheduler
self._on_thread_known = on_thread_known
self._chat_id = chat_id
self._user_msg_id = user_msg_id
self._thread_id = thread_id
self._show_resume_line = show_resume_line
self._stateful_mode = stateful_mode
self._default_engine_override = default_engine_override
self._reply_ref = MessageRef(
channel_id=chat_id,
message_id=user_msg_id,
thread_id=thread_id,
)
def _apply_default_context(self, request: RunRequest) -> RunRequest:
if request.context is not None:
return request
context = self._runtime.default_context_for_chat(self._chat_id)
if context is None:
return request
return RunRequest(
prompt=request.prompt,
engine=request.engine,
context=context,
)
def _apply_default_engine(self, request: RunRequest) -> RunRequest:
if request.engine is not None or self._default_engine_override is None:
return request
return RunRequest(
prompt=request.prompt,
engine=self._default_engine_override,
context=request.context,
)
async def send(
self,
message: RenderedMessage | str,
*,
reply_to: MessageRef | None = None,
notify: bool = True,
) -> MessageRef | None:
rendered = (
message
if isinstance(message, RenderedMessage)
else RenderedMessage(text=message)
)
reply_ref = self._reply_ref if reply_to is None else reply_to
return await self._exec_cfg.transport.send(
channel_id=self._chat_id,
message=rendered,
options=SendOptions(
reply_to=reply_ref,
notify=notify,
thread_id=self._thread_id,
),
)
async def run_one(
self, request: RunRequest, *, mode: RunMode = "emit"
) -> RunResult:
request = self._apply_default_context(request)
request = self._apply_default_engine(request)
effective_show_resume_line = _should_show_resume_line(
show_resume_line=self._show_resume_line,
stateful_mode=self._stateful_mode,
context=request.context,
)
engine = self._runtime.resolve_engine(
engine_override=request.engine,
context=request.context,
)
on_thread_known = (
self._scheduler.note_thread_known
if self._on_thread_known is None
else self._on_thread_known
)
if mode == "capture":
capture = _CaptureTransport()
exec_cfg = ExecBridgeConfig(
transport=capture,
presenter=self._exec_cfg.presenter,
final_notify=False,
)
await _run_engine(
exec_cfg=exec_cfg,
runtime=self._runtime,
running_tasks={},
chat_id=self._chat_id,
user_msg_id=self._user_msg_id,
text=request.prompt,
resume_token=None,
context=request.context,
reply_ref=self._reply_ref,
on_thread_known=on_thread_known,
engine_override=engine,
thread_id=self._thread_id,
show_resume_line=effective_show_resume_line,
)
return RunResult(engine=engine, message=capture.last_message)
await _run_engine(
exec_cfg=self._exec_cfg,
runtime=self._runtime,
running_tasks=self._running_tasks,
chat_id=self._chat_id,
user_msg_id=self._user_msg_id,
text=request.prompt,
resume_token=None,
context=request.context,
reply_ref=self._reply_ref,
on_thread_known=on_thread_known,
engine_override=engine,
thread_id=self._thread_id,
show_resume_line=effective_show_resume_line,
)
return RunResult(engine=engine, message=None)
async def run_many(
self,
requests: Sequence[RunRequest],
*,
mode: RunMode = "emit",
parallel: bool = False,
) -> list[RunResult]:
if not parallel:
return [await self.run_one(request, mode=mode) for request in requests]
results: list[RunResult | None] = [None] * len(requests)
async with anyio.create_task_group() as tg:
async def run_idx(idx: int, request: RunRequest) -> None:
results[idx] = await self.run_one(request, mode=mode)
for idx, request in enumerate(requests):
tg.start_soon(run_idx, idx, request)
return [result for result in results if result is not None]
@@ -0,0 +1,589 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from ...config import ConfigError
from ...context import RunContext
from ...directives import DirectiveError
from ...transport_runtime import ResolvedMessage
from ..context import _format_context
from ..files import (
default_upload_name,
default_upload_path,
deny_reason,
format_bytes,
normalize_relative_path,
parse_file_command,
parse_file_prompt,
resolve_path_within_root,
write_bytes_atomic,
ZipTooLargeError,
zip_directory,
)
from ..topic_state import TopicStateStore
from ..topics import _maybe_update_topic_context, _topic_key
from ..types import TelegramDocument, TelegramIncomingMessage
from .reply import make_reply
if TYPE_CHECKING:
from ..bridge import TelegramBridgeConfig
FILE_PUT_USAGE = "usage: `/file put <path>`"
FILE_GET_USAGE = "usage: `/file get <path>`"
@dataclass(slots=True)
class _FilePutPlan:
resolved: ResolvedMessage
run_root: Path
path_value: str | None
force: bool
@dataclass(slots=True)
class _FilePutResult:
name: str
rel_path: Path | None
size: int | None
error: str | None
@dataclass(slots=True)
class _SavedFilePut:
context: RunContext | None
rel_path: Path
size: int
@dataclass(slots=True)
class _SavedFilePutGroup:
context: RunContext | None
base_dir: Path | None
saved: list[_FilePutResult]
failed: list[_FilePutResult]
def resolve_file_put_paths(
plan: _FilePutPlan,
*,
cfg: TelegramBridgeConfig,
require_dir: bool,
) -> tuple[Path | None, Path | None, str | None]:
path_value = plan.path_value
if not path_value:
return None, None, None
if require_dir or path_value.endswith("/"):
base_dir = normalize_relative_path(path_value)
if base_dir is None:
return None, None, "invalid upload path."
deny_rule = deny_reason(base_dir, cfg.files.deny_globs)
if deny_rule is not None:
return None, None, f"path denied by rule: {deny_rule}"
base_target = resolve_path_within_root(plan.run_root, base_dir)
if base_target is None:
return None, None, "upload path escapes the repo root."
if base_target.exists() and not base_target.is_dir():
return None, None, "upload path is a file."
return base_dir, None, None
rel_path = normalize_relative_path(path_value)
if rel_path is None:
return None, None, "invalid upload path."
return None, rel_path, None
async def _check_file_permissions(
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
) -> bool:
reply = make_reply(cfg, msg)
sender_id = msg.sender_id
if sender_id is None:
await reply(text="cannot verify sender for file transfer.")
return False
if cfg.files.allowed_user_ids:
if sender_id not in cfg.files.allowed_user_ids:
await reply(text="file transfer is not allowed for this user.")
return False
return True
is_private = msg.chat_type == "private"
if msg.chat_type is None:
is_private = msg.chat_id > 0
if is_private:
return True
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
if member is None:
await reply(text="failed to verify file transfer permissions.")
return False
if member.status in {"creator", "administrator"}:
return True
await reply(text="file transfer is restricted to group admins.")
return False
async def _prepare_file_put_plan(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
args_text: str,
ambient_context: RunContext | None,
topic_store: TopicStateStore | None,
) -> _FilePutPlan | None:
reply = make_reply(cfg, msg)
if not await _check_file_permissions(cfg, msg):
return None
try:
resolved = cfg.runtime.resolve_message(
text=args_text,
reply_text=msg.reply_to_text,
ambient_context=ambient_context,
chat_id=msg.chat_id,
)
except DirectiveError as exc:
await reply(text=f"error:\n{exc}")
return None
topic_key = _topic_key(msg, cfg) if topic_store is not None else None
await _maybe_update_topic_context(
cfg=cfg,
topic_store=topic_store,
topic_key=topic_key,
context=resolved.context,
context_source=resolved.context_source,
)
if resolved.context is None or resolved.context.project is None:
await reply(text="no project context available for file upload.")
return None
try:
run_root = cfg.runtime.resolve_run_cwd(resolved.context)
except ConfigError as exc:
await reply(text=f"error:\n{exc}")
return None
if run_root is None:
await reply(text="no project context available for file upload.")
return None
path_value, force, error = parse_file_prompt(resolved.prompt, allow_empty=True)
if error is not None:
await reply(text=error)
return None
return _FilePutPlan(
resolved=resolved,
run_root=run_root,
path_value=path_value,
force=force,
)
def _format_file_put_failures(failed: Sequence[_FilePutResult]) -> str | None:
if not failed:
return None
errors = ", ".join(
f"`{item.name}` ({item.error})" for item in failed if item.error is not None
)
if not errors:
return None
return f"failed: {errors}"
async def _save_document_payload(
cfg: TelegramBridgeConfig,
*,
document: TelegramDocument,
run_root: Path,
rel_path: Path | None,
base_dir: Path | None,
force: bool,
) -> _FilePutResult:
name = default_upload_name(document.file_name, None)
if (
document.file_size is not None
and document.file_size > cfg.files.max_upload_bytes
):
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error="file is too large to upload.",
)
file_info = await cfg.bot.get_file(document.file_id)
if file_info is None:
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error="failed to fetch file metadata.",
)
file_path = file_info.file_path
name = default_upload_name(document.file_name, file_path)
resolved_path = rel_path
if resolved_path is None:
if base_dir is None:
resolved_path = default_upload_path(
cfg.files.uploads_dir, document.file_name, file_path
)
else:
resolved_path = base_dir / name
deny_rule = deny_reason(resolved_path, cfg.files.deny_globs)
if deny_rule is not None:
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error=f"path denied by rule: {deny_rule}",
)
target = resolve_path_within_root(run_root, resolved_path)
if target is None:
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error="upload path escapes the repo root.",
)
if target.exists():
if target.is_dir():
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error="upload target is a directory.",
)
if not force:
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error="file already exists; use --force to overwrite.",
)
payload = await cfg.bot.download_file(file_path)
if payload is None:
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error="failed to download file.",
)
if len(payload) > cfg.files.max_upload_bytes:
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error="file is too large to upload.",
)
try:
write_bytes_atomic(target, payload)
except OSError as exc:
return _FilePutResult(
name=name,
rel_path=None,
size=None,
error=f"failed to write file: {exc}",
)
return _FilePutResult(
name=name,
rel_path=resolved_path,
size=len(payload),
error=None,
)
async def _handle_file_command(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
args_text: str,
ambient_context: RunContext | None,
topic_store: TopicStateStore | None,
) -> None:
reply = make_reply(cfg, msg)
command, rest, error = parse_file_command(args_text)
if error is not None:
await reply(text=error)
return
if command == "put":
await _handle_file_put(cfg, msg, rest, ambient_context, topic_store)
else:
await _handle_file_get(cfg, msg, rest, ambient_context, topic_store)
async def _handle_file_put_default(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
ambient_context: RunContext | None,
topic_store: TopicStateStore | None,
) -> None:
await _handle_file_put(cfg, msg, "", ambient_context, topic_store)
async def _save_file_put(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
args_text: str,
ambient_context: RunContext | None,
topic_store: TopicStateStore | None,
) -> _SavedFilePut | None:
reply = make_reply(cfg, msg)
document = msg.document
if document is None:
await reply(text=FILE_PUT_USAGE)
return None
plan = await _prepare_file_put_plan(
cfg,
msg,
args_text,
ambient_context,
topic_store,
)
if plan is None:
return None
base_dir, rel_path, error = resolve_file_put_paths(
plan,
cfg=cfg,
require_dir=False,
)
if error is not None:
await reply(text=error)
return None
result = await _save_document_payload(
cfg,
document=document,
run_root=plan.run_root,
rel_path=rel_path,
base_dir=base_dir,
force=plan.force,
)
if result.error is not None:
await reply(text=result.error)
return None
if result.rel_path is None or result.size is None:
await reply(text="failed to save file.")
return None
return _SavedFilePut(
context=plan.resolved.context,
rel_path=result.rel_path,
size=result.size,
)
async def _handle_file_put(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
args_text: str,
ambient_context: RunContext | None,
topic_store: TopicStateStore | None,
) -> None:
reply = make_reply(cfg, msg)
saved = await _save_file_put(
cfg,
msg,
args_text,
ambient_context,
topic_store,
)
if saved is None:
return
context_label = _format_context(cfg.runtime, saved.context)
await reply(
text=(
f"saved `{saved.rel_path.as_posix()}` "
f"in `{context_label}` ({format_bytes(saved.size)})"
),
)
async def _handle_file_put_group(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
args_text: str,
messages: Sequence[TelegramIncomingMessage],
ambient_context: RunContext | None,
topic_store: TopicStateStore | None,
) -> None:
reply = make_reply(cfg, msg)
saved_group = await _save_file_put_group(
cfg,
msg,
args_text,
messages,
ambient_context,
topic_store,
)
if saved_group is None:
return
context_label = _format_context(cfg.runtime, saved_group.context)
total_bytes = sum(item.size or 0 for item in saved_group.saved)
dir_label: Path | None = saved_group.base_dir
if dir_label is None and saved_group.saved:
first_path = saved_group.saved[0].rel_path
if first_path is not None:
dir_label = first_path.parent
if saved_group.saved:
saved_names = ", ".join(f"`{item.name}`" for item in saved_group.saved)
if dir_label is not None:
dir_text = dir_label.as_posix()
if not dir_text.endswith("/"):
dir_text = f"{dir_text}/"
text = (
f"saved {saved_names} to `{dir_text}` "
f"in `{context_label}` ({format_bytes(total_bytes)})"
)
else:
text = (
f"saved {saved_names} in `{context_label}` "
f"({format_bytes(total_bytes)})"
)
else:
text = "failed to upload files."
failure_text = _format_file_put_failures(saved_group.failed)
if failure_text is not None:
text = f"{text}\n\n{failure_text}"
await reply(text=text)
async def _save_file_put_group(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
args_text: str,
messages: Sequence[TelegramIncomingMessage],
ambient_context: RunContext | None,
topic_store: TopicStateStore | None,
) -> _SavedFilePutGroup | None:
reply = make_reply(cfg, msg)
documents = [item.document for item in messages if item.document is not None]
if not documents:
await reply(text=FILE_PUT_USAGE)
return None
plan = await _prepare_file_put_plan(
cfg,
msg,
args_text,
ambient_context,
topic_store,
)
if plan is None:
return None
base_dir, _, error = resolve_file_put_paths(
plan,
cfg=cfg,
require_dir=True,
)
if error is not None:
await reply(text=error)
return None
saved: list[_FilePutResult] = []
failed: list[_FilePutResult] = []
for document in documents:
result = await _save_document_payload(
cfg,
document=document,
run_root=plan.run_root,
rel_path=None,
base_dir=base_dir,
force=plan.force,
)
if result.error is None:
saved.append(result)
else:
failed.append(result)
return _SavedFilePutGroup(
context=plan.resolved.context,
base_dir=base_dir,
saved=saved,
failed=failed,
)
async def _handle_file_get(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
args_text: str,
ambient_context: RunContext | None,
topic_store: TopicStateStore | None,
) -> None:
reply = make_reply(cfg, msg)
if not await _check_file_permissions(cfg, msg):
return
try:
resolved = cfg.runtime.resolve_message(
text=args_text,
reply_text=msg.reply_to_text,
ambient_context=ambient_context,
chat_id=msg.chat_id,
)
except DirectiveError as exc:
await reply(text=f"error:\n{exc}")
return
topic_key = _topic_key(msg, cfg) if topic_store is not None else None
await _maybe_update_topic_context(
cfg=cfg,
topic_store=topic_store,
topic_key=topic_key,
context=resolved.context,
context_source=resolved.context_source,
)
if resolved.context is None or resolved.context.project is None:
await reply(text="no project context available for file download.")
return
try:
run_root = cfg.runtime.resolve_run_cwd(resolved.context)
except ConfigError as exc:
await reply(text=f"error:\n{exc}")
return
if run_root is None:
await reply(text="no project context available for file download.")
return
path_value = resolved.prompt
if not path_value.strip():
await reply(text=FILE_GET_USAGE)
return
rel_path = normalize_relative_path(path_value)
if rel_path is None:
await reply(text="invalid download path.")
return
deny_rule = deny_reason(rel_path, cfg.files.deny_globs)
if deny_rule is not None:
await reply(text=f"path denied by rule: {deny_rule}")
return
target = resolve_path_within_root(run_root, rel_path)
if target is None:
await reply(text="download path escapes the repo root.")
return
if not target.exists():
await reply(text="file does not exist.")
return
if target.is_dir():
try:
payload = zip_directory(
run_root,
rel_path,
cfg.files.deny_globs,
max_bytes=cfg.files.max_download_bytes,
)
except ZipTooLargeError:
await reply(text="file is too large to send.")
return
except OSError as exc:
await reply(text=f"failed to read directory: {exc}")
return
filename = f"{rel_path.name or 'archive'}.zip"
else:
try:
size = target.stat().st_size
if size > cfg.files.max_download_bytes:
await reply(text="file is too large to send.")
return
payload = target.read_bytes()
except OSError as exc:
await reply(text=f"failed to read file: {exc}")
return
filename = target.name
if len(payload) > cfg.files.max_download_bytes:
await reply(text="file is too large to send.")
return
sent = await cfg.bot.send_document(
chat_id=msg.chat_id,
filename=filename,
content=payload,
reply_to_message_id=msg.message_id,
message_thread_id=msg.thread_id,
)
if sent is None:
await reply(text="failed to send file.")
return
+143
View File
@@ -0,0 +1,143 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable, Sequence
from typing import TYPE_CHECKING
from ...context import RunContext
from ...directives import DirectiveError
from ...transport_runtime import ResolvedMessage
from ..context import _merge_topic_context
from ..files import parse_file_command
from ..topic_state import TopicStateStore
from ..topics import _topic_key, _topics_chat_project
from ..types import TelegramIncomingMessage
from .file_transfer import (
FILE_PUT_USAGE,
_format_file_put_failures,
_handle_file_put_group,
_save_file_put_group,
)
from .parse import _parse_slash_command
from .reply import make_reply
if TYPE_CHECKING:
from ..bridge import TelegramBridgeConfig
async def _handle_media_group(
cfg: TelegramBridgeConfig,
messages: Sequence[TelegramIncomingMessage],
topic_store: TopicStateStore | None,
run_prompt: Callable[
[TelegramIncomingMessage, str, ResolvedMessage], Awaitable[None]
]
| None = None,
resolve_prompt: Callable[
[TelegramIncomingMessage, str, RunContext | None],
Awaitable[ResolvedMessage | None],
]
| None = None,
) -> None:
if not messages:
return
ordered = sorted(messages, key=lambda item: item.message_id)
command_msg = next(
(item for item in ordered if item.text.strip()),
ordered[0],
)
reply = make_reply(cfg, command_msg)
topic_key = _topic_key(command_msg, cfg) if topic_store is not None else None
chat_project = _topics_chat_project(cfg, command_msg.chat_id)
bound_context = (
await topic_store.get_context(*topic_key)
if topic_store is not None and topic_key is not None
else None
)
ambient_context = _merge_topic_context(
chat_project=chat_project, bound=bound_context
)
command_id, args_text = _parse_slash_command(command_msg.text)
if command_id == "file":
command, rest, error = parse_file_command(args_text)
if error is not None:
await reply(text=error)
return
if command == "put":
await _handle_file_put_group(
cfg,
command_msg,
rest,
ordered,
ambient_context,
topic_store,
)
return
if cfg.files.enabled and cfg.files.auto_put:
caption_text = command_msg.text.strip()
if cfg.files.auto_put_mode == "prompt" and caption_text:
if resolve_prompt is None:
try:
resolved = cfg.runtime.resolve_message(
text=caption_text,
reply_text=command_msg.reply_to_text,
ambient_context=ambient_context,
chat_id=command_msg.chat_id,
)
except DirectiveError as exc:
await reply(text=f"error:\n{exc}")
return
else:
resolved = await resolve_prompt(
command_msg, caption_text, ambient_context
)
if resolved is None:
return
saved_group = await _save_file_put_group(
cfg,
command_msg,
"",
ordered,
resolved.context,
topic_store,
)
if saved_group is None:
return
if not saved_group.saved:
failure_text = _format_file_put_failures(saved_group.failed)
text = "failed to upload files."
if failure_text is not None:
text = f"{text}\n\n{failure_text}"
await reply(text=text)
return
if saved_group.failed:
failure_text = _format_file_put_failures(saved_group.failed)
if failure_text is not None:
await reply(text=f"some files failed to upload.\n\n{failure_text}")
if run_prompt is None:
await reply(text=FILE_PUT_USAGE)
return
paths = [
item.rel_path.as_posix()
for item in saved_group.saved
if item.rel_path is not None
]
files_text = "\n".join(f"- {path}" for path in paths)
prompt_base = resolved.prompt
annotation = f"[uploaded files]\n{files_text}"
if prompt_base and prompt_base.strip():
prompt = f"{prompt_base}\n\n{annotation}"
else:
prompt = annotation
await run_prompt(command_msg, prompt, resolved)
return
if not caption_text:
await _handle_file_put_group(
cfg,
command_msg,
"",
ordered,
ambient_context,
topic_store,
)
return
await reply(text=FILE_PUT_USAGE)
+114
View File
@@ -0,0 +1,114 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from ...commands import get_command
from ...config import ConfigError
from ...ids import RESERVED_COMMAND_IDS, is_valid_id
from ...logging import get_logger
from ...plugins import COMMAND_GROUP, list_entrypoints
from ...transport_runtime import TransportRuntime
if TYPE_CHECKING:
from ..bridge import TelegramBridgeConfig
logger = get_logger(__name__)
_MAX_BOT_COMMANDS = 100
def build_bot_commands(
runtime: TransportRuntime, *, include_file: bool = True
) -> list[dict[str, str]]:
commands: list[dict[str, str]] = []
seen: set[str] = set()
for engine_id in runtime.available_engine_ids():
cmd = engine_id.lower()
if cmd in seen:
continue
commands.append({"command": cmd, "description": f"use agent: {cmd}"})
seen.add(cmd)
for alias in runtime.project_aliases():
cmd = alias.lower()
if cmd in seen:
continue
if not is_valid_id(cmd):
logger.debug(
"startup.command_menu.skip_project",
alias=alias,
)
continue
commands.append({"command": cmd, "description": f"work on: {cmd}"})
seen.add(cmd)
allowlist = runtime.allowlist
for ep in list_entrypoints(
COMMAND_GROUP,
allowlist=allowlist,
reserved_ids=RESERVED_COMMAND_IDS,
):
try:
backend = get_command(ep.name, allowlist=allowlist)
except ConfigError as exc:
logger.info(
"startup.command_menu.skip_command",
command=ep.name,
error=str(exc),
)
continue
cmd = backend.id.lower()
if cmd in seen:
continue
if not is_valid_id(cmd):
logger.debug(
"startup.command_menu.skip_command_id",
command=cmd,
)
continue
description = backend.description or f"command: {cmd}"
commands.append({"command": cmd, "description": description})
seen.add(cmd)
if include_file and "file" not in seen:
commands.append({"command": "file", "description": "upload or fetch files"})
seen.add("file")
if "cancel" not in seen:
commands.append({"command": "cancel", "description": "cancel run"})
if len(commands) > _MAX_BOT_COMMANDS:
logger.warning(
"startup.command_menu.too_many",
count=len(commands),
limit=_MAX_BOT_COMMANDS,
)
commands = commands[:_MAX_BOT_COMMANDS]
if not any(cmd["command"] == "cancel" for cmd in commands):
commands[-1] = {"command": "cancel", "description": "cancel run"}
return commands
def _reserved_commands(runtime: TransportRuntime) -> set[str]:
return {
*{engine.lower() for engine in runtime.engine_ids},
*{alias.lower() for alias in runtime.project_aliases()},
*RESERVED_COMMAND_IDS,
}
async def _set_command_menu(cfg: TelegramBridgeConfig) -> None:
commands = build_bot_commands(cfg.runtime, include_file=cfg.files.enabled)
if not commands:
return
try:
ok = await cfg.bot.set_my_commands(commands)
except Exception as exc: # noqa: BLE001
logger.info(
"startup.command_menu.failed",
error=str(exc),
error_type=exc.__class__.__name__,
)
return
if not ok:
logger.info("startup.command_menu.rejected")
return
logger.info(
"startup.command_menu.updated",
commands=[cmd["command"] for cmd in commands],
)
+30
View File
@@ -0,0 +1,30 @@
from __future__ import annotations
def is_cancel_command(text: str) -> bool:
stripped = text.strip()
if not stripped:
return False
command = stripped.split(maxsplit=1)[0]
return command == "/cancel" or command.startswith("/cancel@")
def _parse_slash_command(text: str) -> tuple[str | None, str]:
stripped = text.lstrip()
if not stripped.startswith("/"):
return None, text
lines = stripped.splitlines()
if not lines:
return None, text
first_line = lines[0]
token, _, rest = first_line.partition(" ")
command = token[1:]
if not command:
return None, text
if "@" in command:
command = command.split("@", 1)[0]
args_text = rest
if len(lines) > 1:
tail = "\n".join(lines[1:])
args_text = f"{args_text}\n{tail}" if args_text else tail
return command.lower(), args_text
+23
View File
@@ -0,0 +1,23 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from functools import partial
from typing import TYPE_CHECKING
from ..bridge import send_plain
from ..types import TelegramIncomingMessage
if TYPE_CHECKING:
from ..bridge import TelegramBridgeConfig
def make_reply(
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
) -> Callable[..., Awaitable[None]]:
return partial(
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
+220
View File
@@ -0,0 +1,220 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from ...markdown import MarkdownParts
from ...transport import RenderedMessage, SendOptions
from ..chat_sessions import ChatSessionStore
from ..context import (
_format_context,
_format_ctx_status,
_merge_topic_context,
_parse_project_branch_args,
_usage_ctx_set,
_usage_topic,
)
from ..files import split_command_args
from ..render import prepare_telegram
from ..topic_state import TopicStateStore
from ..topics import (
_maybe_rename_topic,
_topic_key,
_topic_title,
_topics_chat_project,
_topics_command_error,
)
from ..types import TelegramIncomingMessage
from .reply import make_reply
if TYPE_CHECKING:
from ..bridge import TelegramBridgeConfig
async def _handle_ctx_command(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
args_text: str,
store: TopicStateStore,
*,
resolved_scope: str | None = None,
scope_chat_ids: frozenset[int] | None = None,
) -> None:
reply = make_reply(cfg, msg)
error = _topics_command_error(
cfg,
msg.chat_id,
resolved_scope=resolved_scope,
scope_chat_ids=scope_chat_ids,
)
if error is not None:
await reply(text=error)
return
chat_project = _topics_chat_project(cfg, msg.chat_id)
tkey = _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
if tkey is None:
await reply(text="this command only works inside a topic.")
return
tokens = split_command_args(args_text)
action = tokens[0].lower() if tokens else "show"
if action in {"show", ""}:
snapshot = await store.get_thread(*tkey)
bound = snapshot.context if snapshot is not None else None
ambient = _merge_topic_context(chat_project=chat_project, bound=bound)
resolved = cfg.runtime.resolve_message(
text="",
reply_text=msg.reply_to_text,
chat_id=msg.chat_id,
ambient_context=ambient,
)
text = _format_ctx_status(
cfg=cfg,
runtime=cfg.runtime,
bound=bound,
resolved=resolved.context,
context_source=resolved.context_source,
snapshot=snapshot,
chat_project=chat_project,
)
await reply(text=text)
return
if action == "set":
rest = " ".join(tokens[1:])
context, error = _parse_project_branch_args(
rest,
runtime=cfg.runtime,
require_branch=False,
chat_project=chat_project,
)
if error is not None:
await reply(
text=f"error:\n{error}\n{_usage_ctx_set(chat_project=chat_project)}",
)
return
if context is None:
await reply(
text=f"error:\n{_usage_ctx_set(chat_project=chat_project)}",
)
return
await store.set_context(*tkey, context)
await _maybe_rename_topic(
cfg,
store,
chat_id=tkey[0],
thread_id=tkey[1],
context=context,
)
await reply(
text=f"topic bound to `{_format_context(cfg.runtime, context)}`",
)
return
if action == "clear":
await store.clear_context(*tkey)
await reply(text="topic binding cleared.")
return
await reply(
text="unknown `/ctx` command. use `/ctx`, `/ctx set`, or `/ctx clear`.",
)
async def _handle_new_command(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
store: TopicStateStore,
*,
resolved_scope: str | None = None,
scope_chat_ids: frozenset[int] | None = None,
) -> None:
reply = make_reply(cfg, msg)
error = _topics_command_error(
cfg,
msg.chat_id,
resolved_scope=resolved_scope,
scope_chat_ids=scope_chat_ids,
)
if error is not None:
await reply(text=error)
return
tkey = _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
if tkey is None:
await reply(text="this command only works inside a topic.")
return
await store.clear_sessions(*tkey)
await reply(text="cleared stored sessions for this topic.")
async def _handle_chat_new_command(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
store: ChatSessionStore,
session_key: tuple[int, int | None] | None,
) -> None:
reply = make_reply(cfg, msg)
if session_key is None:
await reply(text="no stored sessions to clear for this chat.")
return
await store.clear_sessions(session_key[0], session_key[1])
if msg.chat_type == "private":
text = "cleared stored sessions for this chat."
else:
text = "cleared stored sessions for you in this chat."
await reply(text=text)
async def _handle_topic_command(
cfg: TelegramBridgeConfig,
msg: TelegramIncomingMessage,
args_text: str,
store: TopicStateStore,
*,
resolved_scope: str | None = None,
scope_chat_ids: frozenset[int] | None = None,
) -> None:
reply = make_reply(cfg, msg)
error = _topics_command_error(
cfg,
msg.chat_id,
resolved_scope=resolved_scope,
scope_chat_ids=scope_chat_ids,
)
if error is not None:
await reply(text=error)
return
chat_project = _topics_chat_project(cfg, msg.chat_id)
context, error = _parse_project_branch_args(
args_text,
runtime=cfg.runtime,
require_branch=True,
chat_project=chat_project,
)
if error is not None or context is None:
usage = _usage_topic(chat_project=chat_project)
text = f"error:\n{error}\n{usage}" if error else usage
await reply(text=text)
return
existing = await store.find_thread_for_context(msg.chat_id, context)
if existing is not None:
await reply(
text=f"topic already exists for {_format_context(cfg.runtime, context)} "
"in this chat.",
)
return
title = _topic_title(runtime=cfg.runtime, context=context)
created = await cfg.bot.create_forum_topic(msg.chat_id, title)
if created is None:
await reply(text="failed to create topic.")
return
thread_id = created.message_thread_id
await store.set_context(
msg.chat_id,
thread_id,
context,
topic_title=title,
)
await reply(text=f"created topic `{title}`.")
bound_text = f"topic bound to `{_format_context(cfg.runtime, context)}`"
rendered_text, entities = prepare_telegram(MarkdownParts(header=bound_text))
await cfg.exec_cfg.transport.send(
channel_id=msg.chat_id,
message=RenderedMessage(text=rendered_text, extra={"entities": entities}),
options=SendOptions(thread_id=thread_id),
)
+16 -29
View File
@@ -20,26 +20,25 @@ from ..transport import MessageRef
from ..transport_runtime import ResolvedMessage
from ..context import RunContext
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
from .commands import (
from .commands.agent import _handle_agent_command
from .commands.cancel import handle_callback_cancel, handle_cancel
from .commands.dispatch import _dispatch_command
from .commands.executor import _run_engine, _should_show_resume_line
from .commands.file_transfer import (
FILE_PUT_USAGE,
_dispatch_command,
_handle_agent_command,
_handle_chat_new_command,
_handle_ctx_command,
_handle_file_command,
_handle_file_put_default,
_handle_media_group,
_save_file_put,
)
from .commands.media import _handle_media_group
from .commands.menu import _reserved_commands, _set_command_menu
from .commands.parse import _parse_slash_command, is_cancel_command
from .commands.reply import make_reply
from .commands.topics import (
_handle_chat_new_command,
_handle_ctx_command,
_handle_new_command,
_handle_topic_command,
_parse_slash_command,
_reserved_commands,
_save_file_put,
_should_show_resume_line,
_run_engine,
_set_command_menu,
handle_callback_cancel,
handle_cancel,
is_cancel_command,
)
from .context import _merge_topic_context, _usage_ctx_set, _usage_topic
from .topics import (
@@ -519,13 +518,7 @@ async def run_main_loop(
text: str,
ambient_context: RunContext | None,
) -> ResolvedMessage | None:
reply = partial(
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
reply = make_reply(cfg, msg)
try:
resolved = cfg.runtime.resolve_message(
text=text,
@@ -757,13 +750,7 @@ async def run_main_loop(
if reply_id is not None
else None
)
reply = partial(
send_plain,
cfg.exec_cfg.transport,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=msg.thread_id,
)
reply = make_reply(cfg, msg)
text = msg.text
if msg.voice is not None:
text = await transcribe_voice(
+2 -2
View File
@@ -286,8 +286,8 @@ def _confirm(message: str, *, default: bool = True) -> bool | None:
exit_with_result(event)
@bindings.add(Keys.Any)
def other(event):
_ = event
def other(_event):
return None
question = Question(
PromptSession(get_prompt_tokens, key_bindings=bindings, style=merged_style).app
+177
View File
@@ -0,0 +1,177 @@
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING
from collections.abc import Awaitable, Callable, Hashable
import anyio
from .client_api import RetryAfter
SEND_PRIORITY = 0
DELETE_PRIORITY = 1
EDIT_PRIORITY = 2
@dataclass(slots=True)
class OutboxOp:
execute: Callable[[], Awaitable[Any]]
priority: int
queued_at: float
chat_id: int | None
label: str | None = None
done: anyio.Event = field(default_factory=anyio.Event)
result: Any = None
def set_result(self, result: Any) -> None:
if self.done.is_set():
return
self.result = result
self.done.set()
class TelegramOutbox:
def __init__(
self,
*,
interval_for_chat: Callable[[int | None], float],
clock: Callable[[], float] = time.monotonic,
sleep: Callable[[float], Awaitable[None]] = anyio.sleep,
on_error: Callable[[OutboxOp, Exception], None] | None = None,
on_outbox_error: Callable[[Exception], None] | None = None,
) -> None:
self._interval_for_chat = interval_for_chat
self._clock = clock
self._sleep = sleep
self._on_error = on_error
self._on_outbox_error = on_outbox_error
self._pending: dict[Hashable, OutboxOp] = {}
self._cond = anyio.Condition()
self._start_lock = anyio.Lock()
self._closed = False
self._tg: TaskGroup | None = None
self.next_at = 0.0
self.retry_at = 0.0
async def ensure_worker(self) -> None:
async with self._start_lock:
if self._tg is not None or self._closed:
return
self._tg = await anyio.create_task_group().__aenter__()
self._tg.start_soon(self.run)
async def enqueue(self, *, key: Hashable, op: OutboxOp, wait: bool = True) -> Any:
await self.ensure_worker()
async with self._cond:
if self._closed:
op.set_result(None)
return op.result
previous = self._pending.get(key)
if previous is not None:
op.queued_at = previous.queued_at
previous.set_result(None)
self._pending[key] = op
self._cond.notify()
if not wait:
return None
await op.done.wait()
return op.result
async def drop_pending(self, *, key: Hashable) -> None:
async with self._cond:
pending = self._pending.pop(key, None)
if pending is not None:
pending.set_result(None)
self._cond.notify()
async def close(self) -> None:
async with self._cond:
self._closed = True
self.fail_pending()
self._cond.notify_all()
if self._tg is not None:
await self._tg.__aexit__(None, None, None)
self._tg = None
def fail_pending(self) -> None:
for pending in list(self._pending.values()):
pending.set_result(None)
self._pending.clear()
def pick_locked(self) -> tuple[Hashable, OutboxOp] | None:
if not self._pending:
return None
return min(
self._pending.items(),
key=lambda item: (item[1].priority, item[1].queued_at),
)
async def execute_op(self, op: OutboxOp) -> Any:
try:
return await op.execute()
except Exception as exc: # noqa: BLE001
if isinstance(exc, RetryAfter):
raise
if self._on_error is not None:
self._on_error(op, exc)
return None
async def sleep_until(self, deadline: float) -> None:
delay = deadline - self._clock()
if delay > 0:
await self._sleep(delay)
async def run(self) -> None:
cancel_exc = anyio.get_cancelled_exc_class()
try:
while True:
async with self._cond:
while not self._pending and not self._closed:
await self._cond.wait()
if self._closed and not self._pending:
return
blocked_until = max(self.next_at, self.retry_at)
if self._clock() < blocked_until:
await self.sleep_until(blocked_until)
continue
async with self._cond:
if self._closed and not self._pending:
return
picked = self.pick_locked()
if picked is None:
continue
key, op = picked
self._pending.pop(key, None)
started_at = self._clock()
try:
result = await self.execute_op(op)
except RetryAfter as exc:
self.retry_at = max(self.retry_at, self._clock() + exc.retry_after)
async with self._cond:
if self._closed:
op.set_result(None)
elif key not in self._pending:
self._pending[key] = op
self._cond.notify()
else:
op.set_result(None)
continue
self.next_at = started_at + self._interval_for_chat(op.chat_id)
op.set_result(result)
except cancel_exc:
return
except Exception as exc: # noqa: BLE001
async with self._cond:
self._closed = True
self.fail_pending()
self._cond.notify_all()
if self._on_outbox_error is not None:
self._on_outbox_error(exc)
return
if TYPE_CHECKING:
from anyio.abc import TaskGroup
else:
TaskGroup = object
+283
View File
@@ -0,0 +1,283 @@
from __future__ import annotations
from typing import Any
from collections.abc import AsyncIterator, Callable, Iterable
import anyio
from ..logging import get_logger
from .api_models import Update
from .client_api import BotClient
from .types import (
TelegramCallbackQuery,
TelegramDocument,
TelegramIncomingMessage,
TelegramIncomingUpdate,
TelegramVoice,
)
logger = get_logger(__name__)
def parse_incoming_update(
update: Update | dict[str, Any],
*,
chat_id: int | None = None,
chat_ids: set[int] | None = None,
) -> TelegramIncomingUpdate | None:
if isinstance(update, Update):
msg = update.message
callback_query = update.callback_query
else:
msg = update.get("message")
callback_query = update.get("callback_query")
if isinstance(msg, dict):
return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids)
if isinstance(callback_query, dict):
return _parse_callback_query(
callback_query,
chat_id=chat_id,
chat_ids=chat_ids,
)
return None
def _parse_incoming_message(
msg: dict[str, Any],
*,
chat_id: int | None = None,
chat_ids: set[int] | None = None,
) -> TelegramIncomingMessage | None:
def _parse_document_payload(payload: dict[str, Any]) -> TelegramDocument | None:
file_id = payload.get("file_id")
if not isinstance(file_id, str) or not file_id:
return None
return TelegramDocument(
file_id=file_id,
file_name=payload.get("file_name")
if isinstance(payload.get("file_name"), str)
else None,
mime_type=payload.get("mime_type")
if isinstance(payload.get("mime_type"), str)
else None,
file_size=payload.get("file_size")
if isinstance(payload.get("file_size"), int)
and not isinstance(payload.get("file_size"), bool)
else None,
raw=payload,
)
raw_text = msg.get("text")
text = raw_text if isinstance(raw_text, str) else None
caption = msg.get("caption")
if text is None and isinstance(caption, str):
text = caption
if text is None:
text = ""
file_command = False
if isinstance(text, str):
stripped = text.lstrip()
if stripped.startswith("/"):
token = stripped.split(maxsplit=1)[0]
file_command = token.startswith("/file")
voice_payload: TelegramVoice | None = None
voice = msg.get("voice")
if isinstance(voice, dict):
file_id = voice.get("file_id")
if not isinstance(file_id, str) or not file_id:
file_id = None
if file_id is not None:
voice_payload = TelegramVoice(
file_id=file_id,
mime_type=voice.get("mime_type")
if isinstance(voice.get("mime_type"), str)
else None,
file_size=voice.get("file_size")
if isinstance(voice.get("file_size"), int)
and not isinstance(voice.get("file_size"), bool)
else None,
duration=voice.get("duration")
if isinstance(voice.get("duration"), int)
and not isinstance(voice.get("duration"), bool)
else None,
raw=voice,
)
if not isinstance(raw_text, str) and not isinstance(caption, str):
text = ""
document_payload: TelegramDocument | None = None
document = msg.get("document")
if isinstance(document, dict):
document_payload = _parse_document_payload(document)
if document_payload is None:
video = msg.get("video")
if isinstance(video, dict):
document_payload = _parse_document_payload(video)
if document_payload is None:
photo = msg.get("photo")
if isinstance(photo, list):
best: dict[str, Any] | None = None
best_score = -1
for item in photo:
if not isinstance(item, dict):
continue
file_id = item.get("file_id")
if not isinstance(file_id, str) or not file_id:
continue
size = item.get("file_size")
if isinstance(size, int) and not isinstance(size, bool):
score = size
else:
width = item.get("width")
height = item.get("height")
if isinstance(width, int) and isinstance(height, int):
score = width * height
else:
score = 0
if score > best_score:
best_score = score
best = item
if best is not None:
document_payload = _parse_document_payload(best)
if document_payload is None and file_command:
sticker = msg.get("sticker")
if isinstance(sticker, dict):
document_payload = _parse_document_payload(sticker)
has_text = isinstance(raw_text, str) or isinstance(caption, str)
if not has_text and voice_payload is None and document_payload is None:
return None
chat = msg.get("chat")
if not isinstance(chat, dict):
return None
msg_chat_id = chat.get("id")
if not isinstance(msg_chat_id, int):
return None
chat_type = chat.get("type") if isinstance(chat.get("type"), str) else None
is_forum = chat.get("is_forum")
if not isinstance(is_forum, bool):
is_forum = None
allowed = chat_ids
if allowed is None and chat_id is not None:
allowed = {chat_id}
if allowed is not None and msg_chat_id not in allowed:
return None
message_id = msg.get("message_id")
if not isinstance(message_id, int):
return None
reply = msg.get("reply_to_message")
reply_to_message_id = None
reply_to_text = None
if isinstance(reply, dict):
reply_to_message_id = (
reply.get("message_id")
if isinstance(reply.get("message_id"), int)
else None
)
reply_to_text = (
reply.get("text") if isinstance(reply.get("text"), str) else None
)
sender = msg.get("from")
sender_id = (
sender.get("id")
if isinstance(sender, dict) and isinstance(sender.get("id"), int)
else None
)
media_group_id = msg.get("media_group_id")
if not isinstance(media_group_id, str):
media_group_id = None
thread_id = msg.get("message_thread_id")
if isinstance(thread_id, bool) or not isinstance(thread_id, int):
thread_id = None
is_topic_message = msg.get("is_topic_message")
if not isinstance(is_topic_message, bool):
is_topic_message = None
return TelegramIncomingMessage(
transport="telegram",
chat_id=msg_chat_id,
message_id=message_id,
text=text,
reply_to_message_id=reply_to_message_id,
reply_to_text=reply_to_text,
sender_id=sender_id,
media_group_id=media_group_id,
thread_id=thread_id,
is_topic_message=is_topic_message,
chat_type=chat_type,
is_forum=is_forum,
voice=voice_payload,
document=document_payload,
raw=msg,
)
def _parse_callback_query(
query: dict[str, Any],
*,
chat_id: int | None = None,
chat_ids: set[int] | None = None,
) -> TelegramCallbackQuery | None:
callback_id = query.get("id")
if not isinstance(callback_id, str) or not callback_id:
return None
msg = query.get("message")
if not isinstance(msg, dict):
return None
chat = msg.get("chat")
if not isinstance(chat, dict):
return None
msg_chat_id = chat.get("id")
if not isinstance(msg_chat_id, int):
return None
allowed = chat_ids
if allowed is None and chat_id is not None:
allowed = {chat_id}
if allowed is not None and msg_chat_id not in allowed:
return None
message_id = msg.get("message_id")
if not isinstance(message_id, int):
return None
data = query.get("data") if isinstance(query.get("data"), str) else None
sender = query.get("from")
sender_id = (
sender.get("id")
if isinstance(sender, dict) and isinstance(sender.get("id"), int)
else None
)
return TelegramCallbackQuery(
transport="telegram",
chat_id=msg_chat_id,
message_id=message_id,
callback_query_id=callback_id,
data=data,
sender_id=sender_id,
raw=query,
)
async def poll_incoming(
bot: BotClient,
*,
chat_id: int | None = None,
chat_ids: Iterable[int] | Callable[[], Iterable[int]] | None = None,
offset: int | None = None,
) -> AsyncIterator[TelegramIncomingUpdate]:
while True:
updates = await bot.get_updates(
offset=offset,
timeout_s=50,
allowed_updates=["message", "callback_query"],
)
if updates is None:
logger.info("loop.get_updates.failed")
await anyio.sleep(2)
continue
logger.debug("loop.updates", updates=updates)
resolved_chat_ids = chat_ids() if callable(chat_ids) else chat_ids
allowed = set(resolved_chat_ids) if resolved_chat_ids is not None else None
if allowed is None and chat_id is not None:
allowed = {chat_id}
for upd in updates:
offset = upd.update_id + 1
msg = parse_incoming_update(upd, chat_ids=allowed)
if msg is not None:
yield msg
+5 -11
View File
@@ -1,14 +1,13 @@
from __future__ import annotations
import json
import os
from collections.abc import Callable
from pathlib import Path
from typing import Any, Callable, Generic, Protocol, TypeVar
from typing import Any, Protocol
import anyio
import msgspec
T = TypeVar("T", bound="_VersionedState")
from ..utils.json_state import atomic_write_json
class _Logger(Protocol):
@@ -19,7 +18,7 @@ class _VersionedState(Protocol):
version: int
class JsonStateStore(Generic[T]):
class JsonStateStore[T: _VersionedState]:
def __init__(
self,
path: Path,
@@ -84,11 +83,6 @@ class JsonStateStore(Generic[T]):
self._state = payload
def _save_locked(self) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True)
payload = msgspec.to_builtins(self._state)
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2, sort_keys=True)
handle.write("\n")
os.replace(tmp_path, self._path)
atomic_write_json(self._path, payload)
self._mtime_ns = self._stat_mtime_ns()
+4 -4
View File
@@ -1,11 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol, TypeAlias
from typing import Any, Protocol
ChannelId: TypeAlias = int | str
MessageId: TypeAlias = int | str
ThreadId: TypeAlias = int | str
type ChannelId = int | str
type MessageId = int | str
type ThreadId = int | str
@dataclass(frozen=True, slots=True)
+9 -9
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Iterable, Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, TypeAlias
from typing import Any, Literal
from .config import ConfigError, ProjectsConfig
from .context import RunContext
@@ -19,7 +19,7 @@ from .router import AutoRouter, EngineStatus
from .runner import Runner
from .worktrees import WorktreeError, resolve_run_cwd
ContextSource: TypeAlias = Literal[
type ContextSource = Literal[
"reply_ctx",
"directives",
"ambient",
@@ -234,13 +234,13 @@ class TransportRuntime:
project_key = ambient_context.project
else:
project_key = default_project
if branch is None:
if (
ambient_context is not None
and ambient_context.branch is not None
and project_key == ambient_context.project
):
branch = ambient_context.branch
if (
branch is None
and ambient_context is not None
and ambient_context.branch is not None
and project_key == ambient_context.project
):
branch = ambient_context.branch
context: RunContext | None = None
if project_key is not None or branch is not None:
context = RunContext(project=project_key, branch=branch)
+3 -2
View File
@@ -2,7 +2,8 @@ from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
from collections.abc import Iterable
from .backends import EngineBackend, SetupIssue
from .plugins import TRANSPORT_GROUP, list_ids, load_plugin_backend
@@ -34,7 +35,7 @@ class TransportBackend(Protocol):
def interactive_setup(self, *, force: bool) -> bool: ...
def lock_token(
self, *, transport_config: object, config_path: Path
self, *, transport_config: object, _config_path: Path
) -> str | None: ...
def build_and_run(
+21
View File
@@ -0,0 +1,21 @@
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
def atomic_write_json(
path: Path,
payload: Any,
*,
indent: int = 2,
sort_keys: bool = True,
) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.with_suffix(f"{path.suffix}.tmp")
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=indent, sort_keys=sort_keys)
handle.write("\n")
os.replace(tmp_path, path)
+3 -3
View File
@@ -14,7 +14,7 @@ from takopi.model import (
def session_started(engine: str, value: str, title: str = "Codex") -> TakopiEvent:
engine_id = EngineId(engine)
engine_id: EngineId = engine
return StartedEvent(
engine=engine_id,
resume=ResumeToken(engine=engine_id, value=value),
@@ -29,7 +29,7 @@ def action_started(
detail: dict[str, Any] | None = None,
engine: str = "codex",
) -> TakopiEvent:
engine_id = EngineId(engine)
engine_id: EngineId = engine
return ActionEvent(
engine=engine_id,
action=Action(
@@ -50,7 +50,7 @@ def action_completed(
detail: dict[str, Any] | None = None,
engine: str = "codex",
) -> TakopiEvent:
engine_id = EngineId(engine)
engine_id: EngineId = engine
return ActionEvent(
engine=engine_id,
action=Action(
+2 -1
View File
@@ -1,7 +1,8 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Iterable
from typing import Any
from collections.abc import Callable, Iterable
@dataclass(frozen=True, slots=True)
+2 -2
View File
@@ -5,7 +5,7 @@ import pytest
from takopi.runner_bridge import ExecBridgeConfig, IncomingMessage, handle_message
from takopi.markdown import MarkdownParts, MarkdownPresenter
from takopi.model import EngineId, ResumeToken, TakopiEvent
from takopi.model import ResumeToken, TakopiEvent
from takopi.telegram.render import prepare_telegram
from takopi.runners.codex import CodexRunner
from takopi.runners.mock import Advance, Emit, Raise, Return, ScriptRunner, Wait
@@ -13,7 +13,7 @@ from takopi.settings import load_settings, require_telegram
from takopi.transport import MessageRef, RenderedMessage, SendOptions
from tests.factories import action_completed, action_started
CODEX_ENGINE = EngineId("codex")
CODEX_ENGINE = "codex"
class _FakeTransport:
+1 -2
View File
@@ -7,14 +7,13 @@ from collections.abc import AsyncIterator
from takopi.model import (
ActionEvent,
CompletedEvent,
EngineId,
ResumeToken,
StartedEvent,
TakopiEvent,
)
from takopi.runners.codex import CodexRunner, find_exec_only_flag
CODEX_ENGINE = EngineId("codex")
CODEX_ENGINE = "codex"
@pytest.mark.anyio
+4 -5
View File
@@ -45,11 +45,10 @@ def test_resolve_default_base_prefers_master_over_main(monkeypatch) -> None:
return None
def _fake_ok(args, **kwargs):
if args == ["show-ref", "--verify", "--quiet", "refs/heads/master"]:
return True
if args == ["show-ref", "--verify", "--quiet", "refs/heads/main"]:
return True
return False
return args in (
["show-ref", "--verify", "--quiet", "refs/heads/master"],
["show-ref", "--verify", "--quiet", "refs/heads/main"],
)
monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout)
monkeypatch.setattr("takopi.utils.git.git_ok", _fake_ok)
+3 -4
View File
@@ -7,7 +7,6 @@ from takopi.model import (
Action,
ActionEvent,
CompletedEvent,
EngineId,
ResumeToken,
StartedEvent,
TakopiEvent,
@@ -15,7 +14,7 @@ from takopi.model import (
from takopi.runners.mock import Emit, Return, ScriptRunner, Wait
from tests.factories import action_started
CODEX_ENGINE = EngineId("codex")
CODEX_ENGINE = "codex"
@pytest.mark.anyio
@@ -84,7 +83,7 @@ async def test_runner_releases_lock_when_consumer_closes() -> None:
gate = anyio.Event()
runner = ScriptRunner([Wait(gate)], engine=CODEX_ENGINE, resume_value="sid")
gen = cast(AsyncGenerator[TakopiEvent, None], runner.run("hello", None))
gen = cast(AsyncGenerator[TakopiEvent], runner.run("hello", None))
try:
while True:
evt = await anext(gen)
@@ -94,7 +93,7 @@ async def test_runner_releases_lock_when_consumer_closes() -> None:
await gen.aclose()
gen2 = cast(
AsyncGenerator[TakopiEvent, None],
AsyncGenerator[TakopiEvent],
runner.run("again", ResumeToken(engine=CODEX_ENGINE, value="sid")),
)
try:
+7 -8
View File
@@ -8,7 +8,6 @@ import takopi.runner as runner_module
from takopi.model import (
ActionEvent,
CompletedEvent,
EngineId,
ResumeToken,
StartedEvent,
TakopiEvent,
@@ -22,7 +21,7 @@ from takopi.runner import (
class _DummyRunner(ResumeTokenMixin, BaseRunner):
engine = EngineId("dummy")
engine = "dummy"
resume_re = re.compile(r"(?im)^`?dummy resume (?P<token>[^`\s]+)`?$")
async def run_impl(
@@ -39,7 +38,7 @@ class _DummyRunner(ResumeTokenMixin, BaseRunner):
class _DummyJsonlRunner(JsonlSubprocessRunner):
engine = EngineId("dummy-jsonl")
engine = "dummy-jsonl"
def command(self) -> str:
return "dummy"
@@ -67,7 +66,7 @@ class _DummyJsonlRunner(JsonlSubprocessRunner):
class _BareJsonlRunner(JsonlSubprocessRunner):
engine = EngineId("bare-jsonl")
engine = "bare-jsonl"
class _RunJsonlRunner(_DummyJsonlRunner):
@@ -177,7 +176,7 @@ async def test_base_runner_run_locked_handles_resume() -> None:
@pytest.mark.anyio
async def test_base_runner_rejects_wrong_resume_engine() -> None:
runner = _DummyRunner()
bad_resume = ResumeToken(engine=EngineId("other"), value="oops")
bad_resume = ResumeToken(engine="other", value="oops")
with pytest.raises(RuntimeError):
_ = [evt async for evt in runner.run("hello", bad_resume)]
@@ -185,7 +184,7 @@ async def test_base_runner_rejects_wrong_resume_engine() -> None:
@pytest.mark.anyio
async def test_base_runner_run_impl_not_implemented() -> None:
class _BareRunner(BaseRunner):
engine = EngineId("bare")
engine = "bare"
runner = _BareRunner()
with pytest.raises(NotImplementedError):
@@ -204,7 +203,7 @@ def test_resume_token_format_and_extract() -> None:
assert runner.extract_resume(None) is None
with pytest.raises(RuntimeError):
runner.format_resume(ResumeToken(engine=EngineId("other"), value="bad"))
runner.format_resume(ResumeToken(engine="other", value="bad"))
def test_session_lock_reuse() -> None:
@@ -294,7 +293,7 @@ def test_jsonl_helpers() -> None:
assert found == resume
assert emit is False
mismatch = StartedEvent(engine=EngineId("other"), resume=resume, title="t")
mismatch = StartedEvent(engine="other", resume=resume, title="t")
with pytest.raises(RuntimeError):
runner.handle_started_event(mismatch, expected_session=None, found_session=None)
+6 -7
View File
@@ -6,7 +6,6 @@ from typing import Any
import pytest
from takopi.config import ProjectsConfig
from takopi.model import EngineId
from takopi.router import AutoRouter, RunnerEntry
from takopi.runners.mock import Return, ScriptRunner
from takopi.settings import (
@@ -19,8 +18,8 @@ from takopi.transport_runtime import TransportRuntime
def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
codex = EngineId("codex")
pi = EngineId("pi")
codex = "codex"
pi = "pi"
runner = ScriptRunner([Return(answer="ok")], engine=codex)
missing = ScriptRunner([Return(answer="ok")], engine=pi)
router = AutoRouter(
@@ -53,9 +52,9 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
def test_build_startup_message_surfaces_unavailable_engine_reasons(
tmp_path: Path,
) -> None:
codex = EngineId("codex")
pi = EngineId("pi")
claude = EngineId("claude")
codex = "codex"
pi = "pi"
claude = "claude"
runner = ScriptRunner([Return(answer="ok")], engine=codex)
bad_cfg = ScriptRunner([Return(answer="ok")], engine=pi)
load_err = ScriptRunner([Return(answer="ok")], engine=claude)
@@ -100,7 +99,7 @@ def test_telegram_backend_build_and_run_wires_config(
encoding="utf-8",
)
codex = EngineId("codex")
codex = "codex"
runner = ScriptRunner([Return(answer="ok")], engine=codex)
router = AutoRouter(
entries=[RunnerEntry(engine=codex, runner=runner)],
+13 -14
View File
@@ -6,8 +6,9 @@ import anyio
import pytest
from takopi import commands, plugins
from takopi.telegram.commands.executor import _CaptureTransport, _run_engine
from takopi.telegram.commands.file_transfer import _handle_file_get, _handle_file_put
import takopi.telegram.loop as telegram_loop
import takopi.telegram.commands as telegram_commands
import takopi.telegram.topics as telegram_topics
from takopi.directives import parse_directives
from takopi.telegram.api_models import (
@@ -39,7 +40,7 @@ from takopi.context import RunContext
from takopi.config import ProjectConfig, ProjectsConfig
from takopi.runner_bridge import ExecBridgeConfig, RunningTask
from takopi.markdown import MarkdownPresenter
from takopi.model import EngineId, ResumeToken
from takopi.model import ResumeToken
from takopi.progress import ProgressTracker
from takopi.router import AutoRouter, RunnerEntry
from takopi.transport_runtime import TransportRuntime
@@ -52,7 +53,7 @@ from takopi.telegram.types import (
from takopi.transport import MessageRef, RenderedMessage, SendOptions
from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints
CODEX_ENGINE = EngineId("codex")
CODEX_ENGINE = "codex"
def _empty_projects() -> ProjectsConfig:
@@ -905,9 +906,7 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
),
)
await telegram_commands._handle_file_put(
cfg, msg, "/proj uploads/hello.txt", None, None
)
await _handle_file_put(cfg, msg, "/proj uploads/hello.txt", None, None)
target = tmp_path / "uploads" / "hello.txt"
assert target.read_bytes() == payload
@@ -966,7 +965,7 @@ async def test_handle_file_get_sends_document_for_allowed_user(
chat_type="supergroup",
)
await telegram_commands._handle_file_get(cfg, msg, "/proj hello.txt", None, None)
await _handle_file_get(cfg, msg, "/proj hello.txt", None, None)
assert bot.document_calls
assert bot.document_calls[0]["filename"] == "hello.txt"
@@ -1263,7 +1262,7 @@ async def test_send_with_resume_reports_when_missing() -> None:
@pytest.mark.anyio
async def test_run_engine_hides_resume_line_in_topics() -> None:
transport = telegram_commands._CaptureTransport()
transport = _CaptureTransport()
runner = ScriptRunner(
[Return(answer="ok")],
engine=CODEX_ENGINE,
@@ -1279,7 +1278,7 @@ async def test_run_engine_hides_resume_line_in_topics() -> None:
projects=_empty_projects(),
)
await telegram_commands._run_engine(
await _run_engine(
exec_cfg=exec_cfg,
runtime=runtime,
running_tasks={},
@@ -1456,14 +1455,14 @@ async def test_run_main_loop_auto_resumes_topic_default_engine(
123, 77, ResumeToken(engine=CODEX_ENGINE, value="resume-codex")
)
await store.set_session_resume(
123, 77, ResumeToken(engine=EngineId("claude"), value="resume-claude")
123, 77, ResumeToken(engine="claude", value="resume-claude")
)
await store.set_default_engine(123, 77, "claude")
transport = _FakeTransport()
bot = _FakeBot()
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
claude_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("claude"))
claude_runner = ScriptRunner([Return(answer="ok")], engine="claude")
router = AutoRouter(
entries=[
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
@@ -1521,7 +1520,7 @@ async def test_run_main_loop_auto_resumes_topic_default_engine(
assert codex_runner.calls == []
assert len(claude_runner.calls) == 1
assert claude_runner.calls[0][1] == ResumeToken(
engine=EngineId("claude"), value="resume-claude"
engine="claude", value="resume-claude"
)
@@ -2443,7 +2442,7 @@ async def test_run_main_loop_command_uses_project_default_engine(
transport = _FakeTransport()
bot = _FakeBot()
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
pi_runner = ScriptRunner([Return(answer="ok")], engine="pi")
router = AutoRouter(
entries=[
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
@@ -2525,7 +2524,7 @@ async def test_run_main_loop_command_defaults_to_chat_project(
transport = _FakeTransport()
bot = _FakeBot()
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
pi_runner = ScriptRunner([Return(answer="ok")], engine="pi")
router = AutoRouter(
entries=[
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
+13 -12
View File
@@ -3,6 +3,7 @@ import pytest
from takopi.logging import setup_logging
from takopi.telegram.client import TelegramClient, TelegramRetryAfter
from takopi.telegram.client_api import HttpBotClient
@pytest.mark.anyio
@@ -25,9 +26,9 @@ async def test_telegram_429_no_retry() -> None:
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
with pytest.raises(TelegramRetryAfter) as exc:
await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
await api._post("sendMessage", {"chat_id": 1, "text": "hi"})
finally:
await client.aclose()
@@ -49,8 +50,8 @@ async def test_no_token_in_logs_on_http_error(
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient(token, http_client=client)
await tg._post("getUpdates", {"timeout": 1})
api = HttpBotClient(token, http_client=client)
await api._post("getUpdates", {"timeout": 1})
finally:
await client.aclose()
@@ -79,9 +80,9 @@ async def test_telegram_429_no_retry_post_form() -> None:
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
with pytest.raises(TelegramRetryAfter) as exc:
await tg._post_form(
await api._post_form(
"sendDocument",
{"chat_id": 1},
files={"document": ("note.txt", b"hi")},
@@ -102,9 +103,9 @@ async def test_telegram_429_defaults_retry_after_on_bad_body() -> None:
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
with pytest.raises(TelegramRetryAfter) as exc:
await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
await api._post("sendMessage", {"chat_id": 1, "text": "hi"})
finally:
await client.aclose()
@@ -124,8 +125,8 @@ async def test_telegram_ok_false_returns_none() -> None:
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
result = await tg._post("getUpdates", {"timeout": 1})
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
result = await api._post("getUpdates", {"timeout": 1})
finally:
await client.aclose()
@@ -141,8 +142,8 @@ async def test_telegram_invalid_payload_returns_none() -> None:
client = httpx.AsyncClient(transport=transport)
try:
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
result = await tg._post("getUpdates", {"timeout": 1})
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
result = await api._post("getUpdates", {"timeout": 1})
finally:
await client.aclose()
+3 -4
View File
@@ -4,7 +4,6 @@ import pytest
from takopi.config import ProjectConfig, ProjectsConfig
from takopi.context import RunContext
from takopi.model import EngineId
from takopi.router import AutoRouter, RunnerEntry
from takopi.runners.mock import Return, ScriptRunner
from takopi.telegram.chat_prefs import ChatPrefsStore
@@ -15,8 +14,8 @@ from takopi.transport_runtime import TransportRuntime
@pytest.mark.anyio
async def test_resolve_engine_for_message_sources(tmp_path) -> None:
codex = ScriptRunner([Return(answer="ok")], engine=EngineId("codex"))
pi = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
codex = ScriptRunner([Return(answer="ok")], engine="codex")
pi = ScriptRunner([Return(answer="ok")], engine="pi")
router = AutoRouter(
entries=[
RunnerEntry(engine=codex.engine, runner=codex),
@@ -42,7 +41,7 @@ async def test_resolve_engine_for_message_sources(tmp_path) -> None:
resolved = await resolve_engine_for_message(
runtime=runtime,
context=RunContext(project="proj"),
explicit_engine=EngineId("codex"),
explicit_engine="codex",
chat_id=1,
topic_key=(1, 10),
topic_store=topic_store,
+2 -2
View File
@@ -15,8 +15,8 @@ class DummyTransport:
def interactive_setup(self, *, force: bool) -> bool:
raise NotImplementedError
def lock_token(self, *, transport_config: object, config_path):
_ = transport_config, config_path
def lock_token(self, *, transport_config: object, _config_path):
_ = transport_config, _config_path
raise NotImplementedError
def build_and_run(