fix: code review (#16)

This commit is contained in:
banteg
2026-01-01 23:14:17 +04:00
committed by GitHub
parent 035441c889
commit 12dfaded26
14 changed files with 78 additions and 58 deletions
+2 -1
View File
@@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from .runner import Runner from .runner import Runner
+4 -13
View File
@@ -4,7 +4,6 @@ from __future__ import annotations
import logging import logging
import time import time
import inspect
from collections import deque from collections import deque
from collections.abc import AsyncIterator, Awaitable, Callable from collections.abc import AsyncIterator, Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -28,12 +27,6 @@ def _resolve_resume(
return runner.extract_resume(text) or runner.extract_resume(reply_text) return runner.extract_resume(text) or runner.extract_resume(reply_text)
def _summarize_error(error: str | None) -> str:
if not error:
return "error"
return error
def _log_runner_event(evt: TakopiEvent) -> None: def _log_runner_event(evt: TakopiEvent) -> None:
for line in render_event_cli(evt): for line in render_event_cli(evt):
logger.info("[runner] %s", line) logger.info("[runner] %s", line)
@@ -41,7 +34,7 @@ def _log_runner_event(evt: TakopiEvent) -> None:
if evt.ok: if evt.ok:
logger.info("[runner] done") logger.info("[runner] done")
else: else:
logger.info("[runner] error: %s", _summarize_error(evt.error)) logger.info("[runner] error: %s", evt.error or "error")
def _is_cancel_command(text: str) -> bool: def _is_cancel_command(text: str) -> bool:
@@ -516,7 +509,7 @@ async def handle_message(
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id) await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
async def poll_updates(cfg: BridgeConfig): async def poll_updates(cfg: BridgeConfig) -> AsyncIterator[dict[str, Any]]:
offset: int | None = None offset: int | None = None
offset = await _drain_backlog(cfg, offset) offset = await _drain_backlog(cfg, offset)
await _send_startup(cfg) await _send_startup(cfg)
@@ -605,7 +598,7 @@ async def _wait_for_resume(running_task: RunningTask) -> ResumeToken | None:
async def _send_with_resume( async def _send_with_resume(
bot: BotClient, bot: BotClient,
enqueue: Callable[[int, int, str, ResumeToken], Awaitable[None] | None], enqueue: Callable[[int, int, str, ResumeToken], Awaitable[None]],
running_task: RunningTask, running_task: RunningTask,
chat_id: int, chat_id: int,
user_msg_id: int, user_msg_id: int,
@@ -620,9 +613,7 @@ async def _send_with_resume(
disable_notification=True, disable_notification=True,
) )
return return
result = enqueue(chat_id, user_msg_id, text, resume) await enqueue(chat_id, user_msg_id, text, resume)
if inspect.isawaitable(result):
await result
async def _run_main_loop( async def _run_main_loop(
+1 -1
View File
@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
from typing import Callable from collections.abc import Callable
import anyio import anyio
import typer import typer
+13 -9
View File
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import shutil
import tomllib import tomllib
from pathlib import Path from pathlib import Path
@@ -38,8 +39,7 @@ def _maybe_migrate_legacy(legacy_path: Path, target_path: Path) -> None:
return return
try: try:
target_path.parent.mkdir(parents=True, exist_ok=True) target_path.parent.mkdir(parents=True, exist_ok=True)
raw = legacy_path.read_text(encoding="utf-8") shutil.move(legacy_path, target_path)
target_path.write_text(raw, encoding="utf-8")
except OSError as e: except OSError as e:
raise ConfigError( raise ConfigError(
f"Failed to migrate legacy config {legacy_path} to {target_path}: {e}" f"Failed to migrate legacy config {legacy_path} to {target_path}: {e}"
@@ -64,19 +64,23 @@ def load_telegram_config(path: str | Path | None = None) -> tuple[dict, Path]:
cfg_path = Path(path).expanduser() cfg_path = Path(path).expanduser()
return _read_config(cfg_path), cfg_path return _read_config(cfg_path), cfg_path
for legacy, target in zip(_legacy_candidates(), _config_candidates(), strict=True): config_candidates = _config_candidates()
legacy_candidates = _legacy_candidates()
for legacy, target in zip(legacy_candidates, config_candidates, strict=True):
_maybe_migrate_legacy(legacy, target) _maybe_migrate_legacy(legacy, target)
candidates = _config_candidates() for candidate in config_candidates:
for candidate in candidates:
if candidate.is_file(): if candidate.is_file():
return _read_config(candidate), candidate return _read_config(candidate), candidate
legacy_candidates = _legacy_candidates()
for candidate in legacy_candidates: for candidate in legacy_candidates:
if candidate.is_file(): if candidate.is_file():
return _read_config(candidate), candidate return _read_config(candidate), candidate
if len(candidates) == 1: checked: list[Path] = []
raise ConfigError("Missing takopi config.") for candidate in [*config_candidates, *legacy_candidates]:
raise ConfigError("Missing takopi config.") if candidate in checked:
continue
checked.append(candidate)
checked_display = ", ".join(str(candidate) for candidate in checked)
raise ConfigError(f"Missing takopi config. Checked: {checked_display}")
+13 -16
View File
@@ -2,14 +2,15 @@ from __future__ import annotations
import importlib import importlib
import pkgutil import pkgutil
from collections.abc import Mapping
from functools import cache
from pathlib import Path from pathlib import Path
from types import MappingProxyType
from typing import Any from typing import Any
from .backends import EngineBackend, EngineConfig from .backends import EngineBackend, EngineConfig
from .config import ConfigError from .config import ConfigError
_BACKENDS: dict[str, EngineBackend] | None = None
def _discover_backends() -> dict[str, EngineBackend]: def _discover_backends() -> dict[str, EngineBackend]:
import takopi.runners as runners_pkg import takopi.runners as runners_pkg
@@ -33,34 +34,30 @@ def _discover_backends() -> dict[str, EngineBackend]:
return backends return backends
def _ensure_loaded() -> None: @cache
global _BACKENDS def _backends() -> Mapping[str, EngineBackend]:
if _BACKENDS is None: backends = _discover_backends()
_BACKENDS = _discover_backends() return MappingProxyType(backends)
def get_backend(engine_id: str) -> EngineBackend: def get_backend(engine_id: str) -> EngineBackend:
_ensure_loaded() backends = _backends()
assert _BACKENDS is not None
try: try:
return _BACKENDS[engine_id] return backends[engine_id]
except KeyError as exc: except KeyError as exc:
available = ", ".join(sorted(_BACKENDS)) available = ", ".join(sorted(backends))
raise ConfigError( raise ConfigError(
f"Unknown engine {engine_id!r}. Available: {available}." f"Unknown engine {engine_id!r}. Available: {available}."
) from exc ) from exc
def list_backends() -> list[EngineBackend]: def list_backends() -> list[EngineBackend]:
_ensure_loaded() backends = _backends()
assert _BACKENDS is not None return [backends[key] for key in sorted(backends)]
return [_BACKENDS[key] for key in sorted(_BACKENDS)]
def list_backend_ids() -> list[str]: def list_backend_ids() -> list[str]:
_ensure_loaded() return sorted(_backends())
assert _BACKENDS is not None
return sorted(_BACKENDS)
def get_engine_config( def get_engine_config(
+2 -1
View File
@@ -3,7 +3,8 @@
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Any, Callable from collections.abc import Callable
from typing import Any
from markdown_it import MarkdownIt from markdown_it import MarkdownIt
from sulguk import transform_html from sulguk import transform_html
+1 -1
View File
@@ -4,8 +4,8 @@ from __future__ import annotations
import textwrap import textwrap
from collections import deque from collections import deque
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Callable
from .model import Action, ActionEvent, ResumeToken, StartedEvent, TakopiEvent from .model import Action, ActionEvent, ResumeToken, StartedEvent, TakopiEvent
from .utils.paths import relativize_path from .utils.paths import relativize_path
+4 -4
View File
@@ -126,7 +126,7 @@ class BaseRunner(SessionLockMixin):
raise NotImplementedError raise NotImplementedError
@dataclass @dataclass(slots=True)
class JsonlRunState: class JsonlRunState:
note_seq: int = 0 note_seq: int = 0
@@ -302,12 +302,12 @@ class JsonlSubprocessRunner(BaseRunner):
tag = self.tag() tag = self.tag()
logger = self.get_logger() logger = self.get_logger()
args = [self.command(), *self.build_args(prompt, resume, state=state)] cmd = [self.command(), *self.build_args(prompt, resume, state=state)]
payload = self.stdin_payload(prompt, resume, state=state) payload = self.stdin_payload(prompt, resume, state=state)
env = self.env(state=state) env = self.env(state=state)
async with manage_subprocess( async with manage_subprocess(
*args, cmd,
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
@@ -318,7 +318,7 @@ class JsonlSubprocessRunner(BaseRunner):
if payload is not None and proc.stdin is None: if payload is not None and proc.stdin is None:
raise RuntimeError(self.pipes_error_message()) raise RuntimeError(self.pipes_error_message())
logger.debug("[%s] spawn pid=%s args=%r", tag, proc.pid, args) logger.debug("[%s] spawn pid=%s args=%r", tag, proc.pid, cmd)
if payload is not None: if payload is not None:
assert proc.stdin is not None assert proc.stdin is not None
+1 -1
View File
@@ -31,7 +31,7 @@ _RESUME_RE = re.compile(
) )
@dataclass @dataclass(slots=True)
class ClaudeStreamState: class ClaudeStreamState:
pending_actions: dict[str, Action] = field(default_factory=dict) pending_actions: dict[str, Action] = field(default_factory=dict)
last_assistant_text: str | None = None last_assistant_text: str | None = None
+1 -1
View File
@@ -381,7 +381,7 @@ def translate_codex_event(event: dict[str, Any], *, title: str) -> list[TakopiEv
return [] return []
@dataclass @dataclass(slots=True)
class CodexRunState: class CodexRunState:
note_seq: int = 0 note_seq: int = 0
final_answer: str | None = None final_answer: str | None = None
+19 -3
View File
@@ -72,14 +72,30 @@ class TelegramClient:
return None return None
try: try:
payload = resp.json() resp.raise_for_status()
except Exception as e: except httpx.HTTPStatusError as e:
body = resp.text
logger.error( logger.error(
"[telegram] bad response method=%s status=%s url=%s: %s", "[telegram] http error method=%s status=%s url=%s: %s body=%r",
method, method,
resp.status_code, resp.status_code,
resp.request.url, resp.request.url,
e, e,
body,
)
return None
try:
payload = resp.json()
except Exception as e:
body = resp.text
logger.error(
"[telegram] bad response method=%s status=%s url=%s: %s body=%r",
method,
resp.status_code,
resp.request.url,
e,
body,
) )
return None return None
+6 -2
View File
@@ -3,7 +3,9 @@ from __future__ import annotations
import logging import logging
import os import os
import signal import signal
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any
import anyio import anyio
from anyio.abc import Process from anyio.abc import Process
@@ -52,11 +54,13 @@ def kill_process(proc: Process) -> None:
@asynccontextmanager @asynccontextmanager
async def manage_subprocess(*args, **kwargs): async def manage_subprocess(
cmd: Sequence[str], **kwargs: Any
) -> AsyncIterator[Process]:
"""Ensure subprocesses receive SIGTERM, then SIGKILL after a 2s timeout.""" """Ensure subprocesses receive SIGTERM, then SIGKILL after a 2s timeout."""
if os.name == "posix": if os.name == "posix":
kwargs.setdefault("start_new_session", True) kwargs.setdefault("start_new_session", True)
proc = await anyio.open_process(args, **kwargs) proc = await anyio.open_process(cmd, **kwargs)
try: try:
yield proc yield proc
finally: finally:
+6 -2
View File
@@ -715,7 +715,9 @@ async def test_send_with_resume_waits_for_token() -> None:
bot = _FakeBot() bot = _FakeBot()
sent: list[tuple[int, int, str, ResumeToken | None]] = [] sent: list[tuple[int, int, str, ResumeToken | None]] = []
def enqueue(chat_id: int, user_msg_id: int, text: str, resume: ResumeToken) -> None: async def enqueue(
chat_id: int, user_msg_id: int, text: str, resume: ResumeToken
) -> None:
sent.append((chat_id, user_msg_id, text, resume)) sent.append((chat_id, user_msg_id, text, resume))
running_task = RunningTask() running_task = RunningTask()
@@ -748,7 +750,9 @@ async def test_send_with_resume_reports_when_missing() -> None:
bot = _FakeBot() bot = _FakeBot()
sent: list[tuple[int, int, str, ResumeToken | None]] = [] sent: list[tuple[int, int, str, ResumeToken | None]] = []
def enqueue(chat_id: int, user_msg_id: int, text: str, resume: ResumeToken) -> None: async def enqueue(
chat_id: int, user_msg_id: int, text: str, resume: ResumeToken
) -> None:
sent.append((chat_id, user_msg_id, text, resume)) sent.append((chat_id, user_msg_id, text, resume))
running_task = RunningTask() running_task = RunningTask()
+5 -3
View File
@@ -16,9 +16,11 @@ async def test_manage_subprocess_kills_when_terminate_times_out(
monkeypatch.setattr(subprocess_utils, "wait_for_process", fake_wait_for_process) monkeypatch.setattr(subprocess_utils, "wait_for_process", fake_wait_for_process)
async with subprocess_utils.manage_subprocess( async with subprocess_utils.manage_subprocess(
sys.executable, [
"-c", sys.executable,
"import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(10)", "-c",
"import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(10)",
]
) as proc: ) as proc:
assert proc.returncode is None assert proc.returncode is None