refactor: telegram modules and tighten linting (#111)
This commit is contained in:
@@ -408,7 +408,7 @@ from ..schemas import acme as acme_schema
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ENGINE: EngineId = EngineId("acme")
|
ENGINE: EngineId = "acme"
|
||||||
_RESUME_RE = re.compile(
|
_RESUME_RE = re.compile(
|
||||||
r"(?im)^\s*`?acme\s+--resume\s+(?P<token>[^`\s]+)`?\s*$"
|
r"(?im)^\s*`?acme\s+--resume\s+(?P<token>[^`\s]+)`?\s*$"
|
||||||
)
|
)
|
||||||
|
|||||||
+1
-1
@@ -65,7 +65,7 @@ addopts = ["--cov=takopi", "--cov-report=term-missing", "--cov-fail-under=75"]
|
|||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
extend-select = ["B904", "BLE001", "S110", "RUF043"]
|
extend-select = ["B", "BLE001", "C4", "PERF", "RUF043", "S110", "SIM", "UP"]
|
||||||
|
|
||||||
[tool.ty.src]
|
[tool.ty.src]
|
||||||
include = ["src", "tests"]
|
include = ["src", "tests"]
|
||||||
|
|||||||
+2
-5
@@ -256,7 +256,7 @@ def _run_auto_router(
|
|||||||
)
|
)
|
||||||
lock_token = transport_backend.lock_token(
|
lock_token = transport_backend.lock_token(
|
||||||
transport_config=transport_config,
|
transport_config=transport_config,
|
||||||
config_path=config_path,
|
_config_path=config_path,
|
||||||
)
|
)
|
||||||
lock_handle = acquire_config_lock(config_path, lock_token)
|
lock_handle = acquire_config_lock(config_path, lock_token)
|
||||||
runtime = spec.to_runtime(config_path=config_path)
|
runtime = spec.to_runtime(config_path=config_path)
|
||||||
@@ -301,10 +301,7 @@ def _default_alias_from_path(path: Path) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def _ensure_projects_table(config: dict, config_path: Path) -> dict:
|
def _ensure_projects_table(config: dict, config_path: Path) -> dict:
|
||||||
projects = config.get("projects")
|
projects = config.setdefault("projects", {})
|
||||||
if projects is None:
|
|
||||||
projects = {}
|
|
||||||
config["projects"] = projects
|
|
||||||
if not isinstance(projects, dict):
|
if not isinstance(projects, dict):
|
||||||
raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.")
|
raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.")
|
||||||
return projects
|
return projects
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Awaitable, Callable, Iterable
|
from collections.abc import Awaitable, Callable, Iterable
|
||||||
|
|
||||||
from watchfiles import awatch
|
from watchfiles import awatch
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from .backends import EngineBackend
|
from .backends import EngineBackend
|
||||||
from .config import ConfigError
|
from .config import ConfigError
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from .logging import get_logger
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class LockInfo:
|
class LockInfo:
|
||||||
pid: int | None
|
pid: int | None
|
||||||
token_fingerprint: str | None
|
token_fingerprint: str | None
|
||||||
@@ -29,7 +29,7 @@ class LockError(RuntimeError):
|
|||||||
super().__init__(_format_lock_message(path, state))
|
super().__init__(_format_lock_message(path, state))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(slots=True)
|
||||||
class LockHandle:
|
class LockHandle:
|
||||||
path: Path
|
path: Path
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ class LockHandle:
|
|||||||
error_type=exc.__class__.__name__,
|
error_type=exc.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __enter__(self) -> "LockHandle":
|
def __enter__(self) -> LockHandle:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb) -> None:
|
def __exit__(self, exc_type, exc, tb) -> None:
|
||||||
|
|||||||
@@ -107,9 +107,8 @@ def _redact_value(value: Any, memo: dict[int, Any]) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
def _redact_event_dict(
|
def _redact_event_dict(
|
||||||
logger: Any, method_name: str, event_dict: dict[str, Any]
|
_logger: Any, _method_name: str, event_dict: dict[str, Any]
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_ = logger, method_name
|
|
||||||
return _redact_value(event_dict, memo={})
|
return _redact_value(event_dict, memo={})
|
||||||
|
|
||||||
|
|
||||||
@@ -222,10 +221,7 @@ def setup_logging(
|
|||||||
|
|
||||||
format_value = os.environ.get("TAKOPI_LOG_FORMAT", "console").strip().lower()
|
format_value = os.environ.get("TAKOPI_LOG_FORMAT", "console").strip().lower()
|
||||||
color_override = os.environ.get("TAKOPI_LOG_COLOR")
|
color_override = os.environ.get("TAKOPI_LOG_COLOR")
|
||||||
if color_override is None:
|
is_tty = sys.stdout.isatty() if color_override is None else _truthy(color_override)
|
||||||
is_tty = sys.stdout.isatty()
|
|
||||||
else:
|
|
||||||
is_tty = _truthy(color_override)
|
|
||||||
if format_value == "json":
|
if format_value == "json":
|
||||||
renderer: Any = structlog.processors.JSONRenderer(default=str)
|
renderer: Any = structlog.processors.JSONRenderer(default=str)
|
||||||
else:
|
else:
|
||||||
@@ -242,7 +238,9 @@ def setup_logging(
|
|||||||
_log_file_handle = None
|
_log_file_handle = None
|
||||||
if log_file:
|
if log_file:
|
||||||
try:
|
try:
|
||||||
_log_file_handle = open(log_file, "a", encoding="utf-8")
|
_log_file_handle = open( # noqa: SIM115
|
||||||
|
log_file, "a", encoding="utf-8"
|
||||||
|
)
|
||||||
except OSError:
|
except OSError:
|
||||||
_log_file_handle = None
|
_log_file_handle = None
|
||||||
|
|
||||||
|
|||||||
@@ -120,9 +120,12 @@ def format_file_change_title(action: Action, *, command_width: int | None) -> st
|
|||||||
was_relativized = relativized != fallback
|
was_relativized = relativized != fallback
|
||||||
if was_relativized:
|
if was_relativized:
|
||||||
fallback = relativized
|
fallback = relativized
|
||||||
if fallback and not (fallback.startswith("`") and fallback.endswith("`")):
|
if (
|
||||||
if was_relativized or os.sep in fallback or "/" in fallback:
|
fallback
|
||||||
fallback = f"`{fallback}`"
|
and not (fallback.startswith("`") and fallback.endswith("`"))
|
||||||
|
and (was_relativized or os.sep in fallback or "/" in fallback)
|
||||||
|
):
|
||||||
|
fallback = f"`{fallback}`"
|
||||||
return f"files: {shorten(fallback, command_width)}"
|
return f"files: {shorten(fallback, command_width)}"
|
||||||
|
|
||||||
|
|
||||||
@@ -247,10 +250,7 @@ class MarkdownFormatter:
|
|||||||
|
|
||||||
def _format_actions(self, state: ProgressState) -> list[str]:
|
def _format_actions(self, state: ProgressState) -> list[str]:
|
||||||
actions = list(state.actions)
|
actions = list(state.actions)
|
||||||
if self.max_actions == 0:
|
actions = [] if self.max_actions == 0 else actions[-self.max_actions :]
|
||||||
actions = []
|
|
||||||
else:
|
|
||||||
actions = actions[-self.max_actions :]
|
|
||||||
return [
|
return [
|
||||||
format_action_line(
|
format_action_line(
|
||||||
action_state.action,
|
action_state.action,
|
||||||
|
|||||||
+7
-7
@@ -3,11 +3,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Literal, TypeAlias
|
from typing import Any, Literal
|
||||||
|
|
||||||
EngineId: TypeAlias = str
|
type EngineId = str
|
||||||
|
|
||||||
ActionKind: TypeAlias = Literal[
|
type ActionKind = Literal[
|
||||||
"command",
|
"command",
|
||||||
"tool",
|
"tool",
|
||||||
"file_change",
|
"file_change",
|
||||||
@@ -19,14 +19,14 @@ ActionKind: TypeAlias = Literal[
|
|||||||
"telemetry",
|
"telemetry",
|
||||||
]
|
]
|
||||||
|
|
||||||
TakopiEventType: TypeAlias = Literal[
|
type TakopiEventType = Literal[
|
||||||
"started",
|
"started",
|
||||||
"action",
|
"action",
|
||||||
"completed",
|
"completed",
|
||||||
]
|
]
|
||||||
|
|
||||||
ActionPhase: TypeAlias = Literal["started", "updated", "completed"]
|
type ActionPhase = Literal["started", "updated", "completed"]
|
||||||
ActionLevel: TypeAlias = Literal["debug", "info", "warning", "error"]
|
type ActionLevel = Literal["debug", "info", "warning", "error"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
@@ -74,4 +74,4 @@ class CompletedEvent:
|
|||||||
usage: dict[str, Any] | None = None
|
usage: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
TakopiEvent: TypeAlias = StartedEvent | ActionEvent | CompletedEvent
|
type TakopiEvent = StartedEvent | ActionEvent | CompletedEvent
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from importlib.metadata import EntryPoint, entry_points
|
from importlib.metadata import EntryPoint, entry_points
|
||||||
import re
|
import re
|
||||||
from typing import Any, Callable
|
from typing import Any
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from .ids import ID_PATTERN, is_valid_id
|
from .ids import ID_PATTERN, is_valid_id
|
||||||
|
|
||||||
@@ -80,12 +81,7 @@ def reset_plugin_state() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _select_entrypoints(group: str) -> list[EntryPoint]:
|
def _select_entrypoints(group: str) -> list[EntryPoint]:
|
||||||
eps = entry_points()
|
return list(entry_points().select(group=group))
|
||||||
if hasattr(eps, "select"):
|
|
||||||
return list(eps.select(group=group))
|
|
||||||
if isinstance(eps, Mapping):
|
|
||||||
return list(eps.get(group, []))
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def entrypoint_distribution_name(ep: EntryPoint) -> str | None:
|
def entrypoint_distribution_name(ep: EntryPoint) -> str | None:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
from .model import Action, ActionEvent, ResumeToken, StartedEvent, TakopiEvent
|
from .model import Action, ActionEvent, ResumeToken, StartedEvent, TakopiEvent
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterable, Literal, TypeAlias
|
from typing import Literal
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from .model import EngineId, ResumeToken
|
from .model import EngineId, ResumeToken
|
||||||
from .runner import Runner
|
from .runner import Runner
|
||||||
@@ -17,7 +18,7 @@ class RunnerUnavailableError(RuntimeError):
|
|||||||
self.issue = issue
|
self.issue = issue
|
||||||
|
|
||||||
|
|
||||||
EngineStatus: TypeAlias = Literal["ok", "missing_cli", "bad_config", "load_error"]
|
type EngineStatus = Literal["ok", "missing_cli", "bad_config", "load_error"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
|
|||||||
@@ -37,12 +37,9 @@ def _log_runner_event(evt: TakopiEvent) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _strip_resume_lines(text: str, *, is_resume_line: Callable[[str], bool]) -> str:
|
def _strip_resume_lines(text: str, *, is_resume_line: Callable[[str], bool]) -> str:
|
||||||
stripped_lines: list[str] = []
|
prompt = "\n".join(
|
||||||
for line in text.splitlines():
|
line for line in text.splitlines() if not is_resume_line(line)
|
||||||
if is_resume_line(line):
|
).strip()
|
||||||
continue
|
|
||||||
stripped_lines.append(line)
|
|
||||||
prompt = "\n".join(stripped_lines).strip()
|
|
||||||
return prompt or "continue"
|
return prompt or "continue"
|
||||||
|
|
||||||
|
|
||||||
@@ -83,14 +80,14 @@ class IncomingMessage:
|
|||||||
thread_id: ThreadId | None = None
|
thread_id: ThreadId | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class ExecBridgeConfig:
|
class ExecBridgeConfig:
|
||||||
transport: Transport
|
transport: Transport
|
||||||
presenter: Presenter
|
presenter: Presenter
|
||||||
final_notify: bool
|
final_notify: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(slots=True)
|
||||||
class RunningTask:
|
class RunningTask:
|
||||||
resume: ResumeToken | None = None
|
resume: ResumeToken | None = None
|
||||||
resume_ready: anyio.Event = field(default_factory=anyio.Event)
|
resume_ready: anyio.Event = field(default_factory=anyio.Event)
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ from ..logging import get_logger
|
|||||||
from ..model import Action, ActionKind, EngineId, ResumeToken, TakopiEvent
|
from ..model import Action, ActionKind, EngineId, ResumeToken, TakopiEvent
|
||||||
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
||||||
from ..schemas import claude as claude_schema
|
from ..schemas import claude as claude_schema
|
||||||
from ..utils.paths import relativize_command, relativize_path
|
from .tool_actions import tool_input_path, tool_kind_and_title
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
ENGINE: EngineId = EngineId("claude")
|
ENGINE: EngineId = "claude"
|
||||||
DEFAULT_ALLOWED_TOOLS = ["Bash", "Read", "Edit", "Write"]
|
DEFAULT_ALLOWED_TOOLS = ["Bash", "Read", "Edit", "Write"]
|
||||||
|
|
||||||
_RESUME_RE = re.compile(
|
_RESUME_RE = re.compile(
|
||||||
@@ -67,55 +67,10 @@ def _coerce_comma_list(value: Any) -> str | None:
|
|||||||
return text or None
|
return text or None
|
||||||
|
|
||||||
|
|
||||||
def _tool_input_path(tool_input: dict[str, Any]) -> str | None:
|
|
||||||
for key in ("file_path", "path"):
|
|
||||||
value = tool_input.get(key)
|
|
||||||
if isinstance(value, str) and value:
|
|
||||||
return value
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _tool_kind_and_title(
|
def _tool_kind_and_title(
|
||||||
name: str, tool_input: dict[str, Any]
|
name: str, tool_input: dict[str, Any]
|
||||||
) -> tuple[ActionKind, str]:
|
) -> tuple[ActionKind, str]:
|
||||||
if name in {"Bash", "Shell", "KillShell"}:
|
return tool_kind_and_title(name, tool_input, path_keys=("file_path", "path"))
|
||||||
command = tool_input.get("command")
|
|
||||||
display = relativize_command(str(command or name))
|
|
||||||
return "command", display
|
|
||||||
if name in {"Edit", "Write", "NotebookEdit", "MultiEdit"}:
|
|
||||||
path = _tool_input_path(tool_input)
|
|
||||||
if path:
|
|
||||||
return "file_change", relativize_path(str(path))
|
|
||||||
return "file_change", str(name)
|
|
||||||
if name == "Read":
|
|
||||||
path = _tool_input_path(tool_input)
|
|
||||||
if path:
|
|
||||||
return "tool", f"read: `{relativize_path(str(path))}`"
|
|
||||||
return "tool", "read"
|
|
||||||
if name == "Glob":
|
|
||||||
pattern = tool_input.get("pattern")
|
|
||||||
if pattern:
|
|
||||||
return "tool", f"glob: `{pattern}`"
|
|
||||||
return "tool", "glob"
|
|
||||||
if name == "Grep":
|
|
||||||
pattern = tool_input.get("pattern")
|
|
||||||
if pattern:
|
|
||||||
return "tool", f"grep: {pattern}"
|
|
||||||
return "tool", "grep"
|
|
||||||
if name == "WebSearch":
|
|
||||||
query = tool_input.get("query")
|
|
||||||
return "web_search", str(query or "search")
|
|
||||||
if name == "WebFetch":
|
|
||||||
url = tool_input.get("url")
|
|
||||||
return "web_search", str(url or "fetch")
|
|
||||||
if name in {"TodoWrite", "TodoRead"}:
|
|
||||||
return "note", "update todos" if name == "TodoWrite" else "read todos"
|
|
||||||
if name == "AskUserQuestion":
|
|
||||||
return "note", "ask user"
|
|
||||||
if name in {"Task", "Agent"}:
|
|
||||||
desc = tool_input.get("description") or tool_input.get("prompt")
|
|
||||||
return "subagent", str(desc or name)
|
|
||||||
return "tool", name
|
|
||||||
|
|
||||||
|
|
||||||
def _tool_action(
|
def _tool_action(
|
||||||
@@ -137,7 +92,7 @@ def _tool_action(
|
|||||||
detail["parent_tool_use_id"] = parent_tool_use_id
|
detail["parent_tool_use_id"] = parent_tool_use_id
|
||||||
|
|
||||||
if kind == "file_change":
|
if kind == "file_change":
|
||||||
path = _tool_input_path(tool_input)
|
path = tool_input_path(tool_input, path_keys=("file_path", "path"))
|
||||||
if path:
|
if path:
|
||||||
detail["changes"] = [{"path": path, "kind": "update"}]
|
detail["changes"] = [{"path": path, "kind": "update"}]
|
||||||
|
|
||||||
@@ -321,7 +276,7 @@ def translate_claude_event(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(slots=True)
|
||||||
class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
class ClaudeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||||
engine: EngineId = ENGINE
|
engine: EngineId = ENGINE
|
||||||
resume_re: re.Pattern[str] = _RESUME_RE
|
resume_re: re.Pattern[str] = _RESUME_RE
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from ..utils.paths import relativize_command
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
ENGINE: EngineId = EngineId("codex")
|
ENGINE: EngineId = "codex"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ENGINE",
|
"ENGINE",
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
@@ -18,7 +17,7 @@ from ..model import (
|
|||||||
)
|
)
|
||||||
from ..runner import ResumeTokenMixin, Runner, SessionLockMixin
|
from ..runner import ResumeTokenMixin, Runner, SessionLockMixin
|
||||||
|
|
||||||
ENGINE: EngineId = EngineId("mock")
|
ENGINE: EngineId = "mock"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
@@ -52,7 +51,7 @@ class Raise:
|
|||||||
error: Exception
|
error: Exception
|
||||||
|
|
||||||
|
|
||||||
ScriptStep: TypeAlias = Emit | Advance | Sleep | Wait | Return | Raise
|
type ScriptStep = Emit | Advance | Sleep | Wait | Return | Raise
|
||||||
|
|
||||||
|
|
||||||
def _resume_token(engine: EngineId, value: str | None) -> ResumeToken:
|
def _resume_token(engine: EngineId, value: str | None) -> ResumeToken:
|
||||||
@@ -108,9 +107,9 @@ class MockRunner(SessionLockMixin, ResumeTokenMixin, Runner):
|
|||||||
if (
|
if (
|
||||||
isinstance(event_out, ActionEvent)
|
isinstance(event_out, ActionEvent)
|
||||||
and event_out.phase == "completed"
|
and event_out.phase == "completed"
|
||||||
|
and event_out.ok is None
|
||||||
):
|
):
|
||||||
if event_out.ok is None:
|
event_out = replace(event_out, ok=True)
|
||||||
event_out = replace(event_out, ok=True)
|
|
||||||
yield event_out
|
yield event_out
|
||||||
await anyio.sleep(0)
|
await anyio.sleep(0)
|
||||||
|
|
||||||
@@ -187,9 +186,9 @@ class ScriptRunner(MockRunner):
|
|||||||
if (
|
if (
|
||||||
isinstance(event_out, ActionEvent)
|
isinstance(event_out, ActionEvent)
|
||||||
and event_out.phase == "completed"
|
and event_out.phase == "completed"
|
||||||
|
and event_out.ok is None
|
||||||
):
|
):
|
||||||
if event_out.ok is None:
|
event_out = replace(event_out, ok=True)
|
||||||
event_out = replace(event_out, ok=True)
|
|
||||||
yield event_out
|
yield event_out
|
||||||
await anyio.sleep(0)
|
await anyio.sleep(0)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -35,11 +35,12 @@ from ..model import (
|
|||||||
)
|
)
|
||||||
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
||||||
from ..schemas import opencode as opencode_schema
|
from ..schemas import opencode as opencode_schema
|
||||||
from ..utils.paths import relativize_command, relativize_path
|
from ..utils.paths import relativize_path
|
||||||
|
from .tool_actions import tool_input_path, tool_kind_and_title
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
ENGINE: EngineId = EngineId("opencode")
|
ENGINE: EngineId = "opencode"
|
||||||
|
|
||||||
_RESUME_RE = re.compile(
|
_RESUME_RE = re.compile(
|
||||||
r"(?im)^\s*`?opencode(?:\s+run)?\s+(?:--session|-s)\s+(?P<token>ses_[A-Za-z0-9]+)`?\s*$"
|
r"(?im)^\s*`?opencode(?:\s+run)?\s+(?:--session|-s)\s+(?P<token>ses_[A-Za-z0-9]+)`?\s*$"
|
||||||
@@ -79,54 +80,12 @@ def _action_event(
|
|||||||
def _tool_kind_and_title(
|
def _tool_kind_and_title(
|
||||||
tool_name: str, tool_input: dict[str, Any]
|
tool_name: str, tool_input: dict[str, Any]
|
||||||
) -> tuple[ActionKind, str]:
|
) -> tuple[ActionKind, str]:
|
||||||
"""Map OpenCode tool names to Takopi action kinds and titles."""
|
return tool_kind_and_title(
|
||||||
name_lower = tool_name.lower()
|
tool_name,
|
||||||
|
tool_input,
|
||||||
if name_lower in {"bash", "shell"}:
|
path_keys=("file_path", "filePath"),
|
||||||
command = tool_input.get("command")
|
task_kind="tool",
|
||||||
display = relativize_command(str(command or tool_name))
|
)
|
||||||
return "command", display
|
|
||||||
|
|
||||||
if name_lower in {"edit", "write", "multiedit"}:
|
|
||||||
path = tool_input.get("file_path") or tool_input.get("filePath")
|
|
||||||
if path:
|
|
||||||
return "file_change", relativize_path(str(path))
|
|
||||||
return "file_change", str(tool_name)
|
|
||||||
|
|
||||||
if name_lower == "read":
|
|
||||||
path = tool_input.get("file_path") or tool_input.get("filePath")
|
|
||||||
if path:
|
|
||||||
return "tool", f"read: `{relativize_path(str(path))}`"
|
|
||||||
return "tool", "read"
|
|
||||||
|
|
||||||
if name_lower == "glob":
|
|
||||||
pattern = tool_input.get("pattern")
|
|
||||||
if pattern:
|
|
||||||
return "tool", f"glob: `{pattern}`"
|
|
||||||
return "tool", "glob"
|
|
||||||
|
|
||||||
if name_lower == "grep":
|
|
||||||
pattern = tool_input.get("pattern")
|
|
||||||
if pattern:
|
|
||||||
return "tool", f"grep: {pattern}"
|
|
||||||
return "tool", "grep"
|
|
||||||
|
|
||||||
if name_lower in {"websearch", "web_search"}:
|
|
||||||
query = tool_input.get("query")
|
|
||||||
return "web_search", str(query or "search")
|
|
||||||
|
|
||||||
if name_lower in {"webfetch", "web_fetch"}:
|
|
||||||
url = tool_input.get("url")
|
|
||||||
return "web_search", str(url or "fetch")
|
|
||||||
|
|
||||||
if name_lower in {"todowrite", "todoread"}:
|
|
||||||
return "note", "update todos" if "write" in name_lower else "read todos"
|
|
||||||
|
|
||||||
if name_lower == "task":
|
|
||||||
desc = tool_input.get("description") or tool_input.get("prompt")
|
|
||||||
return "tool", str(desc or tool_name)
|
|
||||||
|
|
||||||
return "tool", tool_name
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_tool_title(
|
def _normalize_tool_title(
|
||||||
@@ -137,10 +96,10 @@ def _normalize_tool_title(
|
|||||||
if "`" in title:
|
if "`" in title:
|
||||||
return title
|
return title
|
||||||
|
|
||||||
path = tool_input.get("file_path") or tool_input.get("filePath")
|
path = tool_input_path(tool_input, path_keys=("file_path", "filePath"))
|
||||||
if isinstance(path, str) and path:
|
if isinstance(path, str) and path:
|
||||||
rel_path = relativize_path(path)
|
rel_path = relativize_path(path)
|
||||||
if title == path or title == rel_path:
|
if title in (path, rel_path):
|
||||||
return f"`{rel_path}`"
|
return f"`{rel_path}`"
|
||||||
|
|
||||||
return title
|
return title
|
||||||
@@ -190,9 +149,8 @@ def translate_opencode_event(
|
|||||||
"""Translate an OpenCode JSON event into Takopi events."""
|
"""Translate an OpenCode JSON event into Takopi events."""
|
||||||
session_id = event.sessionID
|
session_id = event.sessionID
|
||||||
|
|
||||||
if isinstance(session_id, str) and session_id:
|
if isinstance(session_id, str) and session_id and state.session_id is None:
|
||||||
if state.session_id is None:
|
state.session_id = session_id
|
||||||
state.session_id = session_id
|
|
||||||
|
|
||||||
match event:
|
match event:
|
||||||
case opencode_schema.StepStart():
|
case opencode_schema.StepStart():
|
||||||
@@ -340,7 +298,7 @@ def translate_opencode_event(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(slots=True)
|
||||||
class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
class OpenCodeRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
||||||
"""Runner for OpenCode CLI."""
|
"""Runner for OpenCode CLI."""
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, UTC
|
||||||
from pathlib import Path, PurePath
|
from pathlib import Path, PurePath
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
@@ -27,11 +27,12 @@ from ..model import (
|
|||||||
)
|
)
|
||||||
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
from ..runner import JsonlSubprocessRunner, ResumeTokenMixin, Runner
|
||||||
from ..schemas import pi as pi_schema
|
from ..schemas import pi as pi_schema
|
||||||
from ..utils.paths import get_run_base_dir, relativize_command, relativize_path
|
from ..utils.paths import get_run_base_dir
|
||||||
|
from .tool_actions import tool_kind_and_title
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
ENGINE: EngineId = EngineId("pi")
|
ENGINE: EngineId = "pi"
|
||||||
|
|
||||||
_RESUME_RE = re.compile(r"(?im)^\s*`?pi\s+--session\s+(?P<token>.+?)`?\s*$")
|
_RESUME_RE = re.compile(r"(?im)^\s*`?pi\s+--session\s+(?P<token>.+?)`?\s*$")
|
||||||
|
|
||||||
@@ -96,32 +97,7 @@ def _tool_kind_and_title(
|
|||||||
name: str,
|
name: str,
|
||||||
args: dict[str, Any],
|
args: dict[str, Any],
|
||||||
) -> tuple[ActionKind, str]:
|
) -> tuple[ActionKind, str]:
|
||||||
tool = name.lower()
|
return tool_kind_and_title(name, args, path_keys=("path",))
|
||||||
if tool == "bash":
|
|
||||||
command = args.get("command")
|
|
||||||
return "command", relativize_command(str(command or "bash"))
|
|
||||||
if tool in {"edit", "write"}:
|
|
||||||
path = args.get("path")
|
|
||||||
if path:
|
|
||||||
return "file_change", relativize_path(str(path))
|
|
||||||
return "file_change", tool
|
|
||||||
if tool == "read":
|
|
||||||
path = args.get("path")
|
|
||||||
if path:
|
|
||||||
return "tool", f"read: `{relativize_path(str(path))}`"
|
|
||||||
return "tool", "read"
|
|
||||||
if tool == "grep":
|
|
||||||
pattern = args.get("pattern")
|
|
||||||
return "tool", f"grep: {pattern}" if pattern else "grep"
|
|
||||||
if tool == "find":
|
|
||||||
pattern = args.get("pattern")
|
|
||||||
return "tool", f"find: {pattern}" if pattern else "find"
|
|
||||||
if tool == "ls":
|
|
||||||
path = args.get("path")
|
|
||||||
if path:
|
|
||||||
return "tool", f"ls: `{relativize_path(str(path))}`"
|
|
||||||
return "tool", "ls"
|
|
||||||
return "tool", name
|
|
||||||
|
|
||||||
|
|
||||||
def _last_assistant_message(messages: Any) -> dict[str, Any] | None:
|
def _last_assistant_message(messages: Any) -> dict[str, Any] | None:
|
||||||
@@ -418,7 +394,7 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
|||||||
cwd = get_run_base_dir() or Path.cwd()
|
cwd = get_run_base_dir() or Path.cwd()
|
||||||
session_dir = _default_session_dir(cwd)
|
session_dir = _default_session_dir(cwd)
|
||||||
session_dir.mkdir(parents=True, exist_ok=True)
|
session_dir.mkdir(parents=True, exist_ok=True)
|
||||||
timestamp = datetime.now(timezone.utc).isoformat()
|
timestamp = datetime.now(UTC).isoformat()
|
||||||
safe_timestamp = timestamp.replace(":", "-").replace(".", "-")
|
safe_timestamp = timestamp.replace(":", "-").replace(".", "-")
|
||||||
token = uuid4().hex
|
token = uuid4().hex
|
||||||
filename = f"{safe_timestamp}_{token}.jsonl"
|
filename = f"{safe_timestamp}_{token}.jsonl"
|
||||||
@@ -442,7 +418,9 @@ class PiRunner(ResumeTokenMixin, JsonlSubprocessRunner):
|
|||||||
def _default_session_dir(cwd: PurePath) -> Path:
|
def _default_session_dir(cwd: PurePath) -> Path:
|
||||||
agent_dir = os.environ.get("PI_CODING_AGENT_DIR")
|
agent_dir = os.environ.get("PI_CODING_AGENT_DIR")
|
||||||
base = Path(agent_dir).expanduser() if agent_dir else Path.home() / ".pi" / "agent"
|
base = Path(agent_dir).expanduser() if agent_dir else Path.home() / ".pi" / "agent"
|
||||||
safe_path = f"--{str(cwd).lstrip('/\\\\').replace('/', '-').replace('\\', '-').replace(':', '-')}--"
|
cwd_str = str(cwd).lstrip("/\\")
|
||||||
|
safe_path_part = cwd_str.translate(str.maketrans({"/": "-", "\\": "-", ":": "-"}))
|
||||||
|
safe_path = f"--{safe_path_part}--"
|
||||||
return base / "sessions" / safe_path
|
return base / "sessions" / safe_path
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,90 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..model import ActionKind
|
||||||
|
from ..utils.paths import relativize_command, relativize_path
|
||||||
|
|
||||||
|
|
||||||
|
def tool_input_path(
|
||||||
|
tool_input: Mapping[str, Any],
|
||||||
|
*,
|
||||||
|
path_keys: Sequence[str],
|
||||||
|
) -> str | None:
|
||||||
|
for key in path_keys:
|
||||||
|
value = tool_input.get(key)
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def tool_kind_and_title(
|
||||||
|
tool_name: str,
|
||||||
|
tool_input: Mapping[str, Any],
|
||||||
|
*,
|
||||||
|
path_keys: Sequence[str],
|
||||||
|
task_kind: ActionKind = "subagent",
|
||||||
|
) -> tuple[ActionKind, str]:
|
||||||
|
name_lower = tool_name.lower()
|
||||||
|
|
||||||
|
if name_lower in {"bash", "shell", "killshell"}:
|
||||||
|
command = tool_input.get("command")
|
||||||
|
display = relativize_command(str(command or tool_name))
|
||||||
|
return "command", display
|
||||||
|
|
||||||
|
if name_lower in {"edit", "write", "notebookedit", "multiedit"}:
|
||||||
|
path = tool_input_path(tool_input, path_keys=path_keys)
|
||||||
|
if path:
|
||||||
|
return "file_change", relativize_path(str(path))
|
||||||
|
return "file_change", str(tool_name)
|
||||||
|
|
||||||
|
if name_lower == "read":
|
||||||
|
path = tool_input_path(tool_input, path_keys=path_keys)
|
||||||
|
if path:
|
||||||
|
return "tool", f"read: `{relativize_path(str(path))}`"
|
||||||
|
return "tool", "read"
|
||||||
|
|
||||||
|
if name_lower == "glob":
|
||||||
|
pattern = tool_input.get("pattern")
|
||||||
|
if pattern:
|
||||||
|
return "tool", f"glob: `{pattern}`"
|
||||||
|
return "tool", "glob"
|
||||||
|
|
||||||
|
if name_lower == "grep":
|
||||||
|
pattern = tool_input.get("pattern")
|
||||||
|
if pattern:
|
||||||
|
return "tool", f"grep: {pattern}"
|
||||||
|
return "tool", "grep"
|
||||||
|
|
||||||
|
if name_lower == "find":
|
||||||
|
pattern = tool_input.get("pattern")
|
||||||
|
if pattern:
|
||||||
|
return "tool", f"find: {pattern}"
|
||||||
|
return "tool", "find"
|
||||||
|
|
||||||
|
if name_lower == "ls":
|
||||||
|
path = tool_input_path(tool_input, path_keys=path_keys)
|
||||||
|
if path:
|
||||||
|
return "tool", f"ls: `{relativize_path(str(path))}`"
|
||||||
|
return "tool", "ls"
|
||||||
|
|
||||||
|
if name_lower in {"websearch", "web_search"}:
|
||||||
|
query = tool_input.get("query")
|
||||||
|
return "web_search", str(query or "search")
|
||||||
|
|
||||||
|
if name_lower in {"webfetch", "web_fetch"}:
|
||||||
|
url = tool_input.get("url")
|
||||||
|
return "web_search", str(url or "fetch")
|
||||||
|
|
||||||
|
if name_lower in {"todowrite", "todoread"}:
|
||||||
|
return "note", "update todos" if "write" in name_lower else "read todos"
|
||||||
|
|
||||||
|
if name_lower == "askuserquestion":
|
||||||
|
return "note", "ask user"
|
||||||
|
|
||||||
|
if name_lower in {"task", "agent"}:
|
||||||
|
desc = tool_input.get("description") or tool_input.get("prompt")
|
||||||
|
return task_kind, str(desc or tool_name)
|
||||||
|
|
||||||
|
return "tool", tool_name
|
||||||
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, Mapping
|
from typing import Any
|
||||||
|
from collections.abc import Iterable, Mapping
|
||||||
|
|
||||||
from .backends import EngineBackend
|
from .backends import EngineBackend
|
||||||
from .config import ConfigError, ProjectsConfig
|
from .config import ConfigError, ProjectsConfig
|
||||||
|
|||||||
+17
-2
@@ -2,14 +2,18 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Awaitable, Callable, Protocol
|
from typing import Any, Protocol
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
from .context import RunContext
|
from .context import RunContext
|
||||||
|
from .logging import get_logger
|
||||||
from .model import ResumeToken
|
from .model import ResumeToken
|
||||||
from .transport import ChannelId, MessageId, ThreadId
|
from .transport import ChannelId, MessageId, ThreadId
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class ThreadJob:
|
class ThreadJob:
|
||||||
@@ -108,7 +112,18 @@ class ThreadScheduler:
|
|||||||
if done is not None and not done.is_set():
|
if done is not None and not done.is_set():
|
||||||
await done.wait()
|
await done.wait()
|
||||||
|
|
||||||
await self._run_job(job)
|
try:
|
||||||
|
await self._run_job(job)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.exception(
|
||||||
|
"scheduler.job_failed",
|
||||||
|
key=key,
|
||||||
|
tag=job.resume_token.engine,
|
||||||
|
chat_id=job.chat_id,
|
||||||
|
user_msg_id=job.user_msg_id,
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._active_threads.discard(key)
|
self._active_threads.discard(key)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Literal, TypeAlias
|
from typing import Any, Literal
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ class StreamToolResultBlock(
|
|||||||
is_error: bool | None = None
|
is_error: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
StreamContentBlock: TypeAlias = (
|
type StreamContentBlock = (
|
||||||
StreamTextBlock | StreamThinkingBlock | StreamToolUseBlock | StreamToolResultBlock
|
StreamTextBlock | StreamThinkingBlock | StreamToolUseBlock | StreamToolResultBlock
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -164,7 +164,7 @@ class ControlRewindFilesRequest(
|
|||||||
user_message_id: str
|
user_message_id: str
|
||||||
|
|
||||||
|
|
||||||
ControlRequest: TypeAlias = (
|
type ControlRequest = (
|
||||||
ControlInterruptRequest
|
ControlInterruptRequest
|
||||||
| ControlCanUseToolRequest
|
| ControlCanUseToolRequest
|
||||||
| ControlInitializeRequest
|
| ControlInitializeRequest
|
||||||
@@ -196,7 +196,7 @@ class ControlErrorResponse(
|
|||||||
error: str
|
error: str
|
||||||
|
|
||||||
|
|
||||||
ControlResponse: TypeAlias = ControlSuccessResponse | ControlErrorResponse
|
type ControlResponse = ControlSuccessResponse | ControlErrorResponse
|
||||||
|
|
||||||
|
|
||||||
class StreamControlResponse(
|
class StreamControlResponse(
|
||||||
@@ -217,7 +217,7 @@ class StreamControlCancelRequest(
|
|||||||
request_id: str | None = None
|
request_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
StreamJsonMessage: TypeAlias = (
|
type StreamJsonMessage = (
|
||||||
StreamUserMessage
|
StreamUserMessage
|
||||||
| StreamAssistantMessage
|
| StreamAssistantMessage
|
||||||
| StreamSystemMessage
|
| StreamSystemMessage
|
||||||
|
|||||||
@@ -2,27 +2,27 @@ from __future__ import annotations
|
|||||||
|
|
||||||
# Headless JSONL schema derived from tag rust-v0.77.0 (git 112f40e91c12af0f7146d7e03f20283516a8af0b).
|
# Headless JSONL schema derived from tag rust-v0.77.0 (git 112f40e91c12af0f7146d7e03f20283516a8af0b).
|
||||||
|
|
||||||
from typing import Any, Literal, TypeAlias
|
from typing import Any, Literal
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
CommandExecutionStatus: TypeAlias = Literal[
|
type CommandExecutionStatus = Literal[
|
||||||
"in_progress",
|
"in_progress",
|
||||||
"completed",
|
"completed",
|
||||||
"failed",
|
"failed",
|
||||||
"declined",
|
"declined",
|
||||||
]
|
]
|
||||||
PatchApplyStatus: TypeAlias = Literal[
|
type PatchApplyStatus = Literal[
|
||||||
"in_progress",
|
"in_progress",
|
||||||
"completed",
|
"completed",
|
||||||
"failed",
|
"failed",
|
||||||
]
|
]
|
||||||
PatchChangeKind: TypeAlias = Literal[
|
type PatchChangeKind = Literal[
|
||||||
"add",
|
"add",
|
||||||
"delete",
|
"delete",
|
||||||
"update",
|
"update",
|
||||||
]
|
]
|
||||||
McpToolCallStatus: TypeAlias = Literal[
|
type McpToolCallStatus = Literal[
|
||||||
"in_progress",
|
"in_progress",
|
||||||
"completed",
|
"completed",
|
||||||
"failed",
|
"failed",
|
||||||
@@ -127,7 +127,7 @@ class TodoListItem(msgspec.Struct, tag="todo_list", kw_only=True):
|
|||||||
items: list[TodoItem]
|
items: list[TodoItem]
|
||||||
|
|
||||||
|
|
||||||
ThreadItem: TypeAlias = (
|
type ThreadItem = (
|
||||||
AgentMessageItem
|
AgentMessageItem
|
||||||
| ReasoningItem
|
| ReasoningItem
|
||||||
| CommandExecutionItem
|
| CommandExecutionItem
|
||||||
@@ -151,7 +151,7 @@ class ItemCompleted(msgspec.Struct, tag="item.completed", kw_only=True):
|
|||||||
item: ThreadItem
|
item: ThreadItem
|
||||||
|
|
||||||
|
|
||||||
ThreadEvent: TypeAlias = (
|
type ThreadEvent = (
|
||||||
ThreadStarted
|
ThreadStarted
|
||||||
| TurnStarted
|
| TurnStarted
|
||||||
| TurnCompleted
|
| TurnCompleted
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, TypeAlias
|
from typing import Any
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ class Error(_Event, tag="error"):
|
|||||||
message: Any = None
|
message: Any = None
|
||||||
|
|
||||||
|
|
||||||
OpenCodeEvent: TypeAlias = StepStart | StepFinish | ToolUse | Text | Error
|
type OpenCodeEvent = StepStart | StepFinish | ToolUse | Text | Error
|
||||||
|
|
||||||
_DECODER = msgspec.json.Decoder(OpenCodeEvent)
|
_DECODER = msgspec.json.Decoder(OpenCodeEvent)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, TypeAlias
|
from typing import Any
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ class AutoRetryEnd(_Event, tag="auto_retry_end"):
|
|||||||
finalError: str | None = None
|
finalError: str | None = None
|
||||||
|
|
||||||
|
|
||||||
PiEvent: TypeAlias = (
|
type PiEvent = (
|
||||||
AgentStart
|
AgentStart
|
||||||
| AgentEnd
|
| AgentEnd
|
||||||
| MessageStart
|
| MessageStart
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, ClassVar, Iterable, Literal
|
from typing import Annotated, Any, ClassVar, Literal
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
|
|||||||
@@ -50,9 +50,7 @@ def _build_startup_message(
|
|||||||
notes.append(f"failed to load: {', '.join(failed_engines)}")
|
notes.append(f"failed to load: {', '.join(failed_engines)}")
|
||||||
if notes:
|
if notes:
|
||||||
engine_list = f"{engine_list} ({'; '.join(notes)})"
|
engine_list = f"{engine_list} ({'; '.join(notes)})"
|
||||||
project_aliases = sorted(
|
project_aliases = sorted(set(runtime.project_aliases()), key=str.lower)
|
||||||
{alias for alias in runtime.project_aliases()}, key=str.lower
|
|
||||||
)
|
|
||||||
project_list = ", ".join(project_aliases) if project_aliases else "none"
|
project_list = ", ".join(project_aliases) if project_aliases else "none"
|
||||||
return (
|
return (
|
||||||
f"\N{OCTOPUS} **takopi is ready**\n\n"
|
f"\N{OCTOPUS} **takopi is ready**\n\n"
|
||||||
@@ -78,8 +76,7 @@ class TelegramBackend(TransportBackend):
|
|||||||
def interactive_setup(self, *, force: bool) -> bool:
|
def interactive_setup(self, *, force: bool) -> bool:
|
||||||
return interactive_setup(force=force)
|
return interactive_setup(force=force)
|
||||||
|
|
||||||
def lock_token(self, *, transport_config: object, config_path: Path) -> str | None:
|
def lock_token(self, *, transport_config: object, _config_path: Path) -> str | None:
|
||||||
_ = config_path
|
|
||||||
settings = _expect_transport_settings(transport_config)
|
settings = _expect_transport_settings(transport_config)
|
||||||
return settings.bot_token
|
return settings.bot_token
|
||||||
|
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ def _is_cancelled_label(label: str) -> bool:
|
|||||||
return stripped.lower() == "cancelled"
|
return stripped.lower() == "cancelled"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class TelegramBridgeConfig:
|
class TelegramBridgeConfig:
|
||||||
bot: BotClient
|
bot: BotClient
|
||||||
runtime: TransportRuntime
|
runtime: TransportRuntime
|
||||||
|
|||||||
+86
-985
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,537 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Protocol, TypeVar
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import msgspec
|
||||||
|
|
||||||
|
from ..logging import get_logger
|
||||||
|
from .api_models import Chat, ChatMember, File, ForumTopic, Message, Update, User
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class RetryAfter(Exception):
|
||||||
|
def __init__(self, retry_after: float, description: str | None = None) -> None:
|
||||||
|
super().__init__(description or f"retry after {retry_after}")
|
||||||
|
self.retry_after = float(retry_after)
|
||||||
|
self.description = description
|
||||||
|
|
||||||
|
|
||||||
|
class TelegramRetryAfter(RetryAfter):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def retry_after_from_payload(payload: dict[str, Any]) -> float | None:
|
||||||
|
params = payload.get("parameters")
|
||||||
|
if isinstance(params, dict):
|
||||||
|
retry_after = params.get("retry_after")
|
||||||
|
if isinstance(retry_after, (int, float)):
|
||||||
|
return float(retry_after)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class BotClient(Protocol):
|
||||||
|
async def close(self) -> None: ...
|
||||||
|
|
||||||
|
async def get_updates(
|
||||||
|
self,
|
||||||
|
offset: int | None,
|
||||||
|
timeout_s: int = 50,
|
||||||
|
allowed_updates: list[str] | None = None,
|
||||||
|
) -> list[Update] | None: ...
|
||||||
|
|
||||||
|
async def get_file(self, file_id: str) -> File | None: ...
|
||||||
|
|
||||||
|
async def download_file(self, file_path: str) -> bytes | None: ...
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
text: str,
|
||||||
|
reply_to_message_id: int | None = None,
|
||||||
|
disable_notification: bool | None = False,
|
||||||
|
message_thread_id: int | None = None,
|
||||||
|
entities: list[dict] | None = None,
|
||||||
|
parse_mode: str | None = None,
|
||||||
|
reply_markup: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
replace_message_id: int | None = None,
|
||||||
|
) -> Message | None: ...
|
||||||
|
|
||||||
|
async def send_document(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
reply_to_message_id: int | None = None,
|
||||||
|
message_thread_id: int | None = None,
|
||||||
|
disable_notification: bool | None = False,
|
||||||
|
caption: str | None = None,
|
||||||
|
) -> Message | None: ...
|
||||||
|
|
||||||
|
async def edit_message_text(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
message_id: int,
|
||||||
|
text: str,
|
||||||
|
entities: list[dict] | None = None,
|
||||||
|
parse_mode: str | None = None,
|
||||||
|
reply_markup: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
wait: bool = True,
|
||||||
|
) -> Message | None: ...
|
||||||
|
|
||||||
|
async def delete_message(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
message_id: int,
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
|
async def set_my_commands(
|
||||||
|
self,
|
||||||
|
commands: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
scope: dict[str, Any] | None = None,
|
||||||
|
language_code: str | None = None,
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
|
async def get_me(self) -> User | None: ...
|
||||||
|
|
||||||
|
async def answer_callback_query(
|
||||||
|
self,
|
||||||
|
callback_query_id: str,
|
||||||
|
text: str | None = None,
|
||||||
|
show_alert: bool | None = None,
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
|
async def get_chat(self, chat_id: int) -> Chat | None: ...
|
||||||
|
|
||||||
|
async def get_chat_member(
|
||||||
|
self, chat_id: int, user_id: int
|
||||||
|
) -> ChatMember | None: ...
|
||||||
|
|
||||||
|
async def create_forum_topic(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
name: str,
|
||||||
|
) -> ForumTopic | None: ...
|
||||||
|
|
||||||
|
async def edit_forum_topic(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
message_thread_id: int,
|
||||||
|
name: str,
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class HttpBotClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
token: str,
|
||||||
|
*,
|
||||||
|
timeout_s: float = 120,
|
||||||
|
http_client: httpx.AsyncClient | None = None,
|
||||||
|
) -> None:
|
||||||
|
if not token:
|
||||||
|
raise ValueError("Telegram token is empty")
|
||||||
|
self._base = f"https://api.telegram.org/bot{token}"
|
||||||
|
self._file_base = f"https://api.telegram.org/file/bot{token}"
|
||||||
|
self._http_client = http_client or httpx.AsyncClient(timeout=timeout_s)
|
||||||
|
self._owns_http_client = http_client is None
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._owns_http_client:
|
||||||
|
await self._http_client.aclose()
|
||||||
|
|
||||||
|
def _parse_telegram_envelope(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
method: str,
|
||||||
|
resp: httpx.Response,
|
||||||
|
payload: Any,
|
||||||
|
) -> Any | None:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
logger.error(
|
||||||
|
"telegram.invalid_payload",
|
||||||
|
method=method,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not payload.get("ok"):
|
||||||
|
if payload.get("error_code") == 429:
|
||||||
|
retry_after = retry_after_from_payload(payload)
|
||||||
|
retry_after = 5.0 if retry_after is None else retry_after
|
||||||
|
logger.warning(
|
||||||
|
"telegram.rate_limited",
|
||||||
|
method=method,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
retry_after=retry_after,
|
||||||
|
)
|
||||||
|
raise TelegramRetryAfter(retry_after)
|
||||||
|
logger.error(
|
||||||
|
"telegram.api_error",
|
||||||
|
method=method,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug("telegram.response", method=method, payload=payload)
|
||||||
|
return payload.get("result")
|
||||||
|
|
||||||
|
async def _request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
*,
|
||||||
|
json: dict[str, Any] | None = None,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
files: dict[str, Any] | None = None,
|
||||||
|
) -> Any | None:
|
||||||
|
request_payload = json if json is not None else data
|
||||||
|
logger.debug("telegram.request", method=method, payload=request_payload)
|
||||||
|
try:
|
||||||
|
if json is not None:
|
||||||
|
resp = await self._http_client.post(f"{self._base}/{method}", json=json)
|
||||||
|
else:
|
||||||
|
resp = await self._http_client.post(
|
||||||
|
f"{self._base}/{method}", data=data, files=files
|
||||||
|
)
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
url = getattr(exc.request, "url", None)
|
||||||
|
logger.error(
|
||||||
|
"telegram.network_error",
|
||||||
|
method=method,
|
||||||
|
url=str(url) if url is not None else None,
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if resp.status_code == 429:
|
||||||
|
retry_after: float | None = None
|
||||||
|
try:
|
||||||
|
response_payload = resp.json()
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
response_payload = None
|
||||||
|
if isinstance(response_payload, dict):
|
||||||
|
retry_after = retry_after_from_payload(response_payload)
|
||||||
|
retry_after = 5.0 if retry_after is None else retry_after
|
||||||
|
logger.warning(
|
||||||
|
"telegram.rate_limited",
|
||||||
|
method=method,
|
||||||
|
status=resp.status_code,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
retry_after=retry_after,
|
||||||
|
)
|
||||||
|
raise TelegramRetryAfter(retry_after) from exc
|
||||||
|
body = resp.text
|
||||||
|
logger.error(
|
||||||
|
"telegram.http_error",
|
||||||
|
method=method,
|
||||||
|
status=resp.status_code,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
error=str(exc),
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_payload = resp.json()
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
body = resp.text
|
||||||
|
logger.error(
|
||||||
|
"telegram.bad_response",
|
||||||
|
method=method,
|
||||||
|
status=resp.status_code,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._parse_telegram_envelope(
|
||||||
|
method=method,
|
||||||
|
resp=resp,
|
||||||
|
payload=response_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _decode_result(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
method: str,
|
||||||
|
payload: Any,
|
||||||
|
model: type[T],
|
||||||
|
) -> T | None:
|
||||||
|
if payload is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return msgspec.convert(payload, type=model)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.error(
|
||||||
|
"telegram.decode_error",
|
||||||
|
method=method,
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _post(self, method: str, json_data: dict[str, Any]) -> Any | None:
|
||||||
|
return await self._request(method, json=json_data)
|
||||||
|
|
||||||
|
async def _post_form(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
data: dict[str, Any],
|
||||||
|
files: dict[str, Any],
|
||||||
|
) -> Any | None:
|
||||||
|
return await self._request(method, data=data, files=files)
|
||||||
|
|
||||||
|
async def get_updates(
|
||||||
|
self,
|
||||||
|
offset: int | None,
|
||||||
|
timeout_s: int = 50,
|
||||||
|
allowed_updates: list[str] | None = None,
|
||||||
|
) -> list[Update] | None:
|
||||||
|
params: dict[str, Any] = {"timeout": timeout_s}
|
||||||
|
if offset is not None:
|
||||||
|
params["offset"] = offset
|
||||||
|
if allowed_updates is not None:
|
||||||
|
params["allowed_updates"] = allowed_updates
|
||||||
|
result = await self._post("getUpdates", params)
|
||||||
|
if result is None or not isinstance(result, list):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return msgspec.convert(result, type=list[Update])
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.error(
|
||||||
|
"telegram.decode_error",
|
||||||
|
method="getUpdates",
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_file(self, file_id: str) -> File | None:
|
||||||
|
result = await self._post("getFile", {"file_id": file_id})
|
||||||
|
return self._decode_result(method="getFile", payload=result, model=File)
|
||||||
|
|
||||||
|
async def download_file(self, file_path: str) -> bytes | None:
|
||||||
|
url = f"{self._file_base}/{file_path}"
|
||||||
|
try:
|
||||||
|
resp = await self._http_client.get(url)
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
request_url = getattr(exc.request, "url", None)
|
||||||
|
logger.error(
|
||||||
|
"telegram.file_network_error",
|
||||||
|
url=str(request_url) if request_url is not None else None,
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
resp.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if resp.status_code == 429:
|
||||||
|
retry_after: float | None = None
|
||||||
|
try:
|
||||||
|
response_payload = resp.json()
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
response_payload = None
|
||||||
|
if isinstance(response_payload, dict):
|
||||||
|
retry_after = retry_after_from_payload(response_payload)
|
||||||
|
retry_after = 5.0 if retry_after is None else retry_after
|
||||||
|
logger.warning(
|
||||||
|
"telegram.rate_limited",
|
||||||
|
method="download_file",
|
||||||
|
status=resp.status_code,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
retry_after=retry_after,
|
||||||
|
)
|
||||||
|
raise TelegramRetryAfter(retry_after) from exc
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
"telegram.file_http_error",
|
||||||
|
status=resp.status_code,
|
||||||
|
url=str(resp.request.url),
|
||||||
|
error=str(exc),
|
||||||
|
body=resp.text,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return resp.content
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
text: str,
|
||||||
|
reply_to_message_id: int | None = None,
|
||||||
|
disable_notification: bool | None = False,
|
||||||
|
message_thread_id: int | None = None,
|
||||||
|
entities: list[dict] | None = None,
|
||||||
|
parse_mode: str | None = None,
|
||||||
|
reply_markup: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
replace_message_id: int | None = None,
|
||||||
|
) -> Message | None:
|
||||||
|
params: dict[str, Any] = {"chat_id": chat_id, "text": text}
|
||||||
|
if disable_notification is not None:
|
||||||
|
params["disable_notification"] = disable_notification
|
||||||
|
if reply_to_message_id is not None:
|
||||||
|
params["reply_to_message_id"] = reply_to_message_id
|
||||||
|
if message_thread_id is not None:
|
||||||
|
params["message_thread_id"] = message_thread_id
|
||||||
|
if entities is not None:
|
||||||
|
params["entities"] = entities
|
||||||
|
if parse_mode is not None:
|
||||||
|
params["parse_mode"] = parse_mode
|
||||||
|
if reply_markup is not None:
|
||||||
|
params["reply_markup"] = reply_markup
|
||||||
|
result = await self._post("sendMessage", params)
|
||||||
|
return self._decode_result(method="sendMessage", payload=result, model=Message)
|
||||||
|
|
||||||
|
async def send_document(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
reply_to_message_id: int | None = None,
|
||||||
|
message_thread_id: int | None = None,
|
||||||
|
disable_notification: bool | None = False,
|
||||||
|
caption: str | None = None,
|
||||||
|
) -> Message | None:
|
||||||
|
params: dict[str, Any] = {"chat_id": chat_id}
|
||||||
|
if disable_notification is not None:
|
||||||
|
params["disable_notification"] = disable_notification
|
||||||
|
if reply_to_message_id is not None:
|
||||||
|
params["reply_to_message_id"] = reply_to_message_id
|
||||||
|
if message_thread_id is not None:
|
||||||
|
params["message_thread_id"] = message_thread_id
|
||||||
|
if caption is not None:
|
||||||
|
params["caption"] = caption
|
||||||
|
result = await self._post_form(
|
||||||
|
"sendDocument",
|
||||||
|
params,
|
||||||
|
files={"document": (filename, content)},
|
||||||
|
)
|
||||||
|
return self._decode_result(method="sendDocument", payload=result, model=Message)
|
||||||
|
|
||||||
|
async def edit_message_text(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
message_id: int,
|
||||||
|
text: str,
|
||||||
|
entities: list[dict] | None = None,
|
||||||
|
parse_mode: str | None = None,
|
||||||
|
reply_markup: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
wait: bool = True,
|
||||||
|
) -> Message | None:
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"message_id": message_id,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
if entities is not None:
|
||||||
|
params["entities"] = entities
|
||||||
|
if parse_mode is not None:
|
||||||
|
params["parse_mode"] = parse_mode
|
||||||
|
if reply_markup is not None:
|
||||||
|
params["reply_markup"] = reply_markup
|
||||||
|
result = await self._post("editMessageText", params)
|
||||||
|
return self._decode_result(
|
||||||
|
method="editMessageText",
|
||||||
|
payload=result,
|
||||||
|
model=Message,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_message(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
message_id: int,
|
||||||
|
) -> bool:
|
||||||
|
result = await self._post(
|
||||||
|
"deleteMessage",
|
||||||
|
{"chat_id": chat_id, "message_id": message_id},
|
||||||
|
)
|
||||||
|
return bool(result)
|
||||||
|
|
||||||
|
async def set_my_commands(
|
||||||
|
self,
|
||||||
|
commands: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
scope: dict[str, Any] | None = None,
|
||||||
|
language_code: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
params: dict[str, Any] = {"commands": commands}
|
||||||
|
if scope is not None:
|
||||||
|
params["scope"] = scope
|
||||||
|
if language_code is not None:
|
||||||
|
params["language_code"] = language_code
|
||||||
|
result = await self._post("setMyCommands", params)
|
||||||
|
return bool(result)
|
||||||
|
|
||||||
|
async def get_me(self) -> User | None:
|
||||||
|
result = await self._post("getMe", {})
|
||||||
|
return self._decode_result(method="getMe", payload=result, model=User)
|
||||||
|
|
||||||
|
async def answer_callback_query(
|
||||||
|
self,
|
||||||
|
callback_query_id: str,
|
||||||
|
text: str | None = None,
|
||||||
|
show_alert: bool | None = None,
|
||||||
|
) -> bool:
|
||||||
|
params: dict[str, Any] = {"callback_query_id": callback_query_id}
|
||||||
|
if text is not None:
|
||||||
|
params["text"] = text
|
||||||
|
if show_alert is not None:
|
||||||
|
params["show_alert"] = show_alert
|
||||||
|
result = await self._post("answerCallbackQuery", params)
|
||||||
|
return bool(result)
|
||||||
|
|
||||||
|
async def get_chat(self, chat_id: int) -> Chat | None:
|
||||||
|
result = await self._post("getChat", {"chat_id": chat_id})
|
||||||
|
return self._decode_result(method="getChat", payload=result, model=Chat)
|
||||||
|
|
||||||
|
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
|
||||||
|
result = await self._post(
|
||||||
|
"getChatMember", {"chat_id": chat_id, "user_id": user_id}
|
||||||
|
)
|
||||||
|
return self._decode_result(
|
||||||
|
method="getChatMember",
|
||||||
|
payload=result,
|
||||||
|
model=ChatMember,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
|
||||||
|
result = await self._post(
|
||||||
|
"createForumTopic", {"chat_id": chat_id, "name": name}
|
||||||
|
)
|
||||||
|
return self._decode_result(
|
||||||
|
method="createForumTopic",
|
||||||
|
payload=result,
|
||||||
|
model=ForumTopic,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def edit_forum_topic(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
message_thread_id: int,
|
||||||
|
name: str,
|
||||||
|
) -> bool:
|
||||||
|
result = await self._post(
|
||||||
|
"editForumTopic",
|
||||||
|
{
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"message_thread_id": message_thread_id,
|
||||||
|
"name": name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return bool(result)
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,12 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .cancel import handle_callback_cancel, handle_cancel
|
||||||
|
from .menu import build_bot_commands
|
||||||
|
from .parse import is_cancel_command
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"build_bot_commands",
|
||||||
|
"handle_callback_cancel",
|
||||||
|
"handle_cancel",
|
||||||
|
"is_cancel_command",
|
||||||
|
]
|
||||||
@@ -0,0 +1,160 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...context import RunContext
|
||||||
|
from ...directives import DirectiveError
|
||||||
|
from ..chat_prefs import ChatPrefsStore
|
||||||
|
from ..engine_defaults import resolve_engine_for_message
|
||||||
|
from ..files import split_command_args
|
||||||
|
from ..topic_state import TopicStateStore
|
||||||
|
from ..topics import _topic_key
|
||||||
|
from ..types import TelegramIncomingMessage
|
||||||
|
from .reply import make_reply
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..bridge import TelegramBridgeConfig
|
||||||
|
|
||||||
|
AGENT_USAGE = "usage: `/agent`, `/agent set <engine>`, or `/agent clear`"
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_agent_permissions(
|
||||||
|
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
|
||||||
|
) -> bool:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
sender_id = msg.sender_id
|
||||||
|
if sender_id is None:
|
||||||
|
await reply(text="cannot verify sender for agent defaults.")
|
||||||
|
return False
|
||||||
|
is_private = msg.chat_type == "private"
|
||||||
|
if msg.chat_type is None:
|
||||||
|
is_private = msg.chat_id > 0
|
||||||
|
if is_private:
|
||||||
|
return True
|
||||||
|
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
|
||||||
|
if member is None:
|
||||||
|
await reply(text="failed to verify agent permissions.")
|
||||||
|
return False
|
||||||
|
if member.status in {"creator", "administrator"}:
|
||||||
|
return True
|
||||||
|
await reply(text="changing default agents is restricted to group admins.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_agent_command(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
args_text: str,
|
||||||
|
ambient_context: RunContext | None,
|
||||||
|
topic_store: TopicStateStore | None,
|
||||||
|
chat_prefs: ChatPrefsStore | None,
|
||||||
|
*,
|
||||||
|
resolved_scope: str | None = None,
|
||||||
|
scope_chat_ids: frozenset[int] | None = None,
|
||||||
|
) -> None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
tkey = (
|
||||||
|
_topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
|
||||||
|
if topic_store is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
tokens = split_command_args(args_text)
|
||||||
|
action = tokens[0].lower() if tokens else "show"
|
||||||
|
|
||||||
|
if action in {"show", ""}:
|
||||||
|
try:
|
||||||
|
resolved = cfg.runtime.resolve_message(
|
||||||
|
text="",
|
||||||
|
reply_text=msg.reply_to_text,
|
||||||
|
ambient_context=ambient_context,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
)
|
||||||
|
except DirectiveError as exc:
|
||||||
|
await reply(text=f"error:\n{exc}")
|
||||||
|
return
|
||||||
|
selection = await resolve_engine_for_message(
|
||||||
|
runtime=cfg.runtime,
|
||||||
|
context=resolved.context,
|
||||||
|
explicit_engine=None,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
topic_key=tkey,
|
||||||
|
topic_store=topic_store,
|
||||||
|
chat_prefs=chat_prefs,
|
||||||
|
)
|
||||||
|
source_labels = {
|
||||||
|
"directive": "directive",
|
||||||
|
"topic_default": "topic default",
|
||||||
|
"chat_default": "chat default",
|
||||||
|
"project_default": "project default",
|
||||||
|
"global_default": "global default",
|
||||||
|
}
|
||||||
|
agent_line = f"agent: {selection.engine} ({source_labels[selection.source]})"
|
||||||
|
topic_default = selection.topic_default or "none"
|
||||||
|
if tkey is None:
|
||||||
|
topic_default = "none"
|
||||||
|
if chat_prefs is None:
|
||||||
|
chat_default = "unavailable"
|
||||||
|
else:
|
||||||
|
chat_default = selection.chat_default or "none"
|
||||||
|
project_default = (
|
||||||
|
selection.project_default
|
||||||
|
if selection.project_default is not None
|
||||||
|
else "none"
|
||||||
|
)
|
||||||
|
defaults_line = (
|
||||||
|
"defaults: "
|
||||||
|
f"topic: {topic_default}, "
|
||||||
|
f"chat: {chat_default}, "
|
||||||
|
f"project: {project_default}, "
|
||||||
|
f"global: {cfg.runtime.default_engine}"
|
||||||
|
)
|
||||||
|
available = ", ".join(cfg.runtime.engine_ids)
|
||||||
|
available_line = f"available: {available}"
|
||||||
|
await reply(text="\n\n".join([agent_line, defaults_line, available_line]))
|
||||||
|
return
|
||||||
|
|
||||||
|
if action == "set":
|
||||||
|
if len(tokens) < 2:
|
||||||
|
await reply(text=AGENT_USAGE)
|
||||||
|
return
|
||||||
|
if not await _check_agent_permissions(cfg, msg):
|
||||||
|
return
|
||||||
|
engine = tokens[1].strip().lower()
|
||||||
|
if engine not in cfg.runtime.engine_ids:
|
||||||
|
available = ", ".join(cfg.runtime.engine_ids)
|
||||||
|
await reply(
|
||||||
|
text=f"unknown engine `{engine}`.\navailable agents: `{available}`",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if tkey is not None:
|
||||||
|
if topic_store is None:
|
||||||
|
await reply(text="topic defaults are unavailable.")
|
||||||
|
return
|
||||||
|
await topic_store.set_default_engine(tkey[0], tkey[1], engine)
|
||||||
|
await reply(text=f"topic default agent set to `{engine}`")
|
||||||
|
return
|
||||||
|
if chat_prefs is None:
|
||||||
|
await reply(text="chat defaults are unavailable (no config path).")
|
||||||
|
return
|
||||||
|
await chat_prefs.set_default_engine(msg.chat_id, engine)
|
||||||
|
await reply(text=f"chat default agent set to `{engine}`")
|
||||||
|
return
|
||||||
|
|
||||||
|
if action == "clear":
|
||||||
|
if not await _check_agent_permissions(cfg, msg):
|
||||||
|
return
|
||||||
|
if tkey is not None:
|
||||||
|
if topic_store is None:
|
||||||
|
await reply(text="topic defaults are unavailable.")
|
||||||
|
return
|
||||||
|
await topic_store.clear_default_engine(tkey[0], tkey[1])
|
||||||
|
await reply(text="topic default agent cleared.")
|
||||||
|
return
|
||||||
|
if chat_prefs is None:
|
||||||
|
await reply(text="chat defaults are unavailable (no config path).")
|
||||||
|
return
|
||||||
|
await chat_prefs.clear_default_engine(msg.chat_id)
|
||||||
|
await reply(text="chat default agent cleared.")
|
||||||
|
return
|
||||||
|
|
||||||
|
await reply(text=AGENT_USAGE)
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...logging import get_logger
|
||||||
|
from ...runner_bridge import RunningTasks
|
||||||
|
from ...transport import MessageRef
|
||||||
|
from ..types import TelegramCallbackQuery, TelegramIncomingMessage
|
||||||
|
from .reply import make_reply
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..bridge import TelegramBridgeConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_cancel(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
running_tasks: RunningTasks,
|
||||||
|
) -> None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
chat_id = msg.chat_id
|
||||||
|
reply_id = msg.reply_to_message_id
|
||||||
|
|
||||||
|
if reply_id is None:
|
||||||
|
if msg.reply_to_text:
|
||||||
|
await reply(text="nothing is currently running for that message.")
|
||||||
|
return
|
||||||
|
await reply(text="reply to the progress message to cancel.")
|
||||||
|
return
|
||||||
|
|
||||||
|
progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id)
|
||||||
|
running_task = running_tasks.get(progress_ref)
|
||||||
|
if running_task is None:
|
||||||
|
await reply(text="nothing is currently running for that message.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"cancel.requested",
|
||||||
|
chat_id=chat_id,
|
||||||
|
progress_message_id=reply_id,
|
||||||
|
)
|
||||||
|
running_task.cancel_requested.set()
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_callback_cancel(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
query: TelegramCallbackQuery,
|
||||||
|
running_tasks: RunningTasks,
|
||||||
|
) -> None:
|
||||||
|
progress_ref = MessageRef(channel_id=query.chat_id, message_id=query.message_id)
|
||||||
|
running_task = running_tasks.get(progress_ref)
|
||||||
|
if running_task is None:
|
||||||
|
await cfg.bot.answer_callback_query(
|
||||||
|
callback_query_id=query.callback_query_id,
|
||||||
|
text="nothing is currently running for that message.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.info(
|
||||||
|
"cancel.requested",
|
||||||
|
chat_id=query.chat_id,
|
||||||
|
progress_message_id=query.message_id,
|
||||||
|
)
|
||||||
|
running_task.cancel_requested.set()
|
||||||
|
await cfg.bot.answer_callback_query(
|
||||||
|
callback_query_id=query.callback_query_id,
|
||||||
|
text="cancelling...",
|
||||||
|
)
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from ...commands import CommandContext, get_command
|
||||||
|
from ...config import ConfigError
|
||||||
|
from ...logging import get_logger
|
||||||
|
from ...model import EngineId, ResumeToken
|
||||||
|
from ...runner_bridge import RunningTasks
|
||||||
|
from ...scheduler import ThreadScheduler
|
||||||
|
from ...transport import MessageRef
|
||||||
|
from ..files import split_command_args
|
||||||
|
from ..types import TelegramIncomingMessage
|
||||||
|
from .executor import _TelegramCommandExecutor
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..bridge import TelegramBridgeConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch_command(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
text: str,
|
||||||
|
command_id: str,
|
||||||
|
args_text: str,
|
||||||
|
running_tasks: RunningTasks,
|
||||||
|
scheduler: ThreadScheduler,
|
||||||
|
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None,
|
||||||
|
stateful_mode: bool,
|
||||||
|
default_engine_override: EngineId | None,
|
||||||
|
) -> None:
|
||||||
|
allowlist = cfg.runtime.allowlist
|
||||||
|
chat_id = msg.chat_id
|
||||||
|
user_msg_id = msg.message_id
|
||||||
|
reply_ref = (
|
||||||
|
MessageRef(
|
||||||
|
channel_id=chat_id,
|
||||||
|
message_id=msg.reply_to_message_id,
|
||||||
|
thread_id=msg.thread_id,
|
||||||
|
)
|
||||||
|
if msg.reply_to_message_id is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
executor = _TelegramCommandExecutor(
|
||||||
|
exec_cfg=cfg.exec_cfg,
|
||||||
|
runtime=cfg.runtime,
|
||||||
|
running_tasks=running_tasks,
|
||||||
|
scheduler=scheduler,
|
||||||
|
on_thread_known=on_thread_known,
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_msg_id=user_msg_id,
|
||||||
|
thread_id=msg.thread_id,
|
||||||
|
show_resume_line=cfg.show_resume_line,
|
||||||
|
stateful_mode=stateful_mode,
|
||||||
|
default_engine_override=default_engine_override,
|
||||||
|
)
|
||||||
|
message_ref = MessageRef(
|
||||||
|
channel_id=chat_id,
|
||||||
|
message_id=user_msg_id,
|
||||||
|
thread_id=msg.thread_id,
|
||||||
|
sender_id=msg.sender_id,
|
||||||
|
raw=msg.raw,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
backend = get_command(command_id, allowlist=allowlist, required=False)
|
||||||
|
except ConfigError as exc:
|
||||||
|
await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True)
|
||||||
|
return
|
||||||
|
if backend is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
plugin_config = cfg.runtime.plugin_config(command_id)
|
||||||
|
except ConfigError as exc:
|
||||||
|
await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True)
|
||||||
|
return
|
||||||
|
ctx = CommandContext(
|
||||||
|
command=command_id,
|
||||||
|
text=text,
|
||||||
|
args_text=args_text,
|
||||||
|
args=split_command_args(args_text),
|
||||||
|
message=message_ref,
|
||||||
|
reply_to=reply_ref,
|
||||||
|
reply_text=msg.reply_to_text,
|
||||||
|
config_path=cfg.runtime.config_path,
|
||||||
|
plugin_config=plugin_config,
|
||||||
|
runtime=cfg.runtime,
|
||||||
|
executor=executor,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = await backend.handle(ctx)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception(
|
||||||
|
"command.failed",
|
||||||
|
command=command_id,
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
|
await executor.send(f"error:\n{exc}", reply_to=message_ref, notify=True)
|
||||||
|
return
|
||||||
|
if result is not None:
|
||||||
|
reply_to = message_ref if result.reply_to is None else result.reply_to
|
||||||
|
await executor.send(result.text, reply_to=reply_to, notify=result.notify)
|
||||||
@@ -0,0 +1,382 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from ...commands import CommandExecutor, RunMode, RunRequest, RunResult
|
||||||
|
from ...config import ConfigError
|
||||||
|
from ...context import RunContext
|
||||||
|
from ...logging import bind_run_context, clear_context, get_logger
|
||||||
|
from ...model import EngineId, ResumeToken, TakopiEvent
|
||||||
|
from ...progress import ProgressTracker
|
||||||
|
from ...router import RunnerUnavailableError
|
||||||
|
from ...runner import Runner
|
||||||
|
from ...runner_bridge import (
|
||||||
|
ExecBridgeConfig,
|
||||||
|
IncomingMessage as RunnerIncomingMessage,
|
||||||
|
RunningTasks,
|
||||||
|
handle_message,
|
||||||
|
)
|
||||||
|
from ...scheduler import ThreadScheduler
|
||||||
|
from ...transport import MessageRef, RenderedMessage, SendOptions
|
||||||
|
from ...transport_runtime import TransportRuntime
|
||||||
|
from ...utils.paths import reset_run_base_dir, set_run_base_dir
|
||||||
|
from ..bridge import send_plain
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _ResumeLineProxy:
|
||||||
|
runner: Runner
|
||||||
|
|
||||||
|
@property
|
||||||
|
def engine(self) -> str:
|
||||||
|
return self.runner.engine
|
||||||
|
|
||||||
|
def is_resume_line(self, line: str) -> bool:
|
||||||
|
return self.runner.is_resume_line(line)
|
||||||
|
|
||||||
|
def format_resume(self, _: ResumeToken) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def extract_resume(self, text: str | None) -> ResumeToken | None:
|
||||||
|
return self.runner.extract_resume(text)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, prompt: str, resume: ResumeToken | None
|
||||||
|
) -> AsyncIterator[TakopiEvent]:
|
||||||
|
return self.runner.run(prompt, resume)
|
||||||
|
|
||||||
|
|
||||||
|
def _should_show_resume_line(
|
||||||
|
*,
|
||||||
|
show_resume_line: bool,
|
||||||
|
stateful_mode: bool,
|
||||||
|
context: RunContext | None,
|
||||||
|
) -> bool:
|
||||||
|
if show_resume_line or not stateful_mode:
|
||||||
|
return True
|
||||||
|
return context is None or context.project is None
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_runner_unavailable(
|
||||||
|
exec_cfg: ExecBridgeConfig,
|
||||||
|
*,
|
||||||
|
chat_id: int,
|
||||||
|
user_msg_id: int,
|
||||||
|
resume_token: ResumeToken | None,
|
||||||
|
runner: Runner,
|
||||||
|
reason: str,
|
||||||
|
thread_id: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
tracker = ProgressTracker(engine=runner.engine)
|
||||||
|
tracker.set_resume(resume_token)
|
||||||
|
state = tracker.snapshot(resume_formatter=runner.format_resume)
|
||||||
|
message = exec_cfg.presenter.render_final(
|
||||||
|
state,
|
||||||
|
elapsed_s=0.0,
|
||||||
|
status="error",
|
||||||
|
answer=f"error:\n{reason}",
|
||||||
|
)
|
||||||
|
reply_to = MessageRef(channel_id=chat_id, message_id=user_msg_id)
|
||||||
|
await exec_cfg.transport.send(
|
||||||
|
channel_id=chat_id,
|
||||||
|
message=message,
|
||||||
|
options=SendOptions(reply_to=reply_to, notify=True, thread_id=thread_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_engine(
|
||||||
|
*,
|
||||||
|
exec_cfg: ExecBridgeConfig,
|
||||||
|
runtime: TransportRuntime,
|
||||||
|
running_tasks: RunningTasks | None,
|
||||||
|
chat_id: int,
|
||||||
|
user_msg_id: int,
|
||||||
|
text: str,
|
||||||
|
resume_token: ResumeToken | None,
|
||||||
|
context: RunContext | None,
|
||||||
|
reply_ref: MessageRef | None = None,
|
||||||
|
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
|
||||||
|
| None = None,
|
||||||
|
engine_override: EngineId | None = None,
|
||||||
|
thread_id: int | None = None,
|
||||||
|
show_resume_line: bool = True,
|
||||||
|
) -> None:
|
||||||
|
reply = partial(
|
||||||
|
send_plain,
|
||||||
|
exec_cfg.transport,
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_msg_id=user_msg_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
entry = runtime.resolve_runner(
|
||||||
|
resume_token=resume_token,
|
||||||
|
engine_override=engine_override,
|
||||||
|
)
|
||||||
|
except RunnerUnavailableError as exc:
|
||||||
|
await reply(text=f"error:\n{exc}")
|
||||||
|
return
|
||||||
|
runner: Runner = entry.runner
|
||||||
|
if not show_resume_line:
|
||||||
|
runner = cast(Runner, _ResumeLineProxy(runner))
|
||||||
|
if not entry.available:
|
||||||
|
reason = entry.issue or "engine unavailable"
|
||||||
|
await _send_runner_unavailable(
|
||||||
|
exec_cfg,
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_msg_id=user_msg_id,
|
||||||
|
resume_token=resume_token,
|
||||||
|
runner=runner,
|
||||||
|
reason=reason,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
cwd = runtime.resolve_run_cwd(context)
|
||||||
|
except ConfigError as exc:
|
||||||
|
await reply(text=f"error:\n{exc}")
|
||||||
|
return
|
||||||
|
run_base_token = set_run_base_dir(cwd)
|
||||||
|
try:
|
||||||
|
run_fields = {
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"user_msg_id": user_msg_id,
|
||||||
|
"engine": runner.engine,
|
||||||
|
"resume": resume_token.value if resume_token else None,
|
||||||
|
}
|
||||||
|
if context is not None:
|
||||||
|
run_fields["project"] = context.project
|
||||||
|
run_fields["branch"] = context.branch
|
||||||
|
if cwd is not None:
|
||||||
|
run_fields["cwd"] = str(cwd)
|
||||||
|
bind_run_context(**run_fields)
|
||||||
|
context_line = runtime.format_context_line(context)
|
||||||
|
incoming = RunnerIncomingMessage(
|
||||||
|
channel_id=chat_id,
|
||||||
|
message_id=user_msg_id,
|
||||||
|
text=text,
|
||||||
|
reply_to=reply_ref,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
await handle_message(
|
||||||
|
exec_cfg,
|
||||||
|
runner=runner,
|
||||||
|
incoming=incoming,
|
||||||
|
resume_token=resume_token,
|
||||||
|
context=context,
|
||||||
|
context_line=context_line,
|
||||||
|
strip_resume_line=runtime.is_resume_line,
|
||||||
|
running_tasks=running_tasks,
|
||||||
|
on_thread_known=on_thread_known,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
reset_run_base_dir(run_base_token)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception(
|
||||||
|
"handle.worker_failed",
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_context()
|
||||||
|
|
||||||
|
|
||||||
|
class _CaptureTransport:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._next_id = 1
|
||||||
|
self.last_message: RenderedMessage | None = None
|
||||||
|
|
||||||
|
async def send(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: int | str,
|
||||||
|
message: RenderedMessage,
|
||||||
|
options: SendOptions | None = None,
|
||||||
|
) -> MessageRef:
|
||||||
|
thread_id = options.thread_id if options is not None else None
|
||||||
|
ref = MessageRef(channel_id=channel_id, message_id=self._next_id)
|
||||||
|
self._next_id += 1
|
||||||
|
self.last_message = message
|
||||||
|
return MessageRef(
|
||||||
|
channel_id=ref.channel_id,
|
||||||
|
message_id=ref.message_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def edit(
|
||||||
|
self, *, ref: MessageRef, message: RenderedMessage, wait: bool = True
|
||||||
|
) -> MessageRef:
|
||||||
|
self.last_message = message
|
||||||
|
return ref
|
||||||
|
|
||||||
|
async def delete(self, *, ref: MessageRef) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _TelegramCommandExecutor(CommandExecutor):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
exec_cfg: ExecBridgeConfig,
|
||||||
|
runtime: TransportRuntime,
|
||||||
|
running_tasks: RunningTasks,
|
||||||
|
scheduler: ThreadScheduler,
|
||||||
|
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None,
|
||||||
|
chat_id: int,
|
||||||
|
user_msg_id: int,
|
||||||
|
thread_id: int | None,
|
||||||
|
show_resume_line: bool,
|
||||||
|
stateful_mode: bool,
|
||||||
|
default_engine_override: EngineId | None,
|
||||||
|
) -> None:
|
||||||
|
self._exec_cfg = exec_cfg
|
||||||
|
self._runtime = runtime
|
||||||
|
self._running_tasks = running_tasks
|
||||||
|
self._scheduler = scheduler
|
||||||
|
self._on_thread_known = on_thread_known
|
||||||
|
self._chat_id = chat_id
|
||||||
|
self._user_msg_id = user_msg_id
|
||||||
|
self._thread_id = thread_id
|
||||||
|
self._show_resume_line = show_resume_line
|
||||||
|
self._stateful_mode = stateful_mode
|
||||||
|
self._default_engine_override = default_engine_override
|
||||||
|
self._reply_ref = MessageRef(
|
||||||
|
channel_id=chat_id,
|
||||||
|
message_id=user_msg_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_default_context(self, request: RunRequest) -> RunRequest:
|
||||||
|
if request.context is not None:
|
||||||
|
return request
|
||||||
|
context = self._runtime.default_context_for_chat(self._chat_id)
|
||||||
|
if context is None:
|
||||||
|
return request
|
||||||
|
return RunRequest(
|
||||||
|
prompt=request.prompt,
|
||||||
|
engine=request.engine,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_default_engine(self, request: RunRequest) -> RunRequest:
|
||||||
|
if request.engine is not None or self._default_engine_override is None:
|
||||||
|
return request
|
||||||
|
return RunRequest(
|
||||||
|
prompt=request.prompt,
|
||||||
|
engine=self._default_engine_override,
|
||||||
|
context=request.context,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send(
|
||||||
|
self,
|
||||||
|
message: RenderedMessage | str,
|
||||||
|
*,
|
||||||
|
reply_to: MessageRef | None = None,
|
||||||
|
notify: bool = True,
|
||||||
|
) -> MessageRef | None:
|
||||||
|
rendered = (
|
||||||
|
message
|
||||||
|
if isinstance(message, RenderedMessage)
|
||||||
|
else RenderedMessage(text=message)
|
||||||
|
)
|
||||||
|
reply_ref = self._reply_ref if reply_to is None else reply_to
|
||||||
|
return await self._exec_cfg.transport.send(
|
||||||
|
channel_id=self._chat_id,
|
||||||
|
message=rendered,
|
||||||
|
options=SendOptions(
|
||||||
|
reply_to=reply_ref,
|
||||||
|
notify=notify,
|
||||||
|
thread_id=self._thread_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_one(
|
||||||
|
self, request: RunRequest, *, mode: RunMode = "emit"
|
||||||
|
) -> RunResult:
|
||||||
|
request = self._apply_default_context(request)
|
||||||
|
request = self._apply_default_engine(request)
|
||||||
|
effective_show_resume_line = _should_show_resume_line(
|
||||||
|
show_resume_line=self._show_resume_line,
|
||||||
|
stateful_mode=self._stateful_mode,
|
||||||
|
context=request.context,
|
||||||
|
)
|
||||||
|
engine = self._runtime.resolve_engine(
|
||||||
|
engine_override=request.engine,
|
||||||
|
context=request.context,
|
||||||
|
)
|
||||||
|
on_thread_known = (
|
||||||
|
self._scheduler.note_thread_known
|
||||||
|
if self._on_thread_known is None
|
||||||
|
else self._on_thread_known
|
||||||
|
)
|
||||||
|
if mode == "capture":
|
||||||
|
capture = _CaptureTransport()
|
||||||
|
exec_cfg = ExecBridgeConfig(
|
||||||
|
transport=capture,
|
||||||
|
presenter=self._exec_cfg.presenter,
|
||||||
|
final_notify=False,
|
||||||
|
)
|
||||||
|
await _run_engine(
|
||||||
|
exec_cfg=exec_cfg,
|
||||||
|
runtime=self._runtime,
|
||||||
|
running_tasks={},
|
||||||
|
chat_id=self._chat_id,
|
||||||
|
user_msg_id=self._user_msg_id,
|
||||||
|
text=request.prompt,
|
||||||
|
resume_token=None,
|
||||||
|
context=request.context,
|
||||||
|
reply_ref=self._reply_ref,
|
||||||
|
on_thread_known=on_thread_known,
|
||||||
|
engine_override=engine,
|
||||||
|
thread_id=self._thread_id,
|
||||||
|
show_resume_line=effective_show_resume_line,
|
||||||
|
)
|
||||||
|
return RunResult(engine=engine, message=capture.last_message)
|
||||||
|
await _run_engine(
|
||||||
|
exec_cfg=self._exec_cfg,
|
||||||
|
runtime=self._runtime,
|
||||||
|
running_tasks=self._running_tasks,
|
||||||
|
chat_id=self._chat_id,
|
||||||
|
user_msg_id=self._user_msg_id,
|
||||||
|
text=request.prompt,
|
||||||
|
resume_token=None,
|
||||||
|
context=request.context,
|
||||||
|
reply_ref=self._reply_ref,
|
||||||
|
on_thread_known=on_thread_known,
|
||||||
|
engine_override=engine,
|
||||||
|
thread_id=self._thread_id,
|
||||||
|
show_resume_line=effective_show_resume_line,
|
||||||
|
)
|
||||||
|
return RunResult(engine=engine, message=None)
|
||||||
|
|
||||||
|
async def run_many(
|
||||||
|
self,
|
||||||
|
requests: Sequence[RunRequest],
|
||||||
|
*,
|
||||||
|
mode: RunMode = "emit",
|
||||||
|
parallel: bool = False,
|
||||||
|
) -> list[RunResult]:
|
||||||
|
if not parallel:
|
||||||
|
return [await self.run_one(request, mode=mode) for request in requests]
|
||||||
|
results: list[RunResult | None] = [None] * len(requests)
|
||||||
|
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
|
||||||
|
async def run_idx(idx: int, request: RunRequest) -> None:
|
||||||
|
results[idx] = await self.run_one(request, mode=mode)
|
||||||
|
|
||||||
|
for idx, request in enumerate(requests):
|
||||||
|
tg.start_soon(run_idx, idx, request)
|
||||||
|
|
||||||
|
return [result for result in results if result is not None]
|
||||||
@@ -0,0 +1,589 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...config import ConfigError
|
||||||
|
from ...context import RunContext
|
||||||
|
from ...directives import DirectiveError
|
||||||
|
from ...transport_runtime import ResolvedMessage
|
||||||
|
from ..context import _format_context
|
||||||
|
from ..files import (
|
||||||
|
default_upload_name,
|
||||||
|
default_upload_path,
|
||||||
|
deny_reason,
|
||||||
|
format_bytes,
|
||||||
|
normalize_relative_path,
|
||||||
|
parse_file_command,
|
||||||
|
parse_file_prompt,
|
||||||
|
resolve_path_within_root,
|
||||||
|
write_bytes_atomic,
|
||||||
|
ZipTooLargeError,
|
||||||
|
zip_directory,
|
||||||
|
)
|
||||||
|
from ..topic_state import TopicStateStore
|
||||||
|
from ..topics import _maybe_update_topic_context, _topic_key
|
||||||
|
from ..types import TelegramDocument, TelegramIncomingMessage
|
||||||
|
from .reply import make_reply
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..bridge import TelegramBridgeConfig
|
||||||
|
|
||||||
|
FILE_PUT_USAGE = "usage: `/file put <path>`"
|
||||||
|
FILE_GET_USAGE = "usage: `/file get <path>`"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _FilePutPlan:
|
||||||
|
resolved: ResolvedMessage
|
||||||
|
run_root: Path
|
||||||
|
path_value: str | None
|
||||||
|
force: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _FilePutResult:
|
||||||
|
name: str
|
||||||
|
rel_path: Path | None
|
||||||
|
size: int | None
|
||||||
|
error: str | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _SavedFilePut:
|
||||||
|
context: RunContext | None
|
||||||
|
rel_path: Path
|
||||||
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _SavedFilePutGroup:
|
||||||
|
context: RunContext | None
|
||||||
|
base_dir: Path | None
|
||||||
|
saved: list[_FilePutResult]
|
||||||
|
failed: list[_FilePutResult]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_file_put_paths(
|
||||||
|
plan: _FilePutPlan,
|
||||||
|
*,
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
require_dir: bool,
|
||||||
|
) -> tuple[Path | None, Path | None, str | None]:
|
||||||
|
path_value = plan.path_value
|
||||||
|
if not path_value:
|
||||||
|
return None, None, None
|
||||||
|
if require_dir or path_value.endswith("/"):
|
||||||
|
base_dir = normalize_relative_path(path_value)
|
||||||
|
if base_dir is None:
|
||||||
|
return None, None, "invalid upload path."
|
||||||
|
deny_rule = deny_reason(base_dir, cfg.files.deny_globs)
|
||||||
|
if deny_rule is not None:
|
||||||
|
return None, None, f"path denied by rule: {deny_rule}"
|
||||||
|
base_target = resolve_path_within_root(plan.run_root, base_dir)
|
||||||
|
if base_target is None:
|
||||||
|
return None, None, "upload path escapes the repo root."
|
||||||
|
if base_target.exists() and not base_target.is_dir():
|
||||||
|
return None, None, "upload path is a file."
|
||||||
|
return base_dir, None, None
|
||||||
|
rel_path = normalize_relative_path(path_value)
|
||||||
|
if rel_path is None:
|
||||||
|
return None, None, "invalid upload path."
|
||||||
|
return None, rel_path, None
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_file_permissions(
|
||||||
|
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
|
||||||
|
) -> bool:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
sender_id = msg.sender_id
|
||||||
|
if sender_id is None:
|
||||||
|
await reply(text="cannot verify sender for file transfer.")
|
||||||
|
return False
|
||||||
|
if cfg.files.allowed_user_ids:
|
||||||
|
if sender_id not in cfg.files.allowed_user_ids:
|
||||||
|
await reply(text="file transfer is not allowed for this user.")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
is_private = msg.chat_type == "private"
|
||||||
|
if msg.chat_type is None:
|
||||||
|
is_private = msg.chat_id > 0
|
||||||
|
if is_private:
|
||||||
|
return True
|
||||||
|
member = await cfg.bot.get_chat_member(msg.chat_id, sender_id)
|
||||||
|
if member is None:
|
||||||
|
await reply(text="failed to verify file transfer permissions.")
|
||||||
|
return False
|
||||||
|
if member.status in {"creator", "administrator"}:
|
||||||
|
return True
|
||||||
|
await reply(text="file transfer is restricted to group admins.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _prepare_file_put_plan(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
args_text: str,
|
||||||
|
ambient_context: RunContext | None,
|
||||||
|
topic_store: TopicStateStore | None,
|
||||||
|
) -> _FilePutPlan | None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
if not await _check_file_permissions(cfg, msg):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
resolved = cfg.runtime.resolve_message(
|
||||||
|
text=args_text,
|
||||||
|
reply_text=msg.reply_to_text,
|
||||||
|
ambient_context=ambient_context,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
)
|
||||||
|
except DirectiveError as exc:
|
||||||
|
await reply(text=f"error:\n{exc}")
|
||||||
|
return None
|
||||||
|
topic_key = _topic_key(msg, cfg) if topic_store is not None else None
|
||||||
|
await _maybe_update_topic_context(
|
||||||
|
cfg=cfg,
|
||||||
|
topic_store=topic_store,
|
||||||
|
topic_key=topic_key,
|
||||||
|
context=resolved.context,
|
||||||
|
context_source=resolved.context_source,
|
||||||
|
)
|
||||||
|
if resolved.context is None or resolved.context.project is None:
|
||||||
|
await reply(text="no project context available for file upload.")
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
run_root = cfg.runtime.resolve_run_cwd(resolved.context)
|
||||||
|
except ConfigError as exc:
|
||||||
|
await reply(text=f"error:\n{exc}")
|
||||||
|
return None
|
||||||
|
if run_root is None:
|
||||||
|
await reply(text="no project context available for file upload.")
|
||||||
|
return None
|
||||||
|
path_value, force, error = parse_file_prompt(resolved.prompt, allow_empty=True)
|
||||||
|
if error is not None:
|
||||||
|
await reply(text=error)
|
||||||
|
return None
|
||||||
|
return _FilePutPlan(
|
||||||
|
resolved=resolved,
|
||||||
|
run_root=run_root,
|
||||||
|
path_value=path_value,
|
||||||
|
force=force,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_file_put_failures(failed: Sequence[_FilePutResult]) -> str | None:
|
||||||
|
if not failed:
|
||||||
|
return None
|
||||||
|
errors = ", ".join(
|
||||||
|
f"`{item.name}` ({item.error})" for item in failed if item.error is not None
|
||||||
|
)
|
||||||
|
if not errors:
|
||||||
|
return None
|
||||||
|
return f"failed: {errors}"
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_document_payload(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
*,
|
||||||
|
document: TelegramDocument,
|
||||||
|
run_root: Path,
|
||||||
|
rel_path: Path | None,
|
||||||
|
base_dir: Path | None,
|
||||||
|
force: bool,
|
||||||
|
) -> _FilePutResult:
|
||||||
|
name = default_upload_name(document.file_name, None)
|
||||||
|
if (
|
||||||
|
document.file_size is not None
|
||||||
|
and document.file_size > cfg.files.max_upload_bytes
|
||||||
|
):
|
||||||
|
return _FilePutResult(
|
||||||
|
name=name,
|
||||||
|
rel_path=None,
|
||||||
|
size=None,
|
||||||
|
error="file is too large to upload.",
|
||||||
|
)
|
||||||
|
file_info = await cfg.bot.get_file(document.file_id)
|
||||||
|
if file_info is None:
|
||||||
|
return _FilePutResult(
|
||||||
|
name=name,
|
||||||
|
rel_path=None,
|
||||||
|
size=None,
|
||||||
|
error="failed to fetch file metadata.",
|
||||||
|
)
|
||||||
|
file_path = file_info.file_path
|
||||||
|
name = default_upload_name(document.file_name, file_path)
|
||||||
|
resolved_path = rel_path
|
||||||
|
if resolved_path is None:
|
||||||
|
if base_dir is None:
|
||||||
|
resolved_path = default_upload_path(
|
||||||
|
cfg.files.uploads_dir, document.file_name, file_path
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
resolved_path = base_dir / name
|
||||||
|
deny_rule = deny_reason(resolved_path, cfg.files.deny_globs)
|
||||||
|
if deny_rule is not None:
|
||||||
|
return _FilePutResult(
|
||||||
|
name=name,
|
||||||
|
rel_path=None,
|
||||||
|
size=None,
|
||||||
|
error=f"path denied by rule: {deny_rule}",
|
||||||
|
)
|
||||||
|
target = resolve_path_within_root(run_root, resolved_path)
|
||||||
|
if target is None:
|
||||||
|
return _FilePutResult(
|
||||||
|
name=name,
|
||||||
|
rel_path=None,
|
||||||
|
size=None,
|
||||||
|
error="upload path escapes the repo root.",
|
||||||
|
)
|
||||||
|
if target.exists():
|
||||||
|
if target.is_dir():
|
||||||
|
return _FilePutResult(
|
||||||
|
name=name,
|
||||||
|
rel_path=None,
|
||||||
|
size=None,
|
||||||
|
error="upload target is a directory.",
|
||||||
|
)
|
||||||
|
if not force:
|
||||||
|
return _FilePutResult(
|
||||||
|
name=name,
|
||||||
|
rel_path=None,
|
||||||
|
size=None,
|
||||||
|
error="file already exists; use --force to overwrite.",
|
||||||
|
)
|
||||||
|
payload = await cfg.bot.download_file(file_path)
|
||||||
|
if payload is None:
|
||||||
|
return _FilePutResult(
|
||||||
|
name=name,
|
||||||
|
rel_path=None,
|
||||||
|
size=None,
|
||||||
|
error="failed to download file.",
|
||||||
|
)
|
||||||
|
if len(payload) > cfg.files.max_upload_bytes:
|
||||||
|
return _FilePutResult(
|
||||||
|
name=name,
|
||||||
|
rel_path=None,
|
||||||
|
size=None,
|
||||||
|
error="file is too large to upload.",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
write_bytes_atomic(target, payload)
|
||||||
|
except OSError as exc:
|
||||||
|
return _FilePutResult(
|
||||||
|
name=name,
|
||||||
|
rel_path=None,
|
||||||
|
size=None,
|
||||||
|
error=f"failed to write file: {exc}",
|
||||||
|
)
|
||||||
|
return _FilePutResult(
|
||||||
|
name=name,
|
||||||
|
rel_path=resolved_path,
|
||||||
|
size=len(payload),
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_file_command(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
args_text: str,
|
||||||
|
ambient_context: RunContext | None,
|
||||||
|
topic_store: TopicStateStore | None,
|
||||||
|
) -> None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
command, rest, error = parse_file_command(args_text)
|
||||||
|
if error is not None:
|
||||||
|
await reply(text=error)
|
||||||
|
return
|
||||||
|
if command == "put":
|
||||||
|
await _handle_file_put(cfg, msg, rest, ambient_context, topic_store)
|
||||||
|
else:
|
||||||
|
await _handle_file_get(cfg, msg, rest, ambient_context, topic_store)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_file_put_default(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
ambient_context: RunContext | None,
|
||||||
|
topic_store: TopicStateStore | None,
|
||||||
|
) -> None:
|
||||||
|
await _handle_file_put(cfg, msg, "", ambient_context, topic_store)
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_file_put(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
args_text: str,
|
||||||
|
ambient_context: RunContext | None,
|
||||||
|
topic_store: TopicStateStore | None,
|
||||||
|
) -> _SavedFilePut | None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
document = msg.document
|
||||||
|
if document is None:
|
||||||
|
await reply(text=FILE_PUT_USAGE)
|
||||||
|
return None
|
||||||
|
plan = await _prepare_file_put_plan(
|
||||||
|
cfg,
|
||||||
|
msg,
|
||||||
|
args_text,
|
||||||
|
ambient_context,
|
||||||
|
topic_store,
|
||||||
|
)
|
||||||
|
if plan is None:
|
||||||
|
return None
|
||||||
|
base_dir, rel_path, error = resolve_file_put_paths(
|
||||||
|
plan,
|
||||||
|
cfg=cfg,
|
||||||
|
require_dir=False,
|
||||||
|
)
|
||||||
|
if error is not None:
|
||||||
|
await reply(text=error)
|
||||||
|
return None
|
||||||
|
result = await _save_document_payload(
|
||||||
|
cfg,
|
||||||
|
document=document,
|
||||||
|
run_root=plan.run_root,
|
||||||
|
rel_path=rel_path,
|
||||||
|
base_dir=base_dir,
|
||||||
|
force=plan.force,
|
||||||
|
)
|
||||||
|
if result.error is not None:
|
||||||
|
await reply(text=result.error)
|
||||||
|
return None
|
||||||
|
if result.rel_path is None or result.size is None:
|
||||||
|
await reply(text="failed to save file.")
|
||||||
|
return None
|
||||||
|
return _SavedFilePut(
|
||||||
|
context=plan.resolved.context,
|
||||||
|
rel_path=result.rel_path,
|
||||||
|
size=result.size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_file_put(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
args_text: str,
|
||||||
|
ambient_context: RunContext | None,
|
||||||
|
topic_store: TopicStateStore | None,
|
||||||
|
) -> None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
saved = await _save_file_put(
|
||||||
|
cfg,
|
||||||
|
msg,
|
||||||
|
args_text,
|
||||||
|
ambient_context,
|
||||||
|
topic_store,
|
||||||
|
)
|
||||||
|
if saved is None:
|
||||||
|
return
|
||||||
|
context_label = _format_context(cfg.runtime, saved.context)
|
||||||
|
await reply(
|
||||||
|
text=(
|
||||||
|
f"saved `{saved.rel_path.as_posix()}` "
|
||||||
|
f"in `{context_label}` ({format_bytes(saved.size)})"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_file_put_group(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
args_text: str,
|
||||||
|
messages: Sequence[TelegramIncomingMessage],
|
||||||
|
ambient_context: RunContext | None,
|
||||||
|
topic_store: TopicStateStore | None,
|
||||||
|
) -> None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
saved_group = await _save_file_put_group(
|
||||||
|
cfg,
|
||||||
|
msg,
|
||||||
|
args_text,
|
||||||
|
messages,
|
||||||
|
ambient_context,
|
||||||
|
topic_store,
|
||||||
|
)
|
||||||
|
if saved_group is None:
|
||||||
|
return
|
||||||
|
context_label = _format_context(cfg.runtime, saved_group.context)
|
||||||
|
total_bytes = sum(item.size or 0 for item in saved_group.saved)
|
||||||
|
dir_label: Path | None = saved_group.base_dir
|
||||||
|
if dir_label is None and saved_group.saved:
|
||||||
|
first_path = saved_group.saved[0].rel_path
|
||||||
|
if first_path is not None:
|
||||||
|
dir_label = first_path.parent
|
||||||
|
if saved_group.saved:
|
||||||
|
saved_names = ", ".join(f"`{item.name}`" for item in saved_group.saved)
|
||||||
|
if dir_label is not None:
|
||||||
|
dir_text = dir_label.as_posix()
|
||||||
|
if not dir_text.endswith("/"):
|
||||||
|
dir_text = f"{dir_text}/"
|
||||||
|
text = (
|
||||||
|
f"saved {saved_names} to `{dir_text}` "
|
||||||
|
f"in `{context_label}` ({format_bytes(total_bytes)})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text = (
|
||||||
|
f"saved {saved_names} in `{context_label}` "
|
||||||
|
f"({format_bytes(total_bytes)})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text = "failed to upload files."
|
||||||
|
failure_text = _format_file_put_failures(saved_group.failed)
|
||||||
|
if failure_text is not None:
|
||||||
|
text = f"{text}\n\n{failure_text}"
|
||||||
|
await reply(text=text)
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_file_put_group(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
args_text: str,
|
||||||
|
messages: Sequence[TelegramIncomingMessage],
|
||||||
|
ambient_context: RunContext | None,
|
||||||
|
topic_store: TopicStateStore | None,
|
||||||
|
) -> _SavedFilePutGroup | None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
documents = [item.document for item in messages if item.document is not None]
|
||||||
|
if not documents:
|
||||||
|
await reply(text=FILE_PUT_USAGE)
|
||||||
|
return None
|
||||||
|
plan = await _prepare_file_put_plan(
|
||||||
|
cfg,
|
||||||
|
msg,
|
||||||
|
args_text,
|
||||||
|
ambient_context,
|
||||||
|
topic_store,
|
||||||
|
)
|
||||||
|
if plan is None:
|
||||||
|
return None
|
||||||
|
base_dir, _, error = resolve_file_put_paths(
|
||||||
|
plan,
|
||||||
|
cfg=cfg,
|
||||||
|
require_dir=True,
|
||||||
|
)
|
||||||
|
if error is not None:
|
||||||
|
await reply(text=error)
|
||||||
|
return None
|
||||||
|
saved: list[_FilePutResult] = []
|
||||||
|
failed: list[_FilePutResult] = []
|
||||||
|
for document in documents:
|
||||||
|
result = await _save_document_payload(
|
||||||
|
cfg,
|
||||||
|
document=document,
|
||||||
|
run_root=plan.run_root,
|
||||||
|
rel_path=None,
|
||||||
|
base_dir=base_dir,
|
||||||
|
force=plan.force,
|
||||||
|
)
|
||||||
|
if result.error is None:
|
||||||
|
saved.append(result)
|
||||||
|
else:
|
||||||
|
failed.append(result)
|
||||||
|
return _SavedFilePutGroup(
|
||||||
|
context=plan.resolved.context,
|
||||||
|
base_dir=base_dir,
|
||||||
|
saved=saved,
|
||||||
|
failed=failed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_file_get(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
args_text: str,
|
||||||
|
ambient_context: RunContext | None,
|
||||||
|
topic_store: TopicStateStore | None,
|
||||||
|
) -> None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
if not await _check_file_permissions(cfg, msg):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
resolved = cfg.runtime.resolve_message(
|
||||||
|
text=args_text,
|
||||||
|
reply_text=msg.reply_to_text,
|
||||||
|
ambient_context=ambient_context,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
)
|
||||||
|
except DirectiveError as exc:
|
||||||
|
await reply(text=f"error:\n{exc}")
|
||||||
|
return
|
||||||
|
topic_key = _topic_key(msg, cfg) if topic_store is not None else None
|
||||||
|
await _maybe_update_topic_context(
|
||||||
|
cfg=cfg,
|
||||||
|
topic_store=topic_store,
|
||||||
|
topic_key=topic_key,
|
||||||
|
context=resolved.context,
|
||||||
|
context_source=resolved.context_source,
|
||||||
|
)
|
||||||
|
if resolved.context is None or resolved.context.project is None:
|
||||||
|
await reply(text="no project context available for file download.")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
run_root = cfg.runtime.resolve_run_cwd(resolved.context)
|
||||||
|
except ConfigError as exc:
|
||||||
|
await reply(text=f"error:\n{exc}")
|
||||||
|
return
|
||||||
|
if run_root is None:
|
||||||
|
await reply(text="no project context available for file download.")
|
||||||
|
return
|
||||||
|
path_value = resolved.prompt
|
||||||
|
if not path_value.strip():
|
||||||
|
await reply(text=FILE_GET_USAGE)
|
||||||
|
return
|
||||||
|
rel_path = normalize_relative_path(path_value)
|
||||||
|
if rel_path is None:
|
||||||
|
await reply(text="invalid download path.")
|
||||||
|
return
|
||||||
|
deny_rule = deny_reason(rel_path, cfg.files.deny_globs)
|
||||||
|
if deny_rule is not None:
|
||||||
|
await reply(text=f"path denied by rule: {deny_rule}")
|
||||||
|
return
|
||||||
|
target = resolve_path_within_root(run_root, rel_path)
|
||||||
|
if target is None:
|
||||||
|
await reply(text="download path escapes the repo root.")
|
||||||
|
return
|
||||||
|
if not target.exists():
|
||||||
|
await reply(text="file does not exist.")
|
||||||
|
return
|
||||||
|
if target.is_dir():
|
||||||
|
try:
|
||||||
|
payload = zip_directory(
|
||||||
|
run_root,
|
||||||
|
rel_path,
|
||||||
|
cfg.files.deny_globs,
|
||||||
|
max_bytes=cfg.files.max_download_bytes,
|
||||||
|
)
|
||||||
|
except ZipTooLargeError:
|
||||||
|
await reply(text="file is too large to send.")
|
||||||
|
return
|
||||||
|
except OSError as exc:
|
||||||
|
await reply(text=f"failed to read directory: {exc}")
|
||||||
|
return
|
||||||
|
filename = f"{rel_path.name or 'archive'}.zip"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
size = target.stat().st_size
|
||||||
|
if size > cfg.files.max_download_bytes:
|
||||||
|
await reply(text="file is too large to send.")
|
||||||
|
return
|
||||||
|
payload = target.read_bytes()
|
||||||
|
except OSError as exc:
|
||||||
|
await reply(text=f"failed to read file: {exc}")
|
||||||
|
return
|
||||||
|
filename = target.name
|
||||||
|
if len(payload) > cfg.files.max_download_bytes:
|
||||||
|
await reply(text="file is too large to send.")
|
||||||
|
return
|
||||||
|
sent = await cfg.bot.send_document(
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
filename=filename,
|
||||||
|
content=payload,
|
||||||
|
reply_to_message_id=msg.message_id,
|
||||||
|
message_thread_id=msg.thread_id,
|
||||||
|
)
|
||||||
|
if sent is None:
|
||||||
|
await reply(text="failed to send file.")
|
||||||
|
return
|
||||||
@@ -0,0 +1,143 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable, Sequence
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...context import RunContext
|
||||||
|
from ...directives import DirectiveError
|
||||||
|
from ...transport_runtime import ResolvedMessage
|
||||||
|
from ..context import _merge_topic_context
|
||||||
|
from ..files import parse_file_command
|
||||||
|
from ..topic_state import TopicStateStore
|
||||||
|
from ..topics import _topic_key, _topics_chat_project
|
||||||
|
from ..types import TelegramIncomingMessage
|
||||||
|
from .file_transfer import (
|
||||||
|
FILE_PUT_USAGE,
|
||||||
|
_format_file_put_failures,
|
||||||
|
_handle_file_put_group,
|
||||||
|
_save_file_put_group,
|
||||||
|
)
|
||||||
|
from .parse import _parse_slash_command
|
||||||
|
from .reply import make_reply
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..bridge import TelegramBridgeConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_media_group(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
messages: Sequence[TelegramIncomingMessage],
|
||||||
|
topic_store: TopicStateStore | None,
|
||||||
|
run_prompt: Callable[
|
||||||
|
[TelegramIncomingMessage, str, ResolvedMessage], Awaitable[None]
|
||||||
|
]
|
||||||
|
| None = None,
|
||||||
|
resolve_prompt: Callable[
|
||||||
|
[TelegramIncomingMessage, str, RunContext | None],
|
||||||
|
Awaitable[ResolvedMessage | None],
|
||||||
|
]
|
||||||
|
| None = None,
|
||||||
|
) -> None:
|
||||||
|
if not messages:
|
||||||
|
return
|
||||||
|
ordered = sorted(messages, key=lambda item: item.message_id)
|
||||||
|
command_msg = next(
|
||||||
|
(item for item in ordered if item.text.strip()),
|
||||||
|
ordered[0],
|
||||||
|
)
|
||||||
|
reply = make_reply(cfg, command_msg)
|
||||||
|
topic_key = _topic_key(command_msg, cfg) if topic_store is not None else None
|
||||||
|
chat_project = _topics_chat_project(cfg, command_msg.chat_id)
|
||||||
|
bound_context = (
|
||||||
|
await topic_store.get_context(*topic_key)
|
||||||
|
if topic_store is not None and topic_key is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
ambient_context = _merge_topic_context(
|
||||||
|
chat_project=chat_project, bound=bound_context
|
||||||
|
)
|
||||||
|
command_id, args_text = _parse_slash_command(command_msg.text)
|
||||||
|
if command_id == "file":
|
||||||
|
command, rest, error = parse_file_command(args_text)
|
||||||
|
if error is not None:
|
||||||
|
await reply(text=error)
|
||||||
|
return
|
||||||
|
if command == "put":
|
||||||
|
await _handle_file_put_group(
|
||||||
|
cfg,
|
||||||
|
command_msg,
|
||||||
|
rest,
|
||||||
|
ordered,
|
||||||
|
ambient_context,
|
||||||
|
topic_store,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if cfg.files.enabled and cfg.files.auto_put:
|
||||||
|
caption_text = command_msg.text.strip()
|
||||||
|
if cfg.files.auto_put_mode == "prompt" and caption_text:
|
||||||
|
if resolve_prompt is None:
|
||||||
|
try:
|
||||||
|
resolved = cfg.runtime.resolve_message(
|
||||||
|
text=caption_text,
|
||||||
|
reply_text=command_msg.reply_to_text,
|
||||||
|
ambient_context=ambient_context,
|
||||||
|
chat_id=command_msg.chat_id,
|
||||||
|
)
|
||||||
|
except DirectiveError as exc:
|
||||||
|
await reply(text=f"error:\n{exc}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
resolved = await resolve_prompt(
|
||||||
|
command_msg, caption_text, ambient_context
|
||||||
|
)
|
||||||
|
if resolved is None:
|
||||||
|
return
|
||||||
|
saved_group = await _save_file_put_group(
|
||||||
|
cfg,
|
||||||
|
command_msg,
|
||||||
|
"",
|
||||||
|
ordered,
|
||||||
|
resolved.context,
|
||||||
|
topic_store,
|
||||||
|
)
|
||||||
|
if saved_group is None:
|
||||||
|
return
|
||||||
|
if not saved_group.saved:
|
||||||
|
failure_text = _format_file_put_failures(saved_group.failed)
|
||||||
|
text = "failed to upload files."
|
||||||
|
if failure_text is not None:
|
||||||
|
text = f"{text}\n\n{failure_text}"
|
||||||
|
await reply(text=text)
|
||||||
|
return
|
||||||
|
if saved_group.failed:
|
||||||
|
failure_text = _format_file_put_failures(saved_group.failed)
|
||||||
|
if failure_text is not None:
|
||||||
|
await reply(text=f"some files failed to upload.\n\n{failure_text}")
|
||||||
|
if run_prompt is None:
|
||||||
|
await reply(text=FILE_PUT_USAGE)
|
||||||
|
return
|
||||||
|
paths = [
|
||||||
|
item.rel_path.as_posix()
|
||||||
|
for item in saved_group.saved
|
||||||
|
if item.rel_path is not None
|
||||||
|
]
|
||||||
|
files_text = "\n".join(f"- {path}" for path in paths)
|
||||||
|
prompt_base = resolved.prompt
|
||||||
|
annotation = f"[uploaded files]\n{files_text}"
|
||||||
|
if prompt_base and prompt_base.strip():
|
||||||
|
prompt = f"{prompt_base}\n\n{annotation}"
|
||||||
|
else:
|
||||||
|
prompt = annotation
|
||||||
|
await run_prompt(command_msg, prompt, resolved)
|
||||||
|
return
|
||||||
|
if not caption_text:
|
||||||
|
await _handle_file_put_group(
|
||||||
|
cfg,
|
||||||
|
command_msg,
|
||||||
|
"",
|
||||||
|
ordered,
|
||||||
|
ambient_context,
|
||||||
|
topic_store,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
await reply(text=FILE_PUT_USAGE)
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...commands import get_command
|
||||||
|
from ...config import ConfigError
|
||||||
|
from ...ids import RESERVED_COMMAND_IDS, is_valid_id
|
||||||
|
from ...logging import get_logger
|
||||||
|
from ...plugins import COMMAND_GROUP, list_entrypoints
|
||||||
|
from ...transport_runtime import TransportRuntime
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..bridge import TelegramBridgeConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
_MAX_BOT_COMMANDS = 100
|
||||||
|
|
||||||
|
|
||||||
|
def build_bot_commands(
|
||||||
|
runtime: TransportRuntime, *, include_file: bool = True
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
commands: list[dict[str, str]] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for engine_id in runtime.available_engine_ids():
|
||||||
|
cmd = engine_id.lower()
|
||||||
|
if cmd in seen:
|
||||||
|
continue
|
||||||
|
commands.append({"command": cmd, "description": f"use agent: {cmd}"})
|
||||||
|
seen.add(cmd)
|
||||||
|
for alias in runtime.project_aliases():
|
||||||
|
cmd = alias.lower()
|
||||||
|
if cmd in seen:
|
||||||
|
continue
|
||||||
|
if not is_valid_id(cmd):
|
||||||
|
logger.debug(
|
||||||
|
"startup.command_menu.skip_project",
|
||||||
|
alias=alias,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
commands.append({"command": cmd, "description": f"work on: {cmd}"})
|
||||||
|
seen.add(cmd)
|
||||||
|
allowlist = runtime.allowlist
|
||||||
|
for ep in list_entrypoints(
|
||||||
|
COMMAND_GROUP,
|
||||||
|
allowlist=allowlist,
|
||||||
|
reserved_ids=RESERVED_COMMAND_IDS,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
backend = get_command(ep.name, allowlist=allowlist)
|
||||||
|
except ConfigError as exc:
|
||||||
|
logger.info(
|
||||||
|
"startup.command_menu.skip_command",
|
||||||
|
command=ep.name,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
cmd = backend.id.lower()
|
||||||
|
if cmd in seen:
|
||||||
|
continue
|
||||||
|
if not is_valid_id(cmd):
|
||||||
|
logger.debug(
|
||||||
|
"startup.command_menu.skip_command_id",
|
||||||
|
command=cmd,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
description = backend.description or f"command: {cmd}"
|
||||||
|
commands.append({"command": cmd, "description": description})
|
||||||
|
seen.add(cmd)
|
||||||
|
if include_file and "file" not in seen:
|
||||||
|
commands.append({"command": "file", "description": "upload or fetch files"})
|
||||||
|
seen.add("file")
|
||||||
|
if "cancel" not in seen:
|
||||||
|
commands.append({"command": "cancel", "description": "cancel run"})
|
||||||
|
if len(commands) > _MAX_BOT_COMMANDS:
|
||||||
|
logger.warning(
|
||||||
|
"startup.command_menu.too_many",
|
||||||
|
count=len(commands),
|
||||||
|
limit=_MAX_BOT_COMMANDS,
|
||||||
|
)
|
||||||
|
commands = commands[:_MAX_BOT_COMMANDS]
|
||||||
|
if not any(cmd["command"] == "cancel" for cmd in commands):
|
||||||
|
commands[-1] = {"command": "cancel", "description": "cancel run"}
|
||||||
|
return commands
|
||||||
|
|
||||||
|
|
||||||
|
def _reserved_commands(runtime: TransportRuntime) -> set[str]:
|
||||||
|
return {
|
||||||
|
*{engine.lower() for engine in runtime.engine_ids},
|
||||||
|
*{alias.lower() for alias in runtime.project_aliases()},
|
||||||
|
*RESERVED_COMMAND_IDS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _set_command_menu(cfg: TelegramBridgeConfig) -> None:
|
||||||
|
commands = build_bot_commands(cfg.runtime, include_file=cfg.files.enabled)
|
||||||
|
if not commands:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
ok = await cfg.bot.set_my_commands(commands)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.info(
|
||||||
|
"startup.command_menu.failed",
|
||||||
|
error=str(exc),
|
||||||
|
error_type=exc.__class__.__name__,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if not ok:
|
||||||
|
logger.info("startup.command_menu.rejected")
|
||||||
|
return
|
||||||
|
logger.info(
|
||||||
|
"startup.command_menu.updated",
|
||||||
|
commands=[cmd["command"] for cmd in commands],
|
||||||
|
)
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def is_cancel_command(text: str) -> bool:
|
||||||
|
stripped = text.strip()
|
||||||
|
if not stripped:
|
||||||
|
return False
|
||||||
|
command = stripped.split(maxsplit=1)[0]
|
||||||
|
return command == "/cancel" or command.startswith("/cancel@")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_slash_command(text: str) -> tuple[str | None, str]:
|
||||||
|
stripped = text.lstrip()
|
||||||
|
if not stripped.startswith("/"):
|
||||||
|
return None, text
|
||||||
|
lines = stripped.splitlines()
|
||||||
|
if not lines:
|
||||||
|
return None, text
|
||||||
|
first_line = lines[0]
|
||||||
|
token, _, rest = first_line.partition(" ")
|
||||||
|
command = token[1:]
|
||||||
|
if not command:
|
||||||
|
return None, text
|
||||||
|
if "@" in command:
|
||||||
|
command = command.split("@", 1)[0]
|
||||||
|
args_text = rest
|
||||||
|
if len(lines) > 1:
|
||||||
|
tail = "\n".join(lines[1:])
|
||||||
|
args_text = f"{args_text}\n{tail}" if args_text else tail
|
||||||
|
return command.lower(), args_text
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ..bridge import send_plain
|
||||||
|
from ..types import TelegramIncomingMessage
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..bridge import TelegramBridgeConfig
|
||||||
|
|
||||||
|
|
||||||
|
def make_reply(
|
||||||
|
cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage
|
||||||
|
) -> Callable[..., Awaitable[None]]:
|
||||||
|
return partial(
|
||||||
|
send_plain,
|
||||||
|
cfg.exec_cfg.transport,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
user_msg_id=msg.message_id,
|
||||||
|
thread_id=msg.thread_id,
|
||||||
|
)
|
||||||
@@ -0,0 +1,220 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...markdown import MarkdownParts
|
||||||
|
from ...transport import RenderedMessage, SendOptions
|
||||||
|
from ..chat_sessions import ChatSessionStore
|
||||||
|
from ..context import (
|
||||||
|
_format_context,
|
||||||
|
_format_ctx_status,
|
||||||
|
_merge_topic_context,
|
||||||
|
_parse_project_branch_args,
|
||||||
|
_usage_ctx_set,
|
||||||
|
_usage_topic,
|
||||||
|
)
|
||||||
|
from ..files import split_command_args
|
||||||
|
from ..render import prepare_telegram
|
||||||
|
from ..topic_state import TopicStateStore
|
||||||
|
from ..topics import (
|
||||||
|
_maybe_rename_topic,
|
||||||
|
_topic_key,
|
||||||
|
_topic_title,
|
||||||
|
_topics_chat_project,
|
||||||
|
_topics_command_error,
|
||||||
|
)
|
||||||
|
from ..types import TelegramIncomingMessage
|
||||||
|
from .reply import make_reply
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..bridge import TelegramBridgeConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_ctx_command(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
args_text: str,
|
||||||
|
store: TopicStateStore,
|
||||||
|
*,
|
||||||
|
resolved_scope: str | None = None,
|
||||||
|
scope_chat_ids: frozenset[int] | None = None,
|
||||||
|
) -> None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
error = _topics_command_error(
|
||||||
|
cfg,
|
||||||
|
msg.chat_id,
|
||||||
|
resolved_scope=resolved_scope,
|
||||||
|
scope_chat_ids=scope_chat_ids,
|
||||||
|
)
|
||||||
|
if error is not None:
|
||||||
|
await reply(text=error)
|
||||||
|
return
|
||||||
|
chat_project = _topics_chat_project(cfg, msg.chat_id)
|
||||||
|
tkey = _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
|
||||||
|
if tkey is None:
|
||||||
|
await reply(text="this command only works inside a topic.")
|
||||||
|
return
|
||||||
|
tokens = split_command_args(args_text)
|
||||||
|
action = tokens[0].lower() if tokens else "show"
|
||||||
|
if action in {"show", ""}:
|
||||||
|
snapshot = await store.get_thread(*tkey)
|
||||||
|
bound = snapshot.context if snapshot is not None else None
|
||||||
|
ambient = _merge_topic_context(chat_project=chat_project, bound=bound)
|
||||||
|
resolved = cfg.runtime.resolve_message(
|
||||||
|
text="",
|
||||||
|
reply_text=msg.reply_to_text,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
ambient_context=ambient,
|
||||||
|
)
|
||||||
|
text = _format_ctx_status(
|
||||||
|
cfg=cfg,
|
||||||
|
runtime=cfg.runtime,
|
||||||
|
bound=bound,
|
||||||
|
resolved=resolved.context,
|
||||||
|
context_source=resolved.context_source,
|
||||||
|
snapshot=snapshot,
|
||||||
|
chat_project=chat_project,
|
||||||
|
)
|
||||||
|
await reply(text=text)
|
||||||
|
return
|
||||||
|
if action == "set":
|
||||||
|
rest = " ".join(tokens[1:])
|
||||||
|
context, error = _parse_project_branch_args(
|
||||||
|
rest,
|
||||||
|
runtime=cfg.runtime,
|
||||||
|
require_branch=False,
|
||||||
|
chat_project=chat_project,
|
||||||
|
)
|
||||||
|
if error is not None:
|
||||||
|
await reply(
|
||||||
|
text=f"error:\n{error}\n{_usage_ctx_set(chat_project=chat_project)}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if context is None:
|
||||||
|
await reply(
|
||||||
|
text=f"error:\n{_usage_ctx_set(chat_project=chat_project)}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
await store.set_context(*tkey, context)
|
||||||
|
await _maybe_rename_topic(
|
||||||
|
cfg,
|
||||||
|
store,
|
||||||
|
chat_id=tkey[0],
|
||||||
|
thread_id=tkey[1],
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
await reply(
|
||||||
|
text=f"topic bound to `{_format_context(cfg.runtime, context)}`",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if action == "clear":
|
||||||
|
await store.clear_context(*tkey)
|
||||||
|
await reply(text="topic binding cleared.")
|
||||||
|
return
|
||||||
|
await reply(
|
||||||
|
text="unknown `/ctx` command. use `/ctx`, `/ctx set`, or `/ctx clear`.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_new_command(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
store: TopicStateStore,
|
||||||
|
*,
|
||||||
|
resolved_scope: str | None = None,
|
||||||
|
scope_chat_ids: frozenset[int] | None = None,
|
||||||
|
) -> None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
error = _topics_command_error(
|
||||||
|
cfg,
|
||||||
|
msg.chat_id,
|
||||||
|
resolved_scope=resolved_scope,
|
||||||
|
scope_chat_ids=scope_chat_ids,
|
||||||
|
)
|
||||||
|
if error is not None:
|
||||||
|
await reply(text=error)
|
||||||
|
return
|
||||||
|
tkey = _topic_key(msg, cfg, scope_chat_ids=scope_chat_ids)
|
||||||
|
if tkey is None:
|
||||||
|
await reply(text="this command only works inside a topic.")
|
||||||
|
return
|
||||||
|
await store.clear_sessions(*tkey)
|
||||||
|
await reply(text="cleared stored sessions for this topic.")
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_chat_new_command(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
store: ChatSessionStore,
|
||||||
|
session_key: tuple[int, int | None] | None,
|
||||||
|
) -> None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
if session_key is None:
|
||||||
|
await reply(text="no stored sessions to clear for this chat.")
|
||||||
|
return
|
||||||
|
await store.clear_sessions(session_key[0], session_key[1])
|
||||||
|
if msg.chat_type == "private":
|
||||||
|
text = "cleared stored sessions for this chat."
|
||||||
|
else:
|
||||||
|
text = "cleared stored sessions for you in this chat."
|
||||||
|
await reply(text=text)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_topic_command(
|
||||||
|
cfg: TelegramBridgeConfig,
|
||||||
|
msg: TelegramIncomingMessage,
|
||||||
|
args_text: str,
|
||||||
|
store: TopicStateStore,
|
||||||
|
*,
|
||||||
|
resolved_scope: str | None = None,
|
||||||
|
scope_chat_ids: frozenset[int] | None = None,
|
||||||
|
) -> None:
|
||||||
|
reply = make_reply(cfg, msg)
|
||||||
|
error = _topics_command_error(
|
||||||
|
cfg,
|
||||||
|
msg.chat_id,
|
||||||
|
resolved_scope=resolved_scope,
|
||||||
|
scope_chat_ids=scope_chat_ids,
|
||||||
|
)
|
||||||
|
if error is not None:
|
||||||
|
await reply(text=error)
|
||||||
|
return
|
||||||
|
chat_project = _topics_chat_project(cfg, msg.chat_id)
|
||||||
|
context, error = _parse_project_branch_args(
|
||||||
|
args_text,
|
||||||
|
runtime=cfg.runtime,
|
||||||
|
require_branch=True,
|
||||||
|
chat_project=chat_project,
|
||||||
|
)
|
||||||
|
if error is not None or context is None:
|
||||||
|
usage = _usage_topic(chat_project=chat_project)
|
||||||
|
text = f"error:\n{error}\n{usage}" if error else usage
|
||||||
|
await reply(text=text)
|
||||||
|
return
|
||||||
|
existing = await store.find_thread_for_context(msg.chat_id, context)
|
||||||
|
if existing is not None:
|
||||||
|
await reply(
|
||||||
|
text=f"topic already exists for {_format_context(cfg.runtime, context)} "
|
||||||
|
"in this chat.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
title = _topic_title(runtime=cfg.runtime, context=context)
|
||||||
|
created = await cfg.bot.create_forum_topic(msg.chat_id, title)
|
||||||
|
if created is None:
|
||||||
|
await reply(text="failed to create topic.")
|
||||||
|
return
|
||||||
|
thread_id = created.message_thread_id
|
||||||
|
await store.set_context(
|
||||||
|
msg.chat_id,
|
||||||
|
thread_id,
|
||||||
|
context,
|
||||||
|
topic_title=title,
|
||||||
|
)
|
||||||
|
await reply(text=f"created topic `{title}`.")
|
||||||
|
bound_text = f"topic bound to `{_format_context(cfg.runtime, context)}`"
|
||||||
|
rendered_text, entities = prepare_telegram(MarkdownParts(header=bound_text))
|
||||||
|
await cfg.exec_cfg.transport.send(
|
||||||
|
channel_id=msg.chat_id,
|
||||||
|
message=RenderedMessage(text=rendered_text, extra={"entities": entities}),
|
||||||
|
options=SendOptions(thread_id=thread_id),
|
||||||
|
)
|
||||||
+16
-29
@@ -20,26 +20,25 @@ from ..transport import MessageRef
|
|||||||
from ..transport_runtime import ResolvedMessage
|
from ..transport_runtime import ResolvedMessage
|
||||||
from ..context import RunContext
|
from ..context import RunContext
|
||||||
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
|
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
|
||||||
from .commands import (
|
from .commands.agent import _handle_agent_command
|
||||||
|
from .commands.cancel import handle_callback_cancel, handle_cancel
|
||||||
|
from .commands.dispatch import _dispatch_command
|
||||||
|
from .commands.executor import _run_engine, _should_show_resume_line
|
||||||
|
from .commands.file_transfer import (
|
||||||
FILE_PUT_USAGE,
|
FILE_PUT_USAGE,
|
||||||
_dispatch_command,
|
|
||||||
_handle_agent_command,
|
|
||||||
_handle_chat_new_command,
|
|
||||||
_handle_ctx_command,
|
|
||||||
_handle_file_command,
|
_handle_file_command,
|
||||||
_handle_file_put_default,
|
_handle_file_put_default,
|
||||||
_handle_media_group,
|
_save_file_put,
|
||||||
|
)
|
||||||
|
from .commands.media import _handle_media_group
|
||||||
|
from .commands.menu import _reserved_commands, _set_command_menu
|
||||||
|
from .commands.parse import _parse_slash_command, is_cancel_command
|
||||||
|
from .commands.reply import make_reply
|
||||||
|
from .commands.topics import (
|
||||||
|
_handle_chat_new_command,
|
||||||
|
_handle_ctx_command,
|
||||||
_handle_new_command,
|
_handle_new_command,
|
||||||
_handle_topic_command,
|
_handle_topic_command,
|
||||||
_parse_slash_command,
|
|
||||||
_reserved_commands,
|
|
||||||
_save_file_put,
|
|
||||||
_should_show_resume_line,
|
|
||||||
_run_engine,
|
|
||||||
_set_command_menu,
|
|
||||||
handle_callback_cancel,
|
|
||||||
handle_cancel,
|
|
||||||
is_cancel_command,
|
|
||||||
)
|
)
|
||||||
from .context import _merge_topic_context, _usage_ctx_set, _usage_topic
|
from .context import _merge_topic_context, _usage_ctx_set, _usage_topic
|
||||||
from .topics import (
|
from .topics import (
|
||||||
@@ -519,13 +518,7 @@ async def run_main_loop(
|
|||||||
text: str,
|
text: str,
|
||||||
ambient_context: RunContext | None,
|
ambient_context: RunContext | None,
|
||||||
) -> ResolvedMessage | None:
|
) -> ResolvedMessage | None:
|
||||||
reply = partial(
|
reply = make_reply(cfg, msg)
|
||||||
send_plain,
|
|
||||||
cfg.exec_cfg.transport,
|
|
||||||
chat_id=msg.chat_id,
|
|
||||||
user_msg_id=msg.message_id,
|
|
||||||
thread_id=msg.thread_id,
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
resolved = cfg.runtime.resolve_message(
|
resolved = cfg.runtime.resolve_message(
|
||||||
text=text,
|
text=text,
|
||||||
@@ -757,13 +750,7 @@ async def run_main_loop(
|
|||||||
if reply_id is not None
|
if reply_id is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
reply = partial(
|
reply = make_reply(cfg, msg)
|
||||||
send_plain,
|
|
||||||
cfg.exec_cfg.transport,
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_msg_id=user_msg_id,
|
|
||||||
thread_id=msg.thread_id,
|
|
||||||
)
|
|
||||||
text = msg.text
|
text = msg.text
|
||||||
if msg.voice is not None:
|
if msg.voice is not None:
|
||||||
text = await transcribe_voice(
|
text = await transcribe_voice(
|
||||||
|
|||||||
@@ -286,8 +286,8 @@ def _confirm(message: str, *, default: bool = True) -> bool | None:
|
|||||||
exit_with_result(event)
|
exit_with_result(event)
|
||||||
|
|
||||||
@bindings.add(Keys.Any)
|
@bindings.add(Keys.Any)
|
||||||
def other(event):
|
def other(_event):
|
||||||
_ = event
|
return None
|
||||||
|
|
||||||
question = Question(
|
question = Question(
|
||||||
PromptSession(get_prompt_tokens, key_bindings=bindings, style=merged_style).app
|
PromptSession(get_prompt_tokens, key_bindings=bindings, style=merged_style).app
|
||||||
|
|||||||
@@ -0,0 +1,177 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
from collections.abc import Awaitable, Callable, Hashable
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from .client_api import RetryAfter
|
||||||
|
|
||||||
|
SEND_PRIORITY = 0
|
||||||
|
DELETE_PRIORITY = 1
|
||||||
|
EDIT_PRIORITY = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class OutboxOp:
|
||||||
|
execute: Callable[[], Awaitable[Any]]
|
||||||
|
priority: int
|
||||||
|
queued_at: float
|
||||||
|
chat_id: int | None
|
||||||
|
label: str | None = None
|
||||||
|
done: anyio.Event = field(default_factory=anyio.Event)
|
||||||
|
result: Any = None
|
||||||
|
|
||||||
|
def set_result(self, result: Any) -> None:
|
||||||
|
if self.done.is_set():
|
||||||
|
return
|
||||||
|
self.result = result
|
||||||
|
self.done.set()
|
||||||
|
|
||||||
|
|
||||||
|
class TelegramOutbox:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
interval_for_chat: Callable[[int | None], float],
|
||||||
|
clock: Callable[[], float] = time.monotonic,
|
||||||
|
sleep: Callable[[float], Awaitable[None]] = anyio.sleep,
|
||||||
|
on_error: Callable[[OutboxOp, Exception], None] | None = None,
|
||||||
|
on_outbox_error: Callable[[Exception], None] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._interval_for_chat = interval_for_chat
|
||||||
|
self._clock = clock
|
||||||
|
self._sleep = sleep
|
||||||
|
self._on_error = on_error
|
||||||
|
self._on_outbox_error = on_outbox_error
|
||||||
|
self._pending: dict[Hashable, OutboxOp] = {}
|
||||||
|
self._cond = anyio.Condition()
|
||||||
|
self._start_lock = anyio.Lock()
|
||||||
|
self._closed = False
|
||||||
|
self._tg: TaskGroup | None = None
|
||||||
|
self.next_at = 0.0
|
||||||
|
self.retry_at = 0.0
|
||||||
|
|
||||||
|
async def ensure_worker(self) -> None:
|
||||||
|
async with self._start_lock:
|
||||||
|
if self._tg is not None or self._closed:
|
||||||
|
return
|
||||||
|
self._tg = await anyio.create_task_group().__aenter__()
|
||||||
|
self._tg.start_soon(self.run)
|
||||||
|
|
||||||
|
async def enqueue(self, *, key: Hashable, op: OutboxOp, wait: bool = True) -> Any:
|
||||||
|
await self.ensure_worker()
|
||||||
|
async with self._cond:
|
||||||
|
if self._closed:
|
||||||
|
op.set_result(None)
|
||||||
|
return op.result
|
||||||
|
previous = self._pending.get(key)
|
||||||
|
if previous is not None:
|
||||||
|
op.queued_at = previous.queued_at
|
||||||
|
previous.set_result(None)
|
||||||
|
self._pending[key] = op
|
||||||
|
self._cond.notify()
|
||||||
|
if not wait:
|
||||||
|
return None
|
||||||
|
await op.done.wait()
|
||||||
|
return op.result
|
||||||
|
|
||||||
|
async def drop_pending(self, *, key: Hashable) -> None:
|
||||||
|
async with self._cond:
|
||||||
|
pending = self._pending.pop(key, None)
|
||||||
|
if pending is not None:
|
||||||
|
pending.set_result(None)
|
||||||
|
self._cond.notify()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
async with self._cond:
|
||||||
|
self._closed = True
|
||||||
|
self.fail_pending()
|
||||||
|
self._cond.notify_all()
|
||||||
|
if self._tg is not None:
|
||||||
|
await self._tg.__aexit__(None, None, None)
|
||||||
|
self._tg = None
|
||||||
|
|
||||||
|
def fail_pending(self) -> None:
|
||||||
|
for pending in list(self._pending.values()):
|
||||||
|
pending.set_result(None)
|
||||||
|
self._pending.clear()
|
||||||
|
|
||||||
|
def pick_locked(self) -> tuple[Hashable, OutboxOp] | None:
|
||||||
|
if not self._pending:
|
||||||
|
return None
|
||||||
|
return min(
|
||||||
|
self._pending.items(),
|
||||||
|
key=lambda item: (item[1].priority, item[1].queued_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_op(self, op: OutboxOp) -> Any:
|
||||||
|
try:
|
||||||
|
return await op.execute()
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
if isinstance(exc, RetryAfter):
|
||||||
|
raise
|
||||||
|
if self._on_error is not None:
|
||||||
|
self._on_error(op, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def sleep_until(self, deadline: float) -> None:
|
||||||
|
delay = deadline - self._clock()
|
||||||
|
if delay > 0:
|
||||||
|
await self._sleep(delay)
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
cancel_exc = anyio.get_cancelled_exc_class()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
async with self._cond:
|
||||||
|
while not self._pending and not self._closed:
|
||||||
|
await self._cond.wait()
|
||||||
|
if self._closed and not self._pending:
|
||||||
|
return
|
||||||
|
blocked_until = max(self.next_at, self.retry_at)
|
||||||
|
if self._clock() < blocked_until:
|
||||||
|
await self.sleep_until(blocked_until)
|
||||||
|
continue
|
||||||
|
async with self._cond:
|
||||||
|
if self._closed and not self._pending:
|
||||||
|
return
|
||||||
|
picked = self.pick_locked()
|
||||||
|
if picked is None:
|
||||||
|
continue
|
||||||
|
key, op = picked
|
||||||
|
self._pending.pop(key, None)
|
||||||
|
started_at = self._clock()
|
||||||
|
try:
|
||||||
|
result = await self.execute_op(op)
|
||||||
|
except RetryAfter as exc:
|
||||||
|
self.retry_at = max(self.retry_at, self._clock() + exc.retry_after)
|
||||||
|
async with self._cond:
|
||||||
|
if self._closed:
|
||||||
|
op.set_result(None)
|
||||||
|
elif key not in self._pending:
|
||||||
|
self._pending[key] = op
|
||||||
|
self._cond.notify()
|
||||||
|
else:
|
||||||
|
op.set_result(None)
|
||||||
|
continue
|
||||||
|
self.next_at = started_at + self._interval_for_chat(op.chat_id)
|
||||||
|
op.set_result(result)
|
||||||
|
except cancel_exc:
|
||||||
|
return
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
async with self._cond:
|
||||||
|
self._closed = True
|
||||||
|
self.fail_pending()
|
||||||
|
self._cond.notify_all()
|
||||||
|
if self._on_outbox_error is not None:
|
||||||
|
self._on_outbox_error(exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from anyio.abc import TaskGroup
|
||||||
|
else:
|
||||||
|
TaskGroup = object
|
||||||
@@ -0,0 +1,283 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from collections.abc import AsyncIterator, Callable, Iterable
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from ..logging import get_logger
|
||||||
|
from .api_models import Update
|
||||||
|
from .client_api import BotClient
|
||||||
|
from .types import (
|
||||||
|
TelegramCallbackQuery,
|
||||||
|
TelegramDocument,
|
||||||
|
TelegramIncomingMessage,
|
||||||
|
TelegramIncomingUpdate,
|
||||||
|
TelegramVoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_incoming_update(
|
||||||
|
update: Update | dict[str, Any],
|
||||||
|
*,
|
||||||
|
chat_id: int | None = None,
|
||||||
|
chat_ids: set[int] | None = None,
|
||||||
|
) -> TelegramIncomingUpdate | None:
|
||||||
|
if isinstance(update, Update):
|
||||||
|
msg = update.message
|
||||||
|
callback_query = update.callback_query
|
||||||
|
else:
|
||||||
|
msg = update.get("message")
|
||||||
|
callback_query = update.get("callback_query")
|
||||||
|
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
return _parse_incoming_message(msg, chat_id=chat_id, chat_ids=chat_ids)
|
||||||
|
if isinstance(callback_query, dict):
|
||||||
|
return _parse_callback_query(
|
||||||
|
callback_query,
|
||||||
|
chat_id=chat_id,
|
||||||
|
chat_ids=chat_ids,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_incoming_message(
|
||||||
|
msg: dict[str, Any],
|
||||||
|
*,
|
||||||
|
chat_id: int | None = None,
|
||||||
|
chat_ids: set[int] | None = None,
|
||||||
|
) -> TelegramIncomingMessage | None:
|
||||||
|
def _parse_document_payload(payload: dict[str, Any]) -> TelegramDocument | None:
|
||||||
|
file_id = payload.get("file_id")
|
||||||
|
if not isinstance(file_id, str) or not file_id:
|
||||||
|
return None
|
||||||
|
return TelegramDocument(
|
||||||
|
file_id=file_id,
|
||||||
|
file_name=payload.get("file_name")
|
||||||
|
if isinstance(payload.get("file_name"), str)
|
||||||
|
else None,
|
||||||
|
mime_type=payload.get("mime_type")
|
||||||
|
if isinstance(payload.get("mime_type"), str)
|
||||||
|
else None,
|
||||||
|
file_size=payload.get("file_size")
|
||||||
|
if isinstance(payload.get("file_size"), int)
|
||||||
|
and not isinstance(payload.get("file_size"), bool)
|
||||||
|
else None,
|
||||||
|
raw=payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_text = msg.get("text")
|
||||||
|
text = raw_text if isinstance(raw_text, str) else None
|
||||||
|
caption = msg.get("caption")
|
||||||
|
if text is None and isinstance(caption, str):
|
||||||
|
text = caption
|
||||||
|
if text is None:
|
||||||
|
text = ""
|
||||||
|
file_command = False
|
||||||
|
if isinstance(text, str):
|
||||||
|
stripped = text.lstrip()
|
||||||
|
if stripped.startswith("/"):
|
||||||
|
token = stripped.split(maxsplit=1)[0]
|
||||||
|
file_command = token.startswith("/file")
|
||||||
|
voice_payload: TelegramVoice | None = None
|
||||||
|
voice = msg.get("voice")
|
||||||
|
if isinstance(voice, dict):
|
||||||
|
file_id = voice.get("file_id")
|
||||||
|
if not isinstance(file_id, str) or not file_id:
|
||||||
|
file_id = None
|
||||||
|
if file_id is not None:
|
||||||
|
voice_payload = TelegramVoice(
|
||||||
|
file_id=file_id,
|
||||||
|
mime_type=voice.get("mime_type")
|
||||||
|
if isinstance(voice.get("mime_type"), str)
|
||||||
|
else None,
|
||||||
|
file_size=voice.get("file_size")
|
||||||
|
if isinstance(voice.get("file_size"), int)
|
||||||
|
and not isinstance(voice.get("file_size"), bool)
|
||||||
|
else None,
|
||||||
|
duration=voice.get("duration")
|
||||||
|
if isinstance(voice.get("duration"), int)
|
||||||
|
and not isinstance(voice.get("duration"), bool)
|
||||||
|
else None,
|
||||||
|
raw=voice,
|
||||||
|
)
|
||||||
|
if not isinstance(raw_text, str) and not isinstance(caption, str):
|
||||||
|
text = ""
|
||||||
|
document_payload: TelegramDocument | None = None
|
||||||
|
document = msg.get("document")
|
||||||
|
if isinstance(document, dict):
|
||||||
|
document_payload = _parse_document_payload(document)
|
||||||
|
if document_payload is None:
|
||||||
|
video = msg.get("video")
|
||||||
|
if isinstance(video, dict):
|
||||||
|
document_payload = _parse_document_payload(video)
|
||||||
|
if document_payload is None:
|
||||||
|
photo = msg.get("photo")
|
||||||
|
if isinstance(photo, list):
|
||||||
|
best: dict[str, Any] | None = None
|
||||||
|
best_score = -1
|
||||||
|
for item in photo:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
file_id = item.get("file_id")
|
||||||
|
if not isinstance(file_id, str) or not file_id:
|
||||||
|
continue
|
||||||
|
size = item.get("file_size")
|
||||||
|
if isinstance(size, int) and not isinstance(size, bool):
|
||||||
|
score = size
|
||||||
|
else:
|
||||||
|
width = item.get("width")
|
||||||
|
height = item.get("height")
|
||||||
|
if isinstance(width, int) and isinstance(height, int):
|
||||||
|
score = width * height
|
||||||
|
else:
|
||||||
|
score = 0
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best = item
|
||||||
|
if best is not None:
|
||||||
|
document_payload = _parse_document_payload(best)
|
||||||
|
if document_payload is None and file_command:
|
||||||
|
sticker = msg.get("sticker")
|
||||||
|
if isinstance(sticker, dict):
|
||||||
|
document_payload = _parse_document_payload(sticker)
|
||||||
|
has_text = isinstance(raw_text, str) or isinstance(caption, str)
|
||||||
|
if not has_text and voice_payload is None and document_payload is None:
|
||||||
|
return None
|
||||||
|
chat = msg.get("chat")
|
||||||
|
if not isinstance(chat, dict):
|
||||||
|
return None
|
||||||
|
msg_chat_id = chat.get("id")
|
||||||
|
if not isinstance(msg_chat_id, int):
|
||||||
|
return None
|
||||||
|
chat_type = chat.get("type") if isinstance(chat.get("type"), str) else None
|
||||||
|
is_forum = chat.get("is_forum")
|
||||||
|
if not isinstance(is_forum, bool):
|
||||||
|
is_forum = None
|
||||||
|
allowed = chat_ids
|
||||||
|
if allowed is None and chat_id is not None:
|
||||||
|
allowed = {chat_id}
|
||||||
|
if allowed is not None and msg_chat_id not in allowed:
|
||||||
|
return None
|
||||||
|
message_id = msg.get("message_id")
|
||||||
|
if not isinstance(message_id, int):
|
||||||
|
return None
|
||||||
|
reply = msg.get("reply_to_message")
|
||||||
|
reply_to_message_id = None
|
||||||
|
reply_to_text = None
|
||||||
|
if isinstance(reply, dict):
|
||||||
|
reply_to_message_id = (
|
||||||
|
reply.get("message_id")
|
||||||
|
if isinstance(reply.get("message_id"), int)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
reply_to_text = (
|
||||||
|
reply.get("text") if isinstance(reply.get("text"), str) else None
|
||||||
|
)
|
||||||
|
sender = msg.get("from")
|
||||||
|
sender_id = (
|
||||||
|
sender.get("id")
|
||||||
|
if isinstance(sender, dict) and isinstance(sender.get("id"), int)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
media_group_id = msg.get("media_group_id")
|
||||||
|
if not isinstance(media_group_id, str):
|
||||||
|
media_group_id = None
|
||||||
|
thread_id = msg.get("message_thread_id")
|
||||||
|
if isinstance(thread_id, bool) or not isinstance(thread_id, int):
|
||||||
|
thread_id = None
|
||||||
|
is_topic_message = msg.get("is_topic_message")
|
||||||
|
if not isinstance(is_topic_message, bool):
|
||||||
|
is_topic_message = None
|
||||||
|
return TelegramIncomingMessage(
|
||||||
|
transport="telegram",
|
||||||
|
chat_id=msg_chat_id,
|
||||||
|
message_id=message_id,
|
||||||
|
text=text,
|
||||||
|
reply_to_message_id=reply_to_message_id,
|
||||||
|
reply_to_text=reply_to_text,
|
||||||
|
sender_id=sender_id,
|
||||||
|
media_group_id=media_group_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
is_topic_message=is_topic_message,
|
||||||
|
chat_type=chat_type,
|
||||||
|
is_forum=is_forum,
|
||||||
|
voice=voice_payload,
|
||||||
|
document=document_payload,
|
||||||
|
raw=msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_callback_query(
|
||||||
|
query: dict[str, Any],
|
||||||
|
*,
|
||||||
|
chat_id: int | None = None,
|
||||||
|
chat_ids: set[int] | None = None,
|
||||||
|
) -> TelegramCallbackQuery | None:
|
||||||
|
callback_id = query.get("id")
|
||||||
|
if not isinstance(callback_id, str) or not callback_id:
|
||||||
|
return None
|
||||||
|
msg = query.get("message")
|
||||||
|
if not isinstance(msg, dict):
|
||||||
|
return None
|
||||||
|
chat = msg.get("chat")
|
||||||
|
if not isinstance(chat, dict):
|
||||||
|
return None
|
||||||
|
msg_chat_id = chat.get("id")
|
||||||
|
if not isinstance(msg_chat_id, int):
|
||||||
|
return None
|
||||||
|
allowed = chat_ids
|
||||||
|
if allowed is None and chat_id is not None:
|
||||||
|
allowed = {chat_id}
|
||||||
|
if allowed is not None and msg_chat_id not in allowed:
|
||||||
|
return None
|
||||||
|
message_id = msg.get("message_id")
|
||||||
|
if not isinstance(message_id, int):
|
||||||
|
return None
|
||||||
|
data = query.get("data") if isinstance(query.get("data"), str) else None
|
||||||
|
sender = query.get("from")
|
||||||
|
sender_id = (
|
||||||
|
sender.get("id")
|
||||||
|
if isinstance(sender, dict) and isinstance(sender.get("id"), int)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return TelegramCallbackQuery(
|
||||||
|
transport="telegram",
|
||||||
|
chat_id=msg_chat_id,
|
||||||
|
message_id=message_id,
|
||||||
|
callback_query_id=callback_id,
|
||||||
|
data=data,
|
||||||
|
sender_id=sender_id,
|
||||||
|
raw=query,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def poll_incoming(
|
||||||
|
bot: BotClient,
|
||||||
|
*,
|
||||||
|
chat_id: int | None = None,
|
||||||
|
chat_ids: Iterable[int] | Callable[[], Iterable[int]] | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> AsyncIterator[TelegramIncomingUpdate]:
|
||||||
|
while True:
|
||||||
|
updates = await bot.get_updates(
|
||||||
|
offset=offset,
|
||||||
|
timeout_s=50,
|
||||||
|
allowed_updates=["message", "callback_query"],
|
||||||
|
)
|
||||||
|
if updates is None:
|
||||||
|
logger.info("loop.get_updates.failed")
|
||||||
|
await anyio.sleep(2)
|
||||||
|
continue
|
||||||
|
logger.debug("loop.updates", updates=updates)
|
||||||
|
resolved_chat_ids = chat_ids() if callable(chat_ids) else chat_ids
|
||||||
|
allowed = set(resolved_chat_ids) if resolved_chat_ids is not None else None
|
||||||
|
if allowed is None and chat_id is not None:
|
||||||
|
allowed = {chat_id}
|
||||||
|
for upd in updates:
|
||||||
|
offset = upd.update_id + 1
|
||||||
|
msg = parse_incoming_update(upd, chat_ids=allowed)
|
||||||
|
if msg is not None:
|
||||||
|
yield msg
|
||||||
@@ -1,14 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
from collections.abc import Callable
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Generic, Protocol, TypeVar
|
from typing import Any, Protocol
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
T = TypeVar("T", bound="_VersionedState")
|
from ..utils.json_state import atomic_write_json
|
||||||
|
|
||||||
|
|
||||||
class _Logger(Protocol):
|
class _Logger(Protocol):
|
||||||
@@ -19,7 +18,7 @@ class _VersionedState(Protocol):
|
|||||||
version: int
|
version: int
|
||||||
|
|
||||||
|
|
||||||
class JsonStateStore(Generic[T]):
|
class JsonStateStore[T: _VersionedState]:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
path: Path,
|
path: Path,
|
||||||
@@ -84,11 +83,6 @@ class JsonStateStore(Generic[T]):
|
|||||||
self._state = payload
|
self._state = payload
|
||||||
|
|
||||||
def _save_locked(self) -> None:
|
def _save_locked(self) -> None:
|
||||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
payload = msgspec.to_builtins(self._state)
|
payload = msgspec.to_builtins(self._state)
|
||||||
tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp")
|
atomic_write_json(self._path, payload)
|
||||||
with open(tmp_path, "w", encoding="utf-8") as handle:
|
|
||||||
json.dump(payload, handle, indent=2, sort_keys=True)
|
|
||||||
handle.write("\n")
|
|
||||||
os.replace(tmp_path, self._path)
|
|
||||||
self._mtime_ns = self._stat_mtime_ns()
|
self._mtime_ns = self._stat_mtime_ns()
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Protocol, TypeAlias
|
from typing import Any, Protocol
|
||||||
|
|
||||||
ChannelId: TypeAlias = int | str
|
type ChannelId = int | str
|
||||||
MessageId: TypeAlias = int | str
|
type MessageId = int | str
|
||||||
ThreadId: TypeAlias = int | str
|
type ThreadId = int | str
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, TypeAlias
|
from typing import Any, Literal
|
||||||
|
|
||||||
from .config import ConfigError, ProjectsConfig
|
from .config import ConfigError, ProjectsConfig
|
||||||
from .context import RunContext
|
from .context import RunContext
|
||||||
@@ -19,7 +19,7 @@ from .router import AutoRouter, EngineStatus
|
|||||||
from .runner import Runner
|
from .runner import Runner
|
||||||
from .worktrees import WorktreeError, resolve_run_cwd
|
from .worktrees import WorktreeError, resolve_run_cwd
|
||||||
|
|
||||||
ContextSource: TypeAlias = Literal[
|
type ContextSource = Literal[
|
||||||
"reply_ctx",
|
"reply_ctx",
|
||||||
"directives",
|
"directives",
|
||||||
"ambient",
|
"ambient",
|
||||||
@@ -234,13 +234,13 @@ class TransportRuntime:
|
|||||||
project_key = ambient_context.project
|
project_key = ambient_context.project
|
||||||
else:
|
else:
|
||||||
project_key = default_project
|
project_key = default_project
|
||||||
if branch is None:
|
if (
|
||||||
if (
|
branch is None
|
||||||
ambient_context is not None
|
and ambient_context is not None
|
||||||
and ambient_context.branch is not None
|
and ambient_context.branch is not None
|
||||||
and project_key == ambient_context.project
|
and project_key == ambient_context.project
|
||||||
):
|
):
|
||||||
branch = ambient_context.branch
|
branch = ambient_context.branch
|
||||||
context: RunContext | None = None
|
context: RunContext | None = None
|
||||||
if project_key is not None or branch is not None:
|
if project_key is not None or branch is not None:
|
||||||
context = RunContext(project=project_key, branch=branch)
|
context = RunContext(project=project_key, branch=branch)
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from .backends import EngineBackend, SetupIssue
|
from .backends import EngineBackend, SetupIssue
|
||||||
from .plugins import TRANSPORT_GROUP, list_ids, load_plugin_backend
|
from .plugins import TRANSPORT_GROUP, list_ids, load_plugin_backend
|
||||||
@@ -34,7 +35,7 @@ class TransportBackend(Protocol):
|
|||||||
def interactive_setup(self, *, force: bool) -> bool: ...
|
def interactive_setup(self, *, force: bool) -> bool: ...
|
||||||
|
|
||||||
def lock_token(
|
def lock_token(
|
||||||
self, *, transport_config: object, config_path: Path
|
self, *, transport_config: object, _config_path: Path
|
||||||
) -> str | None: ...
|
) -> str | None: ...
|
||||||
|
|
||||||
def build_and_run(
|
def build_and_run(
|
||||||
|
|||||||
@@ -0,0 +1,21 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def atomic_write_json(
|
||||||
|
path: Path,
|
||||||
|
payload: Any,
|
||||||
|
*,
|
||||||
|
indent: int = 2,
|
||||||
|
sort_keys: bool = True,
|
||||||
|
) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp_path = path.with_suffix(f"{path.suffix}.tmp")
|
||||||
|
with open(tmp_path, "w", encoding="utf-8") as handle:
|
||||||
|
json.dump(payload, handle, indent=indent, sort_keys=sort_keys)
|
||||||
|
handle.write("\n")
|
||||||
|
os.replace(tmp_path, path)
|
||||||
+3
-3
@@ -14,7 +14,7 @@ from takopi.model import (
|
|||||||
|
|
||||||
|
|
||||||
def session_started(engine: str, value: str, title: str = "Codex") -> TakopiEvent:
|
def session_started(engine: str, value: str, title: str = "Codex") -> TakopiEvent:
|
||||||
engine_id = EngineId(engine)
|
engine_id: EngineId = engine
|
||||||
return StartedEvent(
|
return StartedEvent(
|
||||||
engine=engine_id,
|
engine=engine_id,
|
||||||
resume=ResumeToken(engine=engine_id, value=value),
|
resume=ResumeToken(engine=engine_id, value=value),
|
||||||
@@ -29,7 +29,7 @@ def action_started(
|
|||||||
detail: dict[str, Any] | None = None,
|
detail: dict[str, Any] | None = None,
|
||||||
engine: str = "codex",
|
engine: str = "codex",
|
||||||
) -> TakopiEvent:
|
) -> TakopiEvent:
|
||||||
engine_id = EngineId(engine)
|
engine_id: EngineId = engine
|
||||||
return ActionEvent(
|
return ActionEvent(
|
||||||
engine=engine_id,
|
engine=engine_id,
|
||||||
action=Action(
|
action=Action(
|
||||||
@@ -50,7 +50,7 @@ def action_completed(
|
|||||||
detail: dict[str, Any] | None = None,
|
detail: dict[str, Any] | None = None,
|
||||||
engine: str = "codex",
|
engine: str = "codex",
|
||||||
) -> TakopiEvent:
|
) -> TakopiEvent:
|
||||||
engine_id = EngineId(engine)
|
engine_id: EngineId = engine
|
||||||
return ActionEvent(
|
return ActionEvent(
|
||||||
engine=engine_id,
|
engine=engine_id,
|
||||||
action=Action(
|
action=Action(
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Iterable
|
from typing import Any
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import pytest
|
|||||||
|
|
||||||
from takopi.runner_bridge import ExecBridgeConfig, IncomingMessage, handle_message
|
from takopi.runner_bridge import ExecBridgeConfig, IncomingMessage, handle_message
|
||||||
from takopi.markdown import MarkdownParts, MarkdownPresenter
|
from takopi.markdown import MarkdownParts, MarkdownPresenter
|
||||||
from takopi.model import EngineId, ResumeToken, TakopiEvent
|
from takopi.model import ResumeToken, TakopiEvent
|
||||||
from takopi.telegram.render import prepare_telegram
|
from takopi.telegram.render import prepare_telegram
|
||||||
from takopi.runners.codex import CodexRunner
|
from takopi.runners.codex import CodexRunner
|
||||||
from takopi.runners.mock import Advance, Emit, Raise, Return, ScriptRunner, Wait
|
from takopi.runners.mock import Advance, Emit, Raise, Return, ScriptRunner, Wait
|
||||||
@@ -13,7 +13,7 @@ from takopi.settings import load_settings, require_telegram
|
|||||||
from takopi.transport import MessageRef, RenderedMessage, SendOptions
|
from takopi.transport import MessageRef, RenderedMessage, SendOptions
|
||||||
from tests.factories import action_completed, action_started
|
from tests.factories import action_completed, action_started
|
||||||
|
|
||||||
CODEX_ENGINE = EngineId("codex")
|
CODEX_ENGINE = "codex"
|
||||||
|
|
||||||
|
|
||||||
class _FakeTransport:
|
class _FakeTransport:
|
||||||
|
|||||||
@@ -7,14 +7,13 @@ from collections.abc import AsyncIterator
|
|||||||
from takopi.model import (
|
from takopi.model import (
|
||||||
ActionEvent,
|
ActionEvent,
|
||||||
CompletedEvent,
|
CompletedEvent,
|
||||||
EngineId,
|
|
||||||
ResumeToken,
|
ResumeToken,
|
||||||
StartedEvent,
|
StartedEvent,
|
||||||
TakopiEvent,
|
TakopiEvent,
|
||||||
)
|
)
|
||||||
from takopi.runners.codex import CodexRunner, find_exec_only_flag
|
from takopi.runners.codex import CodexRunner, find_exec_only_flag
|
||||||
|
|
||||||
CODEX_ENGINE = EngineId("codex")
|
CODEX_ENGINE = "codex"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
|||||||
@@ -45,11 +45,10 @@ def test_resolve_default_base_prefers_master_over_main(monkeypatch) -> None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _fake_ok(args, **kwargs):
|
def _fake_ok(args, **kwargs):
|
||||||
if args == ["show-ref", "--verify", "--quiet", "refs/heads/master"]:
|
return args in (
|
||||||
return True
|
["show-ref", "--verify", "--quiet", "refs/heads/master"],
|
||||||
if args == ["show-ref", "--verify", "--quiet", "refs/heads/main"]:
|
["show-ref", "--verify", "--quiet", "refs/heads/main"],
|
||||||
return True
|
)
|
||||||
return False
|
|
||||||
|
|
||||||
monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout)
|
monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout)
|
||||||
monkeypatch.setattr("takopi.utils.git.git_ok", _fake_ok)
|
monkeypatch.setattr("takopi.utils.git.git_ok", _fake_ok)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from takopi.model import (
|
|||||||
Action,
|
Action,
|
||||||
ActionEvent,
|
ActionEvent,
|
||||||
CompletedEvent,
|
CompletedEvent,
|
||||||
EngineId,
|
|
||||||
ResumeToken,
|
ResumeToken,
|
||||||
StartedEvent,
|
StartedEvent,
|
||||||
TakopiEvent,
|
TakopiEvent,
|
||||||
@@ -15,7 +14,7 @@ from takopi.model import (
|
|||||||
from takopi.runners.mock import Emit, Return, ScriptRunner, Wait
|
from takopi.runners.mock import Emit, Return, ScriptRunner, Wait
|
||||||
from tests.factories import action_started
|
from tests.factories import action_started
|
||||||
|
|
||||||
CODEX_ENGINE = EngineId("codex")
|
CODEX_ENGINE = "codex"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -84,7 +83,7 @@ async def test_runner_releases_lock_when_consumer_closes() -> None:
|
|||||||
gate = anyio.Event()
|
gate = anyio.Event()
|
||||||
runner = ScriptRunner([Wait(gate)], engine=CODEX_ENGINE, resume_value="sid")
|
runner = ScriptRunner([Wait(gate)], engine=CODEX_ENGINE, resume_value="sid")
|
||||||
|
|
||||||
gen = cast(AsyncGenerator[TakopiEvent, None], runner.run("hello", None))
|
gen = cast(AsyncGenerator[TakopiEvent], runner.run("hello", None))
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
evt = await anext(gen)
|
evt = await anext(gen)
|
||||||
@@ -94,7 +93,7 @@ async def test_runner_releases_lock_when_consumer_closes() -> None:
|
|||||||
await gen.aclose()
|
await gen.aclose()
|
||||||
|
|
||||||
gen2 = cast(
|
gen2 = cast(
|
||||||
AsyncGenerator[TakopiEvent, None],
|
AsyncGenerator[TakopiEvent],
|
||||||
runner.run("again", ResumeToken(engine=CODEX_ENGINE, value="sid")),
|
runner.run("again", ResumeToken(engine=CODEX_ENGINE, value="sid")),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import takopi.runner as runner_module
|
|||||||
from takopi.model import (
|
from takopi.model import (
|
||||||
ActionEvent,
|
ActionEvent,
|
||||||
CompletedEvent,
|
CompletedEvent,
|
||||||
EngineId,
|
|
||||||
ResumeToken,
|
ResumeToken,
|
||||||
StartedEvent,
|
StartedEvent,
|
||||||
TakopiEvent,
|
TakopiEvent,
|
||||||
@@ -22,7 +21,7 @@ from takopi.runner import (
|
|||||||
|
|
||||||
|
|
||||||
class _DummyRunner(ResumeTokenMixin, BaseRunner):
|
class _DummyRunner(ResumeTokenMixin, BaseRunner):
|
||||||
engine = EngineId("dummy")
|
engine = "dummy"
|
||||||
resume_re = re.compile(r"(?im)^`?dummy resume (?P<token>[^`\s]+)`?$")
|
resume_re = re.compile(r"(?im)^`?dummy resume (?P<token>[^`\s]+)`?$")
|
||||||
|
|
||||||
async def run_impl(
|
async def run_impl(
|
||||||
@@ -39,7 +38,7 @@ class _DummyRunner(ResumeTokenMixin, BaseRunner):
|
|||||||
|
|
||||||
|
|
||||||
class _DummyJsonlRunner(JsonlSubprocessRunner):
|
class _DummyJsonlRunner(JsonlSubprocessRunner):
|
||||||
engine = EngineId("dummy-jsonl")
|
engine = "dummy-jsonl"
|
||||||
|
|
||||||
def command(self) -> str:
|
def command(self) -> str:
|
||||||
return "dummy"
|
return "dummy"
|
||||||
@@ -67,7 +66,7 @@ class _DummyJsonlRunner(JsonlSubprocessRunner):
|
|||||||
|
|
||||||
|
|
||||||
class _BareJsonlRunner(JsonlSubprocessRunner):
|
class _BareJsonlRunner(JsonlSubprocessRunner):
|
||||||
engine = EngineId("bare-jsonl")
|
engine = "bare-jsonl"
|
||||||
|
|
||||||
|
|
||||||
class _RunJsonlRunner(_DummyJsonlRunner):
|
class _RunJsonlRunner(_DummyJsonlRunner):
|
||||||
@@ -177,7 +176,7 @@ async def test_base_runner_run_locked_handles_resume() -> None:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_base_runner_rejects_wrong_resume_engine() -> None:
|
async def test_base_runner_rejects_wrong_resume_engine() -> None:
|
||||||
runner = _DummyRunner()
|
runner = _DummyRunner()
|
||||||
bad_resume = ResumeToken(engine=EngineId("other"), value="oops")
|
bad_resume = ResumeToken(engine="other", value="oops")
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
_ = [evt async for evt in runner.run("hello", bad_resume)]
|
_ = [evt async for evt in runner.run("hello", bad_resume)]
|
||||||
|
|
||||||
@@ -185,7 +184,7 @@ async def test_base_runner_rejects_wrong_resume_engine() -> None:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_base_runner_run_impl_not_implemented() -> None:
|
async def test_base_runner_run_impl_not_implemented() -> None:
|
||||||
class _BareRunner(BaseRunner):
|
class _BareRunner(BaseRunner):
|
||||||
engine = EngineId("bare")
|
engine = "bare"
|
||||||
|
|
||||||
runner = _BareRunner()
|
runner = _BareRunner()
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
@@ -204,7 +203,7 @@ def test_resume_token_format_and_extract() -> None:
|
|||||||
assert runner.extract_resume(None) is None
|
assert runner.extract_resume(None) is None
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
runner.format_resume(ResumeToken(engine=EngineId("other"), value="bad"))
|
runner.format_resume(ResumeToken(engine="other", value="bad"))
|
||||||
|
|
||||||
|
|
||||||
def test_session_lock_reuse() -> None:
|
def test_session_lock_reuse() -> None:
|
||||||
@@ -294,7 +293,7 @@ def test_jsonl_helpers() -> None:
|
|||||||
assert found == resume
|
assert found == resume
|
||||||
assert emit is False
|
assert emit is False
|
||||||
|
|
||||||
mismatch = StartedEvent(engine=EngineId("other"), resume=resume, title="t")
|
mismatch = StartedEvent(engine="other", resume=resume, title="t")
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
runner.handle_started_event(mismatch, expected_session=None, found_session=None)
|
runner.handle_started_event(mismatch, expected_session=None, found_session=None)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Any
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from takopi.config import ProjectsConfig
|
from takopi.config import ProjectsConfig
|
||||||
from takopi.model import EngineId
|
|
||||||
from takopi.router import AutoRouter, RunnerEntry
|
from takopi.router import AutoRouter, RunnerEntry
|
||||||
from takopi.runners.mock import Return, ScriptRunner
|
from takopi.runners.mock import Return, ScriptRunner
|
||||||
from takopi.settings import (
|
from takopi.settings import (
|
||||||
@@ -19,8 +18,8 @@ from takopi.transport_runtime import TransportRuntime
|
|||||||
|
|
||||||
|
|
||||||
def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
|
def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
|
||||||
codex = EngineId("codex")
|
codex = "codex"
|
||||||
pi = EngineId("pi")
|
pi = "pi"
|
||||||
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
||||||
missing = ScriptRunner([Return(answer="ok")], engine=pi)
|
missing = ScriptRunner([Return(answer="ok")], engine=pi)
|
||||||
router = AutoRouter(
|
router = AutoRouter(
|
||||||
@@ -53,9 +52,9 @@ def test_build_startup_message_includes_missing_engines(tmp_path: Path) -> None:
|
|||||||
def test_build_startup_message_surfaces_unavailable_engine_reasons(
|
def test_build_startup_message_surfaces_unavailable_engine_reasons(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
codex = EngineId("codex")
|
codex = "codex"
|
||||||
pi = EngineId("pi")
|
pi = "pi"
|
||||||
claude = EngineId("claude")
|
claude = "claude"
|
||||||
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
||||||
bad_cfg = ScriptRunner([Return(answer="ok")], engine=pi)
|
bad_cfg = ScriptRunner([Return(answer="ok")], engine=pi)
|
||||||
load_err = ScriptRunner([Return(answer="ok")], engine=claude)
|
load_err = ScriptRunner([Return(answer="ok")], engine=claude)
|
||||||
@@ -100,7 +99,7 @@ def test_telegram_backend_build_and_run_wires_config(
|
|||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
|
|
||||||
codex = EngineId("codex")
|
codex = "codex"
|
||||||
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
runner = ScriptRunner([Return(answer="ok")], engine=codex)
|
||||||
router = AutoRouter(
|
router = AutoRouter(
|
||||||
entries=[RunnerEntry(engine=codex, runner=runner)],
|
entries=[RunnerEntry(engine=codex, runner=runner)],
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ import anyio
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from takopi import commands, plugins
|
from takopi import commands, plugins
|
||||||
|
from takopi.telegram.commands.executor import _CaptureTransport, _run_engine
|
||||||
|
from takopi.telegram.commands.file_transfer import _handle_file_get, _handle_file_put
|
||||||
import takopi.telegram.loop as telegram_loop
|
import takopi.telegram.loop as telegram_loop
|
||||||
import takopi.telegram.commands as telegram_commands
|
|
||||||
import takopi.telegram.topics as telegram_topics
|
import takopi.telegram.topics as telegram_topics
|
||||||
from takopi.directives import parse_directives
|
from takopi.directives import parse_directives
|
||||||
from takopi.telegram.api_models import (
|
from takopi.telegram.api_models import (
|
||||||
@@ -39,7 +40,7 @@ from takopi.context import RunContext
|
|||||||
from takopi.config import ProjectConfig, ProjectsConfig
|
from takopi.config import ProjectConfig, ProjectsConfig
|
||||||
from takopi.runner_bridge import ExecBridgeConfig, RunningTask
|
from takopi.runner_bridge import ExecBridgeConfig, RunningTask
|
||||||
from takopi.markdown import MarkdownPresenter
|
from takopi.markdown import MarkdownPresenter
|
||||||
from takopi.model import EngineId, ResumeToken
|
from takopi.model import ResumeToken
|
||||||
from takopi.progress import ProgressTracker
|
from takopi.progress import ProgressTracker
|
||||||
from takopi.router import AutoRouter, RunnerEntry
|
from takopi.router import AutoRouter, RunnerEntry
|
||||||
from takopi.transport_runtime import TransportRuntime
|
from takopi.transport_runtime import TransportRuntime
|
||||||
@@ -52,7 +53,7 @@ from takopi.telegram.types import (
|
|||||||
from takopi.transport import MessageRef, RenderedMessage, SendOptions
|
from takopi.transport import MessageRef, RenderedMessage, SendOptions
|
||||||
from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints
|
from tests.plugin_fixtures import FakeEntryPoint, install_entrypoints
|
||||||
|
|
||||||
CODEX_ENGINE = EngineId("codex")
|
CODEX_ENGINE = "codex"
|
||||||
|
|
||||||
|
|
||||||
def _empty_projects() -> ProjectsConfig:
|
def _empty_projects() -> ProjectsConfig:
|
||||||
@@ -905,9 +906,7 @@ async def test_handle_file_put_writes_file(tmp_path: Path) -> None:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await telegram_commands._handle_file_put(
|
await _handle_file_put(cfg, msg, "/proj uploads/hello.txt", None, None)
|
||||||
cfg, msg, "/proj uploads/hello.txt", None, None
|
|
||||||
)
|
|
||||||
|
|
||||||
target = tmp_path / "uploads" / "hello.txt"
|
target = tmp_path / "uploads" / "hello.txt"
|
||||||
assert target.read_bytes() == payload
|
assert target.read_bytes() == payload
|
||||||
@@ -966,7 +965,7 @@ async def test_handle_file_get_sends_document_for_allowed_user(
|
|||||||
chat_type="supergroup",
|
chat_type="supergroup",
|
||||||
)
|
)
|
||||||
|
|
||||||
await telegram_commands._handle_file_get(cfg, msg, "/proj hello.txt", None, None)
|
await _handle_file_get(cfg, msg, "/proj hello.txt", None, None)
|
||||||
|
|
||||||
assert bot.document_calls
|
assert bot.document_calls
|
||||||
assert bot.document_calls[0]["filename"] == "hello.txt"
|
assert bot.document_calls[0]["filename"] == "hello.txt"
|
||||||
@@ -1263,7 +1262,7 @@ async def test_send_with_resume_reports_when_missing() -> None:
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_run_engine_hides_resume_line_in_topics() -> None:
|
async def test_run_engine_hides_resume_line_in_topics() -> None:
|
||||||
transport = telegram_commands._CaptureTransport()
|
transport = _CaptureTransport()
|
||||||
runner = ScriptRunner(
|
runner = ScriptRunner(
|
||||||
[Return(answer="ok")],
|
[Return(answer="ok")],
|
||||||
engine=CODEX_ENGINE,
|
engine=CODEX_ENGINE,
|
||||||
@@ -1279,7 +1278,7 @@ async def test_run_engine_hides_resume_line_in_topics() -> None:
|
|||||||
projects=_empty_projects(),
|
projects=_empty_projects(),
|
||||||
)
|
)
|
||||||
|
|
||||||
await telegram_commands._run_engine(
|
await _run_engine(
|
||||||
exec_cfg=exec_cfg,
|
exec_cfg=exec_cfg,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
running_tasks={},
|
running_tasks={},
|
||||||
@@ -1456,14 +1455,14 @@ async def test_run_main_loop_auto_resumes_topic_default_engine(
|
|||||||
123, 77, ResumeToken(engine=CODEX_ENGINE, value="resume-codex")
|
123, 77, ResumeToken(engine=CODEX_ENGINE, value="resume-codex")
|
||||||
)
|
)
|
||||||
await store.set_session_resume(
|
await store.set_session_resume(
|
||||||
123, 77, ResumeToken(engine=EngineId("claude"), value="resume-claude")
|
123, 77, ResumeToken(engine="claude", value="resume-claude")
|
||||||
)
|
)
|
||||||
await store.set_default_engine(123, 77, "claude")
|
await store.set_default_engine(123, 77, "claude")
|
||||||
|
|
||||||
transport = _FakeTransport()
|
transport = _FakeTransport()
|
||||||
bot = _FakeBot()
|
bot = _FakeBot()
|
||||||
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
|
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
|
||||||
claude_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("claude"))
|
claude_runner = ScriptRunner([Return(answer="ok")], engine="claude")
|
||||||
router = AutoRouter(
|
router = AutoRouter(
|
||||||
entries=[
|
entries=[
|
||||||
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
|
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
|
||||||
@@ -1521,7 +1520,7 @@ async def test_run_main_loop_auto_resumes_topic_default_engine(
|
|||||||
assert codex_runner.calls == []
|
assert codex_runner.calls == []
|
||||||
assert len(claude_runner.calls) == 1
|
assert len(claude_runner.calls) == 1
|
||||||
assert claude_runner.calls[0][1] == ResumeToken(
|
assert claude_runner.calls[0][1] == ResumeToken(
|
||||||
engine=EngineId("claude"), value="resume-claude"
|
engine="claude", value="resume-claude"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -2443,7 +2442,7 @@ async def test_run_main_loop_command_uses_project_default_engine(
|
|||||||
transport = _FakeTransport()
|
transport = _FakeTransport()
|
||||||
bot = _FakeBot()
|
bot = _FakeBot()
|
||||||
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
|
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
|
||||||
pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
|
pi_runner = ScriptRunner([Return(answer="ok")], engine="pi")
|
||||||
router = AutoRouter(
|
router = AutoRouter(
|
||||||
entries=[
|
entries=[
|
||||||
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
|
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
|
||||||
@@ -2525,7 +2524,7 @@ async def test_run_main_loop_command_defaults_to_chat_project(
|
|||||||
transport = _FakeTransport()
|
transport = _FakeTransport()
|
||||||
bot = _FakeBot()
|
bot = _FakeBot()
|
||||||
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
|
codex_runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
|
||||||
pi_runner = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
|
pi_runner = ScriptRunner([Return(answer="ok")], engine="pi")
|
||||||
router = AutoRouter(
|
router = AutoRouter(
|
||||||
entries=[
|
entries=[
|
||||||
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
|
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import pytest
|
|||||||
|
|
||||||
from takopi.logging import setup_logging
|
from takopi.logging import setup_logging
|
||||||
from takopi.telegram.client import TelegramClient, TelegramRetryAfter
|
from takopi.telegram.client import TelegramClient, TelegramRetryAfter
|
||||||
|
from takopi.telegram.client_api import HttpBotClient
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -25,9 +26,9 @@ async def test_telegram_429_no_retry() -> None:
|
|||||||
|
|
||||||
client = httpx.AsyncClient(transport=transport)
|
client = httpx.AsyncClient(transport=transport)
|
||||||
try:
|
try:
|
||||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
|
||||||
with pytest.raises(TelegramRetryAfter) as exc:
|
with pytest.raises(TelegramRetryAfter) as exc:
|
||||||
await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
await api._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
||||||
finally:
|
finally:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
@@ -49,8 +50,8 @@ async def test_no_token_in_logs_on_http_error(
|
|||||||
|
|
||||||
client = httpx.AsyncClient(transport=transport)
|
client = httpx.AsyncClient(transport=transport)
|
||||||
try:
|
try:
|
||||||
tg = TelegramClient(token, http_client=client)
|
api = HttpBotClient(token, http_client=client)
|
||||||
await tg._post("getUpdates", {"timeout": 1})
|
await api._post("getUpdates", {"timeout": 1})
|
||||||
finally:
|
finally:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
@@ -79,9 +80,9 @@ async def test_telegram_429_no_retry_post_form() -> None:
|
|||||||
|
|
||||||
client = httpx.AsyncClient(transport=transport)
|
client = httpx.AsyncClient(transport=transport)
|
||||||
try:
|
try:
|
||||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
|
||||||
with pytest.raises(TelegramRetryAfter) as exc:
|
with pytest.raises(TelegramRetryAfter) as exc:
|
||||||
await tg._post_form(
|
await api._post_form(
|
||||||
"sendDocument",
|
"sendDocument",
|
||||||
{"chat_id": 1},
|
{"chat_id": 1},
|
||||||
files={"document": ("note.txt", b"hi")},
|
files={"document": ("note.txt", b"hi")},
|
||||||
@@ -102,9 +103,9 @@ async def test_telegram_429_defaults_retry_after_on_bad_body() -> None:
|
|||||||
|
|
||||||
client = httpx.AsyncClient(transport=transport)
|
client = httpx.AsyncClient(transport=transport)
|
||||||
try:
|
try:
|
||||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
|
||||||
with pytest.raises(TelegramRetryAfter) as exc:
|
with pytest.raises(TelegramRetryAfter) as exc:
|
||||||
await tg._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
await api._post("sendMessage", {"chat_id": 1, "text": "hi"})
|
||||||
finally:
|
finally:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
@@ -124,8 +125,8 @@ async def test_telegram_ok_false_returns_none() -> None:
|
|||||||
|
|
||||||
client = httpx.AsyncClient(transport=transport)
|
client = httpx.AsyncClient(transport=transport)
|
||||||
try:
|
try:
|
||||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
|
||||||
result = await tg._post("getUpdates", {"timeout": 1})
|
result = await api._post("getUpdates", {"timeout": 1})
|
||||||
finally:
|
finally:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
@@ -141,8 +142,8 @@ async def test_telegram_invalid_payload_returns_none() -> None:
|
|||||||
|
|
||||||
client = httpx.AsyncClient(transport=transport)
|
client = httpx.AsyncClient(transport=transport)
|
||||||
try:
|
try:
|
||||||
tg = TelegramClient("123:abcDEF_ghij", http_client=client)
|
api = HttpBotClient("123:abcDEF_ghij", http_client=client)
|
||||||
result = await tg._post("getUpdates", {"timeout": 1})
|
result = await api._post("getUpdates", {"timeout": 1})
|
||||||
finally:
|
finally:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import pytest
|
|||||||
|
|
||||||
from takopi.config import ProjectConfig, ProjectsConfig
|
from takopi.config import ProjectConfig, ProjectsConfig
|
||||||
from takopi.context import RunContext
|
from takopi.context import RunContext
|
||||||
from takopi.model import EngineId
|
|
||||||
from takopi.router import AutoRouter, RunnerEntry
|
from takopi.router import AutoRouter, RunnerEntry
|
||||||
from takopi.runners.mock import Return, ScriptRunner
|
from takopi.runners.mock import Return, ScriptRunner
|
||||||
from takopi.telegram.chat_prefs import ChatPrefsStore
|
from takopi.telegram.chat_prefs import ChatPrefsStore
|
||||||
@@ -15,8 +14,8 @@ from takopi.transport_runtime import TransportRuntime
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_resolve_engine_for_message_sources(tmp_path) -> None:
|
async def test_resolve_engine_for_message_sources(tmp_path) -> None:
|
||||||
codex = ScriptRunner([Return(answer="ok")], engine=EngineId("codex"))
|
codex = ScriptRunner([Return(answer="ok")], engine="codex")
|
||||||
pi = ScriptRunner([Return(answer="ok")], engine=EngineId("pi"))
|
pi = ScriptRunner([Return(answer="ok")], engine="pi")
|
||||||
router = AutoRouter(
|
router = AutoRouter(
|
||||||
entries=[
|
entries=[
|
||||||
RunnerEntry(engine=codex.engine, runner=codex),
|
RunnerEntry(engine=codex.engine, runner=codex),
|
||||||
@@ -42,7 +41,7 @@ async def test_resolve_engine_for_message_sources(tmp_path) -> None:
|
|||||||
resolved = await resolve_engine_for_message(
|
resolved = await resolve_engine_for_message(
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
context=RunContext(project="proj"),
|
context=RunContext(project="proj"),
|
||||||
explicit_engine=EngineId("codex"),
|
explicit_engine="codex",
|
||||||
chat_id=1,
|
chat_id=1,
|
||||||
topic_key=(1, 10),
|
topic_key=(1, 10),
|
||||||
topic_store=topic_store,
|
topic_store=topic_store,
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ class DummyTransport:
|
|||||||
def interactive_setup(self, *, force: bool) -> bool:
|
def interactive_setup(self, *, force: bool) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def lock_token(self, *, transport_config: object, config_path):
|
def lock_token(self, *, transport_config: object, _config_path):
|
||||||
_ = transport_config, config_path
|
_ = transport_config, _config_path
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def build_and_run(
|
def build_and_run(
|
||||||
|
|||||||
Reference in New Issue
Block a user