feat: projects and worktree management (#62)
This commit is contained in:
@@ -10,6 +10,19 @@
|
||||
|
||||
- add transport/presenter protocols plus transport-agnostic `exec_bridge`
|
||||
- move Telegram polling + wiring into `takopi.bridges.telegram` with transport/presenter adapters
|
||||
- add project configuration, directive parsing (`/project`, `@branch`), and `ctx:`-aware routing for runs
|
||||
- add `takopi init` to register project aliases from the main checkout (with worktree defaults)
|
||||
- resolve git worktrees on demand and run engine subprocesses in the project/worktree cwd
|
||||
- list configured projects in the startup banner
|
||||
- add a shared incoming message shape plus Telegram parsing helpers
|
||||
|
||||
### fixes
|
||||
|
||||
- render `ctx:` footer lines consistently (backticked + hard breaks) and include them in final messages
|
||||
|
||||
### docs
|
||||
|
||||
- add a projects/worktrees guide and document `takopi init` behavior in the README
|
||||
|
||||
## v0.8.0 (2026-01-05)
|
||||
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
# Projects and Worktrees
|
||||
|
||||
This doc covers project aliases, worktree behavior, and how Takopi resolves run
|
||||
context from messages.
|
||||
|
||||
## Overview
|
||||
|
||||
Projects let you give a repo an alias (used as `/alias` in messages) and opt into
|
||||
worktree-based runs via `@branch`.
|
||||
|
||||
- If no projects are configured, Takopi runs in the startup working directory.
|
||||
- If a project is configured, `@branch` resolves/creates a git worktree and runs
|
||||
the task in that worktree.
|
||||
- Progress/final messages include a `ctx:` footer when project context is active.
|
||||
|
||||
## Config schema
|
||||
|
||||
All config lives in `~/.takopi/takopi.toml`.
|
||||
|
||||
```toml
|
||||
default_engine = "codex" # optional
|
||||
default_project = "z80" # optional
|
||||
bot_token = "..." # required
|
||||
chat_id = 123 # required
|
||||
|
||||
[projects.z80]
|
||||
path = "~/dev/z80" # required (repo root)
|
||||
worktrees_dir = ".worktrees" # optional, default ".worktrees"
|
||||
default_engine = "codex" # optional, per-project override
|
||||
worktree_base = "master" # optional, base for new branches
|
||||
```
|
||||
|
||||
Validation rules:
|
||||
|
||||
- `projects` is optional.
|
||||
- Each project entry must include `path` (string, non-empty).
|
||||
- `default_project` must match a configured project alias.
|
||||
- Project aliases cannot collide with engine ids or reserved commands (`/cancel`).
|
||||
- `default_engine` and per-project `default_engine` must be valid engine ids.
|
||||
|
||||
## `takopi init`
|
||||
|
||||
`takopi init <alias>` registers the current repo as a project alias.
|
||||
|
||||
Important behavior:
|
||||
|
||||
- The stored `path` is the **main checkout** of the repo, even if you run
|
||||
`takopi init` inside a worktree. Takopi resolves the repo root via the git
|
||||
common dir and writes that path to `[projects.<alias>].path`.
|
||||
- `worktree_base` is set from the current repo using this resolution order:
|
||||
`origin/HEAD` → current branch → `master` → `main`.
|
||||
|
||||
## Directives and context resolution
|
||||
|
||||
Takopi parses the first non-empty line of a message for a directive prefix.
|
||||
|
||||
Supported directives:
|
||||
|
||||
- `/engine` or `/engine@bot`: chooses the engine
|
||||
- `/project`: chooses a project alias
|
||||
- `@branch`: chooses a git branch/worktree
|
||||
|
||||
Rules:
|
||||
|
||||
- Directives must be a contiguous prefix of the line; parsing stops at the first
|
||||
non-directive token.
|
||||
- At most one engine directive, one project directive, and one `@branch` are
|
||||
allowed (duplicates are errors).
|
||||
- If a reply contains a `ctx:` line, Takopi **ignores new directives** and uses
|
||||
the reply context.
|
||||
|
||||
## Context footer (`ctx:`)
|
||||
|
||||
When a run has project context, Takopi appends a footer line rendered as inline
|
||||
code (backticked):
|
||||
|
||||
- With branch: `` `ctx: <project> @ <branch>` ``
|
||||
- Without branch: `` `ctx: <project>` ``
|
||||
|
||||
The `ctx:` line is parsed from replies and takes precedence over new directives.
|
||||
|
||||
## Worktree resolution
|
||||
|
||||
When `@branch` is present:
|
||||
|
||||
```
|
||||
worktrees_root = <project.path> / <worktrees_dir>
|
||||
worktree_path = worktrees_root / <branch>
|
||||
```
|
||||
|
||||
Branch validation:
|
||||
|
||||
- Must be non-empty
|
||||
- Must not start with `/`
|
||||
- Must not contain `..` path segments
|
||||
- May include `/` (nested directories)
|
||||
- The resolved worktree path must stay within `worktrees_root`
|
||||
|
||||
Worktree creation rules:
|
||||
|
||||
1) If `worktree_path` exists:
|
||||
- It must be a git worktree or Takopi errors.
|
||||
2) If it does not exist:
|
||||
- If local branch exists: `git worktree add <path> <branch>`
|
||||
- Else if remote `origin/<branch>` exists:
|
||||
`git worktree add -b <branch> <path> origin/<branch>`
|
||||
- Else:
|
||||
`git worktree add -b <branch> <path> <base>`
|
||||
|
||||
Base branch selection:
|
||||
|
||||
1) `projects.<alias>.worktree_base` (if set)
|
||||
2) `origin/HEAD` (if present)
|
||||
3) current checked out branch
|
||||
4) `master` if it exists
|
||||
5) `main` if it exists
|
||||
6) otherwise error
|
||||
|
||||
When `@branch` is omitted:
|
||||
|
||||
- Takopi runs in `<project.path>` (the main checkout).
|
||||
|
||||
## Examples
|
||||
|
||||
Start a new thread in a worktree:
|
||||
|
||||
```
|
||||
/z80 @feat/streaming fix flaky test
|
||||
```
|
||||
|
||||
Reply to a progress message to continue in the same context:
|
||||
|
||||
```
|
||||
ctx: z80 @ feat/streaming
|
||||
```
|
||||
@@ -76,6 +76,28 @@ provider = "openai"
|
||||
extra_args = ["--no-color"]
|
||||
```
|
||||
|
||||
## projects (optional)
|
||||
|
||||
register the current repo as a project alias:
|
||||
|
||||
```sh
|
||||
takopi init z80
|
||||
```
|
||||
|
||||
`takopi init` writes the repo root to `[projects.<alias>].path`. if you run it inside a git worktree, it resolves the main checkout and records that path instead of the worktree.
|
||||
|
||||
example:
|
||||
|
||||
```toml
|
||||
default_project = "z80"
|
||||
|
||||
[projects.z80]
|
||||
path = "~/dev/z80"
|
||||
worktrees_dir = ".worktrees"
|
||||
default_engine = "codex"
|
||||
worktree_base = "master"
|
||||
```
|
||||
|
||||
## usage
|
||||
|
||||
start takopi in the repo you want to work on:
|
||||
|
||||
+126
-3
@@ -11,10 +11,17 @@ import typer
|
||||
|
||||
from . import __version__
|
||||
from .backends import EngineBackend
|
||||
from .config import ConfigError
|
||||
from .config import (
|
||||
ConfigError,
|
||||
load_or_init_config,
|
||||
parse_projects_config,
|
||||
write_config,
|
||||
)
|
||||
from .engines import get_backend, get_engine_config, list_backends
|
||||
from .lockfile import LockError, LockHandle, acquire_lock, token_fingerprint
|
||||
from .logging import get_logger, setup_logging
|
||||
from .router import AutoRouter, RunnerEntry
|
||||
from .runner_bridge import ExecBridgeConfig
|
||||
from .telegram.bridge import (
|
||||
TelegramBridgeConfig,
|
||||
TelegramPresenter,
|
||||
@@ -24,8 +31,7 @@ from .telegram.bridge import (
|
||||
from .telegram.client import TelegramClient
|
||||
from .telegram.config import load_telegram_config
|
||||
from .telegram.onboarding import SetupResult, check_setup, interactive_setup
|
||||
from .router import AutoRouter, RunnerEntry
|
||||
from .runner_bridge import ExecBridgeConfig
|
||||
from .utils.git import resolve_default_base, resolve_main_worktree_root
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -195,6 +201,12 @@ def _parse_bridge_config(
|
||||
startup_pwd = os.getcwd()
|
||||
|
||||
backends = list_backends()
|
||||
projects = parse_projects_config(
|
||||
config,
|
||||
config_path=config_path,
|
||||
engine_ids=[backend.id for backend in backends],
|
||||
reserved=("cancel",),
|
||||
)
|
||||
default_engine = _resolve_default_engine(
|
||||
override=default_engine_override,
|
||||
config=config,
|
||||
@@ -212,10 +224,16 @@ def _parse_bridge_config(
|
||||
engine_list = ", ".join(available_engines) if available_engines else "none"
|
||||
if missing_engines:
|
||||
engine_list = f"{engine_list} (not installed: {', '.join(missing_engines)})"
|
||||
project_aliases = sorted(
|
||||
{project.alias for project in projects.projects.values()},
|
||||
key=str.lower,
|
||||
)
|
||||
project_list = ", ".join(project_aliases) if project_aliases else "none"
|
||||
startup_msg = (
|
||||
f"\N{OCTOPUS} **takopi is ready**\n\n"
|
||||
f"default: `{router.default_engine}` \n"
|
||||
f"agents: `{engine_list}` \n"
|
||||
f"projects: `{project_list}` \n"
|
||||
f"working in: `{startup_pwd}`"
|
||||
)
|
||||
|
||||
@@ -234,6 +252,7 @@ def _parse_bridge_config(
|
||||
chat_id=chat_id,
|
||||
startup_msg=startup_msg,
|
||||
exec_cfg=exec_cfg,
|
||||
projects=projects,
|
||||
)
|
||||
|
||||
|
||||
@@ -320,6 +339,107 @@ def _run_auto_router(
|
||||
lock_handle.release()
|
||||
|
||||
|
||||
def _prompt_alias(value: str | None, *, default_alias: str | None = None) -> str:
|
||||
if value is not None:
|
||||
alias = value
|
||||
elif default_alias:
|
||||
alias = typer.prompt("project alias", default=default_alias)
|
||||
else:
|
||||
alias = typer.prompt("project alias")
|
||||
alias = alias.strip()
|
||||
if not alias:
|
||||
typer.echo("error: project alias cannot be empty", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
return alias
|
||||
|
||||
|
||||
def _default_alias_from_path(path: Path) -> str | None:
|
||||
name = path.name
|
||||
if not name:
|
||||
return None
|
||||
if name.endswith(".git"):
|
||||
name = name[: -len(".git")]
|
||||
return name or None
|
||||
|
||||
|
||||
def _ensure_projects_table(config: dict, config_path: Path) -> dict:
|
||||
projects = config.get("projects")
|
||||
if projects is None:
|
||||
projects = {}
|
||||
config["projects"] = projects
|
||||
if not isinstance(projects, dict):
|
||||
raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.")
|
||||
return projects
|
||||
|
||||
|
||||
def init(
|
||||
alias: str | None = typer.Argument(
|
||||
None, help="Project alias (used as /alias in messages)."
|
||||
),
|
||||
default: bool = typer.Option(
|
||||
False,
|
||||
"--default",
|
||||
help="Set this project as the default_project.",
|
||||
),
|
||||
) -> None:
|
||||
"""Register the current repo as a Takopi project."""
|
||||
config, config_path = load_or_init_config()
|
||||
|
||||
cwd = Path.cwd()
|
||||
project_path = resolve_main_worktree_root(cwd) or cwd
|
||||
default_alias = _default_alias_from_path(project_path)
|
||||
alias = _prompt_alias(alias, default_alias=default_alias)
|
||||
|
||||
engine_ids = [backend.id for backend in list_backends()]
|
||||
projects_cfg = parse_projects_config(
|
||||
config,
|
||||
config_path=config_path,
|
||||
engine_ids=engine_ids,
|
||||
reserved=("cancel",),
|
||||
)
|
||||
|
||||
alias_key = alias.lower()
|
||||
if alias_key in {engine.lower() for engine in engine_ids}:
|
||||
raise ConfigError(
|
||||
f"Invalid project alias {alias!r}; aliases must not match engine ids."
|
||||
)
|
||||
if alias_key == "cancel":
|
||||
raise ConfigError(
|
||||
f"Invalid project alias {alias!r}; aliases must not match reserved commands."
|
||||
)
|
||||
|
||||
existing = projects_cfg.projects.get(alias_key)
|
||||
if existing is not None:
|
||||
overwrite = typer.confirm(
|
||||
f"project {existing.alias!r} already exists, overwrite?",
|
||||
default=False,
|
||||
)
|
||||
if not overwrite:
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
projects = _ensure_projects_table(config, config_path)
|
||||
if existing is not None and existing.alias in projects:
|
||||
projects.pop(existing.alias, None)
|
||||
|
||||
default_engine = _default_engine_for_setup(None)
|
||||
worktree_base = resolve_default_base(project_path)
|
||||
|
||||
entry: dict[str, object] = {
|
||||
"path": str(project_path),
|
||||
"worktrees_dir": ".worktrees",
|
||||
"default_engine": default_engine,
|
||||
}
|
||||
if worktree_base:
|
||||
entry["worktree_base"] = worktree_base
|
||||
|
||||
projects[alias] = entry
|
||||
if default:
|
||||
config["default_project"] = alias
|
||||
|
||||
write_config(config, config_path)
|
||||
typer.echo(f"saved project {alias!r} to {_config_path_display(config_path)}")
|
||||
|
||||
|
||||
app = typer.Typer(
|
||||
add_completion=False,
|
||||
invoke_without_command=True,
|
||||
@@ -327,6 +447,9 @@ app = typer.Typer(
|
||||
)
|
||||
|
||||
|
||||
app.command(name="init")(init)
|
||||
|
||||
|
||||
@app.callback()
|
||||
def app_main(
|
||||
ctx: typer.Context,
|
||||
|
||||
@@ -1,5 +1,260 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tomllib
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
HOME_CONFIG_PATH = Path.home() / ".takopi" / "takopi.toml"
|
||||
|
||||
|
||||
class ConfigError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _read_config(cfg_path: Path) -> dict:
|
||||
try:
|
||||
raw = cfg_path.read_text(encoding="utf-8")
|
||||
except FileNotFoundError:
|
||||
raise ConfigError(f"Missing config file {cfg_path}.") from None
|
||||
except OSError as e:
|
||||
raise ConfigError(f"Failed to read config file {cfg_path}: {e}") from e
|
||||
try:
|
||||
return tomllib.loads(raw)
|
||||
except tomllib.TOMLDecodeError as e:
|
||||
raise ConfigError(f"Malformed TOML in {cfg_path}: {e}") from None
|
||||
|
||||
|
||||
def load_or_init_config(path: str | Path | None = None) -> tuple[dict, Path]:
|
||||
cfg_path = Path(path).expanduser() if path else HOME_CONFIG_PATH
|
||||
if cfg_path.exists() and not cfg_path.is_file():
|
||||
raise ConfigError(f"Config path {cfg_path} exists but is not a file.") from None
|
||||
if not cfg_path.exists():
|
||||
return {}, cfg_path
|
||||
return _read_config(cfg_path), cfg_path
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProjectConfig:
|
||||
alias: str
|
||||
path: Path
|
||||
worktrees_dir: Path
|
||||
default_engine: str | None = None
|
||||
worktree_base: str | None = None
|
||||
|
||||
@property
|
||||
def worktrees_root(self) -> Path:
|
||||
if self.worktrees_dir.is_absolute():
|
||||
return self.worktrees_dir
|
||||
return self.path / self.worktrees_dir
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProjectsConfig:
|
||||
projects: dict[str, ProjectConfig]
|
||||
default_project: str | None = None
|
||||
|
||||
def resolve(self, alias: str | None) -> ProjectConfig | None:
|
||||
if alias is None:
|
||||
if self.default_project is None:
|
||||
return None
|
||||
return self.projects.get(self.default_project)
|
||||
return self.projects.get(alias.lower())
|
||||
|
||||
|
||||
def empty_projects_config() -> ProjectsConfig:
|
||||
return ProjectsConfig(projects={}, default_project=None)
|
||||
|
||||
|
||||
def _normalize_engine_id(
|
||||
value: str,
|
||||
*,
|
||||
engine_ids: Iterable[str],
|
||||
config_path: Path,
|
||||
label: str,
|
||||
) -> str:
|
||||
engine_map = {engine.lower(): engine for engine in engine_ids}
|
||||
cleaned = value.strip()
|
||||
if not cleaned:
|
||||
raise ConfigError(f"Invalid `{label}` in {config_path}; expected a string.")
|
||||
engine = engine_map.get(cleaned.lower())
|
||||
if engine is None:
|
||||
available = ", ".join(sorted(engine_map.values()))
|
||||
raise ConfigError(
|
||||
f"Unknown `{label}` {cleaned!r} in {config_path}. Available: {available}."
|
||||
)
|
||||
return engine
|
||||
|
||||
|
||||
def _normalize_project_path(value: str, *, config_path: Path) -> Path:
|
||||
path = Path(value).expanduser()
|
||||
if not path.is_absolute():
|
||||
path = config_path.parent / path
|
||||
return path
|
||||
|
||||
|
||||
def parse_projects_config(
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
config_path: Path,
|
||||
engine_ids: Iterable[str],
|
||||
reserved: Iterable[str] = ("cancel",),
|
||||
) -> ProjectsConfig:
|
||||
default_project_raw = config.get("default_project")
|
||||
default_project = None
|
||||
if default_project_raw is not None:
|
||||
if not isinstance(default_project_raw, str) or not default_project_raw.strip():
|
||||
raise ConfigError(
|
||||
f"Invalid `default_project` in {config_path}; expected a non-empty string."
|
||||
)
|
||||
default_project = default_project_raw.strip()
|
||||
|
||||
projects_raw = config.get("projects") or {}
|
||||
if not isinstance(projects_raw, dict):
|
||||
raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.")
|
||||
|
||||
reserved_lower = {value.lower() for value in reserved}
|
||||
engine_lower = {value.lower() for value in engine_ids}
|
||||
projects: dict[str, ProjectConfig] = {}
|
||||
|
||||
for raw_alias, raw_entry in projects_raw.items():
|
||||
if not isinstance(raw_alias, str) or not raw_alias.strip():
|
||||
raise ConfigError(
|
||||
f"Invalid project alias in {config_path}; expected a non-empty string."
|
||||
)
|
||||
alias = raw_alias.strip()
|
||||
alias_key = alias.lower()
|
||||
if alias_key in engine_lower or alias_key in reserved_lower:
|
||||
raise ConfigError(
|
||||
f"Invalid project alias {alias!r} in {config_path}; "
|
||||
"aliases must not match engine ids or reserved commands."
|
||||
)
|
||||
if alias_key in projects:
|
||||
raise ConfigError(f"Duplicate project alias {alias!r} in {config_path}.")
|
||||
if not isinstance(raw_entry, dict):
|
||||
raise ConfigError(
|
||||
f"Invalid project entry for {alias!r} in {config_path}; expected a table."
|
||||
)
|
||||
|
||||
path_value = raw_entry.get("path")
|
||||
if not isinstance(path_value, str) or not path_value.strip():
|
||||
raise ConfigError(f"Missing `path` for project {alias!r} in {config_path}.")
|
||||
path = _normalize_project_path(path_value.strip(), config_path=config_path)
|
||||
|
||||
worktrees_dir_raw = raw_entry.get("worktrees_dir", ".worktrees")
|
||||
if not isinstance(worktrees_dir_raw, str) or not worktrees_dir_raw.strip():
|
||||
raise ConfigError(
|
||||
f"Invalid `worktrees_dir` for project {alias!r} in {config_path}."
|
||||
)
|
||||
worktrees_dir = Path(worktrees_dir_raw.strip())
|
||||
|
||||
default_engine_raw = raw_entry.get("default_engine")
|
||||
default_engine = None
|
||||
if default_engine_raw is not None:
|
||||
if not isinstance(default_engine_raw, str):
|
||||
raise ConfigError(
|
||||
f"Invalid `projects.{alias}.default_engine` in {config_path}; "
|
||||
"expected a string."
|
||||
)
|
||||
default_engine = _normalize_engine_id(
|
||||
default_engine_raw,
|
||||
engine_ids=engine_ids,
|
||||
config_path=config_path,
|
||||
label=f"projects.{alias}.default_engine",
|
||||
)
|
||||
|
||||
worktree_base_raw = raw_entry.get("worktree_base")
|
||||
worktree_base = None
|
||||
if worktree_base_raw is not None:
|
||||
if not isinstance(worktree_base_raw, str) or not worktree_base_raw.strip():
|
||||
raise ConfigError(
|
||||
f"Invalid `projects.{alias}.worktree_base` in {config_path}; "
|
||||
"expected a string."
|
||||
)
|
||||
worktree_base = worktree_base_raw.strip()
|
||||
|
||||
projects[alias_key] = ProjectConfig(
|
||||
alias=alias,
|
||||
path=path,
|
||||
worktrees_dir=worktrees_dir,
|
||||
default_engine=default_engine,
|
||||
worktree_base=worktree_base,
|
||||
)
|
||||
|
||||
if default_project is not None:
|
||||
default_key = default_project.lower()
|
||||
if default_key not in projects:
|
||||
raise ConfigError(
|
||||
f"Invalid `default_project` {default_project!r} in {config_path}; "
|
||||
"no matching project alias found."
|
||||
)
|
||||
default_project = default_key
|
||||
|
||||
return ProjectsConfig(projects=projects, default_project=default_project)
|
||||
|
||||
|
||||
def _toml_escape(value: str) -> str:
|
||||
return value.replace("\\", "\\\\").replace('"', '\\"')
|
||||
|
||||
|
||||
def _format_toml_value(value: Any) -> str:
|
||||
if isinstance(value, bool):
|
||||
return "true" if value else "false"
|
||||
if isinstance(value, int):
|
||||
return str(value)
|
||||
if isinstance(value, float):
|
||||
return repr(value)
|
||||
if isinstance(value, str):
|
||||
return f'"{_toml_escape(value)}"'
|
||||
if isinstance(value, (list, tuple)):
|
||||
inner = ", ".join(_format_toml_value(item) for item in value)
|
||||
return f"[{inner}]"
|
||||
raise ConfigError(f"Unsupported config value {value!r}")
|
||||
|
||||
|
||||
def _table_has_scalars(table: dict[str, Any]) -> bool:
|
||||
return any(not isinstance(value, dict) for value in table.values())
|
||||
|
||||
|
||||
def dump_toml(config: dict[str, Any]) -> str:
|
||||
lines: list[str] = []
|
||||
|
||||
def write_kv(key: str, value: Any) -> None:
|
||||
lines.append(f"{key} = {_format_toml_value(value)}")
|
||||
|
||||
def write_table(name: str, table: dict[str, Any]) -> None:
|
||||
if lines and lines[-1] != "":
|
||||
lines.append("")
|
||||
lines.append(f"[{name}]")
|
||||
for key, value in table.items():
|
||||
if isinstance(value, dict):
|
||||
continue
|
||||
write_kv(key, value)
|
||||
for key, value in table.items():
|
||||
if isinstance(value, dict):
|
||||
write_table(f"{name}.{key}", value)
|
||||
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict):
|
||||
continue
|
||||
write_kv(key, value)
|
||||
|
||||
for key, value in config.items():
|
||||
if not isinstance(value, dict):
|
||||
continue
|
||||
if _table_has_scalars(value):
|
||||
write_table(key, value)
|
||||
continue
|
||||
for subkey, subvalue in value.items():
|
||||
if isinstance(subvalue, dict):
|
||||
write_table(f"{key}.{subkey}", subvalue)
|
||||
else:
|
||||
write_table(key, value)
|
||||
break
|
||||
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def write_config(config: dict[str, Any], path: Path) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(dump_toml(config), encoding="utf-8")
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RunContext:
|
||||
project: str | None = None
|
||||
branch: str | None = None
|
||||
+34
-7
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -94,12 +95,16 @@ def format_file_change_title(action: Action, *, command_width: int | None) -> st
|
||||
if isinstance(changes, list) and changes:
|
||||
rendered: list[str] = []
|
||||
for raw in changes:
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
path = raw.get("path")
|
||||
path: str | None
|
||||
kind: str | None
|
||||
if isinstance(raw, dict):
|
||||
path = raw.get("path")
|
||||
kind = raw.get("kind")
|
||||
else:
|
||||
path = getattr(raw, "path", None)
|
||||
kind = getattr(raw, "kind", None)
|
||||
if not isinstance(path, str) or not path:
|
||||
continue
|
||||
kind = raw.get("kind")
|
||||
verb = kind if isinstance(kind, str) and kind else "update"
|
||||
rendered.append(f"{verb} {format_changed_file_path(path)}")
|
||||
|
||||
@@ -110,7 +115,15 @@ def format_file_change_title(action: Action, *, command_width: int | None) -> st
|
||||
inline = shorten(", ".join(rendered), command_width)
|
||||
return f"files: {inline}"
|
||||
|
||||
return f"files: {shorten(title, command_width)}"
|
||||
fallback = title
|
||||
relativized = relativize_path(fallback)
|
||||
was_relativized = relativized != fallback
|
||||
if was_relativized:
|
||||
fallback = relativized
|
||||
if fallback and not (fallback.startswith("`") and fallback.endswith("`")):
|
||||
if was_relativized or os.sep in fallback or "/" in fallback:
|
||||
fallback = f"`{fallback}`"
|
||||
return f"files: {shorten(fallback, command_width)}"
|
||||
|
||||
|
||||
def format_action_title(action: Action, *, command_width: int | None) -> str:
|
||||
@@ -197,7 +210,9 @@ class MarkdownFormatter:
|
||||
engine=state.engine,
|
||||
)
|
||||
body = self._assemble_body(self._format_actions(state))
|
||||
return MarkdownParts(header=header, body=body, footer=state.resume_line)
|
||||
return MarkdownParts(
|
||||
header=header, body=body, footer=self._format_footer(state)
|
||||
)
|
||||
|
||||
def render_final_parts(
|
||||
self,
|
||||
@@ -216,7 +231,19 @@ class MarkdownFormatter:
|
||||
)
|
||||
answer = (answer or "").strip()
|
||||
body = answer if answer else None
|
||||
return MarkdownParts(header=header, body=body, footer=state.resume_line)
|
||||
return MarkdownParts(
|
||||
header=header, body=body, footer=self._format_footer(state)
|
||||
)
|
||||
|
||||
def _format_footer(self, state: ProgressState) -> str | None:
|
||||
lines: list[str] = []
|
||||
if state.context_line:
|
||||
lines.append(state.context_line)
|
||||
if state.resume_line:
|
||||
lines.append(state.resume_line)
|
||||
if not lines:
|
||||
return None
|
||||
return HARD_BREAK.join(lines)
|
||||
|
||||
def _format_actions(self, state: ProgressState) -> list[str]:
|
||||
actions = list(state.actions)
|
||||
|
||||
@@ -24,6 +24,7 @@ class ProgressState:
|
||||
actions: tuple[ActionState, ...]
|
||||
resume: ResumeToken | None
|
||||
resume_line: str | None
|
||||
context_line: str | None
|
||||
|
||||
|
||||
class ProgressTracker:
|
||||
@@ -80,6 +81,7 @@ class ProgressTracker:
|
||||
self,
|
||||
*,
|
||||
resume_formatter: Callable[[ResumeToken], str] | None = None,
|
||||
context_line: str | None = None,
|
||||
) -> ProgressState:
|
||||
resume_line: str | None = None
|
||||
if self.resume is not None and resume_formatter is not None:
|
||||
@@ -93,4 +95,5 @@ class ProgressTracker:
|
||||
actions=actions,
|
||||
resume=self.resume,
|
||||
resume_line=resume_line,
|
||||
context_line=context_line,
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from .model import (
|
||||
StartedEvent,
|
||||
TakopiEvent,
|
||||
)
|
||||
from .utils.paths import get_run_base_dir
|
||||
from .utils.streams import drain_stderr, iter_bytes_lines
|
||||
from .utils.subprocess import manage_subprocess
|
||||
|
||||
@@ -358,12 +359,15 @@ class JsonlSubprocessRunner(BaseRunner):
|
||||
prompt_len=len(prompt),
|
||||
)
|
||||
|
||||
cwd = get_run_base_dir()
|
||||
|
||||
async with manage_subprocess(
|
||||
cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=env,
|
||||
cwd=cwd,
|
||||
) as proc:
|
||||
if proc.stdout is None or proc.stderr is None:
|
||||
raise RuntimeError(self.pipes_error_message())
|
||||
|
||||
@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
import anyio
|
||||
|
||||
from .context import RunContext
|
||||
from .logging import bind_run_context, get_logger
|
||||
from .model import CompletedEvent, ResumeToken, StartedEvent, TakopiEvent
|
||||
from .presenter import Presenter
|
||||
@@ -93,6 +94,7 @@ class RunningTask:
|
||||
resume_ready: anyio.Event = field(default_factory=anyio.Event)
|
||||
cancel_requested: anyio.Event = field(default_factory=anyio.Event)
|
||||
done: anyio.Event = field(default_factory=anyio.Event)
|
||||
context: RunContext | None = None
|
||||
|
||||
|
||||
RunningTasks = dict[MessageRef, RunningTask]
|
||||
@@ -152,6 +154,7 @@ class ProgressEdits:
|
||||
last_rendered: RenderedMessage | None,
|
||||
resume_formatter: Callable[[ResumeToken], str] | None = None,
|
||||
label: str = "working",
|
||||
context_line: str | None = None,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
self.presenter = presenter
|
||||
@@ -163,6 +166,7 @@ class ProgressEdits:
|
||||
self.last_rendered = last_rendered
|
||||
self.resume_formatter = resume_formatter
|
||||
self.label = label
|
||||
self.context_line = context_line
|
||||
self.event_seq = 0
|
||||
self.rendered_seq = 0
|
||||
self.signal_send, self.signal_recv = anyio.create_memory_object_stream(1)
|
||||
@@ -179,7 +183,10 @@ class ProgressEdits:
|
||||
|
||||
seq_at_render = self.event_seq
|
||||
now = self.clock()
|
||||
state = self.tracker.snapshot(resume_formatter=self.resume_formatter)
|
||||
state = self.tracker.snapshot(
|
||||
resume_formatter=self.resume_formatter,
|
||||
context_line=self.context_line,
|
||||
)
|
||||
rendered = self.presenter.render_progress(
|
||||
state, elapsed_s=now - self.started_at, label=self.label
|
||||
)
|
||||
@@ -228,11 +235,15 @@ async def send_initial_progress(
|
||||
label: str,
|
||||
tracker: ProgressTracker,
|
||||
resume_formatter: Callable[[ResumeToken], str] | None = None,
|
||||
context_line: str | None = None,
|
||||
) -> ProgressMessageState:
|
||||
progress_ref: MessageRef | None = None
|
||||
last_rendered: RenderedMessage | None = None
|
||||
|
||||
state = tracker.snapshot(resume_formatter=resume_formatter)
|
||||
state = tracker.snapshot(
|
||||
resume_formatter=resume_formatter,
|
||||
context_line=context_line,
|
||||
)
|
||||
initial_rendered = cfg.presenter.render_progress(
|
||||
state,
|
||||
elapsed_s=0.0,
|
||||
@@ -364,6 +375,8 @@ async def handle_message(
|
||||
runner: Runner,
|
||||
incoming: IncomingMessage,
|
||||
resume_token: ResumeToken | None,
|
||||
context: RunContext | None = None,
|
||||
context_line: str | None = None,
|
||||
strip_resume_line: Callable[[str], bool] | None = None,
|
||||
running_tasks: RunningTasks | None = None,
|
||||
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
|
||||
@@ -395,6 +408,7 @@ async def handle_message(
|
||||
label="starting",
|
||||
tracker=progress_tracker,
|
||||
resume_formatter=runner.format_resume,
|
||||
context_line=context_line,
|
||||
)
|
||||
progress_ref = progress_state.ref
|
||||
|
||||
@@ -408,11 +422,12 @@ async def handle_message(
|
||||
clock=clock,
|
||||
last_rendered=progress_state.last_rendered,
|
||||
resume_formatter=runner.format_resume,
|
||||
context_line=context_line,
|
||||
)
|
||||
|
||||
running_task: RunningTask | None = None
|
||||
if running_tasks is not None and progress_ref is not None:
|
||||
running_task = RunningTask()
|
||||
running_task = RunningTask(context=context)
|
||||
running_tasks[progress_ref] = running_task
|
||||
|
||||
cancel_exc_type = anyio.get_cancelled_exc_class()
|
||||
@@ -464,7 +479,10 @@ async def handle_message(
|
||||
if error is not None:
|
||||
sync_resume_token(progress_tracker, outcome.resume)
|
||||
err_body = _format_error(error)
|
||||
state = progress_tracker.snapshot(resume_formatter=runner.format_resume)
|
||||
state = progress_tracker.snapshot(
|
||||
resume_formatter=runner.format_resume,
|
||||
context_line=context_line,
|
||||
)
|
||||
final_rendered = cfg.presenter.render_final(
|
||||
state,
|
||||
elapsed_s=elapsed,
|
||||
@@ -496,7 +514,10 @@ async def handle_message(
|
||||
resume=resume.value if resume else None,
|
||||
elapsed_s=elapsed,
|
||||
)
|
||||
state = progress_tracker.snapshot(resume_formatter=runner.format_resume)
|
||||
state = progress_tracker.snapshot(
|
||||
resume_formatter=runner.format_resume,
|
||||
context_line=context_line,
|
||||
)
|
||||
final_rendered = cfg.presenter.render_progress(
|
||||
state,
|
||||
elapsed_s=elapsed,
|
||||
@@ -546,7 +567,10 @@ async def handle_message(
|
||||
resume=resume_value,
|
||||
)
|
||||
sync_resume_token(progress_tracker, completed.resume or outcome.resume)
|
||||
state = progress_tracker.snapshot(resume_formatter=runner.format_resume)
|
||||
state = progress_tracker.snapshot(
|
||||
resume_formatter=runner.format_resume,
|
||||
context_line=context_line,
|
||||
)
|
||||
final_rendered = cfg.presenter.render_final(
|
||||
state,
|
||||
elapsed_s=elapsed,
|
||||
|
||||
@@ -76,6 +76,26 @@ def _summarize_tool_result(result: Any) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_change_list(changes: list[Any]) -> list[dict[str, str]]:
|
||||
normalized: list[dict[str, str]] = []
|
||||
for change in changes:
|
||||
path: str | None = None
|
||||
kind: str | None = None
|
||||
if isinstance(change, codex_schema.FileUpdateChange):
|
||||
path = change.path
|
||||
kind = change.kind
|
||||
elif isinstance(change, dict):
|
||||
path = change.get("path")
|
||||
kind = change.get("kind")
|
||||
if not isinstance(path, str) or not path:
|
||||
continue
|
||||
entry = {"path": path}
|
||||
if isinstance(kind, str) and kind:
|
||||
entry["kind"] = kind
|
||||
normalized.append(entry)
|
||||
return normalized
|
||||
|
||||
|
||||
def _format_change_summary(changes: list[Any]) -> str:
|
||||
paths: list[str] = []
|
||||
for change in changes:
|
||||
@@ -260,8 +280,9 @@ def _translate_item_event(
|
||||
if phase != "completed":
|
||||
return []
|
||||
title = _format_change_summary(changes)
|
||||
normalized_changes = _normalize_change_list(changes)
|
||||
detail = {
|
||||
"changes": changes,
|
||||
"changes": normalized_changes,
|
||||
"status": status,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Awaitable, Callable, Protocol
|
||||
|
||||
import anyio
|
||||
|
||||
from .context import RunContext
|
||||
from .model import ResumeToken
|
||||
|
||||
|
||||
@@ -15,6 +16,7 @@ class ThreadJob:
|
||||
user_msg_id: int
|
||||
text: str
|
||||
resume_token: ResumeToken
|
||||
context: RunContext | None = None
|
||||
|
||||
|
||||
RunJob = Callable[[ThreadJob], Awaitable[None]]
|
||||
@@ -66,6 +68,7 @@ class ThreadScheduler:
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
resume_token: ResumeToken,
|
||||
context: RunContext | None = None,
|
||||
) -> None:
|
||||
await self.enqueue(
|
||||
ThreadJob(
|
||||
@@ -73,6 +76,7 @@ class ThreadScheduler:
|
||||
user_msg_id=user_msg_id,
|
||||
text=text,
|
||||
resume_token=resume_token,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
"""Telegram-specific clients and adapters."""
|
||||
|
||||
from .client import parse_incoming_update, poll_incoming
|
||||
|
||||
__all__ = ["parse_incoming_update", "poll_incoming"]
|
||||
|
||||
+322
-81
@@ -1,14 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
import anyio
|
||||
|
||||
from ..config import ProjectsConfig, empty_projects_config
|
||||
from ..context import RunContext
|
||||
from ..runner_bridge import (
|
||||
ExecBridgeConfig,
|
||||
IncomingMessage,
|
||||
IncomingMessage as RunnerIncomingMessage,
|
||||
RunningTask,
|
||||
RunningTasks,
|
||||
handle_message,
|
||||
@@ -20,8 +22,16 @@ from ..progress import ProgressState, ProgressTracker
|
||||
from ..router import AutoRouter, RunnerUnavailableError
|
||||
from ..runner import Runner
|
||||
from ..scheduler import ThreadJob, ThreadScheduler
|
||||
from ..transport import MessageRef, RenderedMessage, SendOptions, Transport
|
||||
from .client import BotClient
|
||||
from ..transport import (
|
||||
IncomingMessage as TransportIncomingMessage,
|
||||
MessageRef,
|
||||
RenderedMessage,
|
||||
SendOptions,
|
||||
Transport,
|
||||
)
|
||||
from ..utils.paths import reset_run_base_dir, set_run_base_dir
|
||||
from ..worktrees import WorktreeError, resolve_run_cwd
|
||||
from .client import BotClient, poll_incoming
|
||||
from .render import prepare_telegram
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -70,6 +80,206 @@ def _strip_engine_command(
|
||||
return "\n".join(lines).strip(), engine
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ParsedDirectives:
|
||||
prompt: str
|
||||
engine: EngineId | None
|
||||
project: str | None
|
||||
branch: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ResolvedMessage:
|
||||
prompt: str
|
||||
resume_token: ResumeToken | None
|
||||
engine_override: EngineId | None
|
||||
context: RunContext | None
|
||||
|
||||
|
||||
class DirectiveError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _parse_directives(
|
||||
text: str,
|
||||
*,
|
||||
engine_ids: tuple[EngineId, ...],
|
||||
projects: ProjectsConfig,
|
||||
) -> ParsedDirectives:
|
||||
if not text:
|
||||
return ParsedDirectives(prompt="", engine=None, project=None, branch=None)
|
||||
|
||||
lines = text.splitlines()
|
||||
idx = next((i for i, line in enumerate(lines) if line.strip()), None)
|
||||
if idx is None:
|
||||
return ParsedDirectives(prompt=text, engine=None, project=None, branch=None)
|
||||
|
||||
line = lines[idx].lstrip()
|
||||
tokens = line.split()
|
||||
if not tokens:
|
||||
return ParsedDirectives(prompt=text, engine=None, project=None, branch=None)
|
||||
|
||||
engine_map = {engine.lower(): engine for engine in engine_ids}
|
||||
project_map = {alias.lower(): alias for alias in projects.projects}
|
||||
|
||||
engine: EngineId | None = None
|
||||
project: str | None = None
|
||||
branch: str | None = None
|
||||
consumed = 0
|
||||
|
||||
for token in tokens:
|
||||
if token.startswith("/"):
|
||||
name = token[1:]
|
||||
if "@" in name:
|
||||
name = name.split("@", 1)[0]
|
||||
if not name:
|
||||
break
|
||||
key = name.lower()
|
||||
engine_candidate = engine_map.get(key)
|
||||
project_candidate = project_map.get(key)
|
||||
if engine_candidate is not None:
|
||||
if engine is not None:
|
||||
raise DirectiveError("multiple engine directives")
|
||||
engine = engine_candidate
|
||||
consumed += 1
|
||||
continue
|
||||
if project_candidate is not None:
|
||||
if project is not None:
|
||||
raise DirectiveError("multiple project directives")
|
||||
project = project_candidate
|
||||
consumed += 1
|
||||
continue
|
||||
break
|
||||
if token.startswith("@"):
|
||||
value = token[1:]
|
||||
if not value:
|
||||
break
|
||||
if branch is not None:
|
||||
raise DirectiveError("multiple @branch directives")
|
||||
branch = value
|
||||
consumed += 1
|
||||
continue
|
||||
break
|
||||
|
||||
if consumed == 0:
|
||||
return ParsedDirectives(prompt=text, engine=None, project=None, branch=None)
|
||||
|
||||
if consumed < len(tokens):
|
||||
remainder = " ".join(tokens[consumed:])
|
||||
lines[idx] = remainder
|
||||
else:
|
||||
lines.pop(idx)
|
||||
|
||||
prompt = "\n".join(lines).strip()
|
||||
return ParsedDirectives(
|
||||
prompt=prompt, engine=engine, project=project, branch=branch
|
||||
)
|
||||
|
||||
|
||||
def _parse_ctx_line(text: str | None, *, projects: ProjectsConfig) -> RunContext | None:
|
||||
if not text:
|
||||
return None
|
||||
ctx: RunContext | None = None
|
||||
for line in text.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("`") and stripped.endswith("`") and len(stripped) > 1:
|
||||
stripped = stripped[1:-1].strip()
|
||||
elif stripped.startswith("`"):
|
||||
stripped = stripped[1:].strip()
|
||||
elif stripped.endswith("`"):
|
||||
stripped = stripped[:-1].strip()
|
||||
if not stripped.lower().startswith("ctx:"):
|
||||
continue
|
||||
content = stripped.split(":", 1)[1].strip()
|
||||
if not content:
|
||||
continue
|
||||
tokens = content.split()
|
||||
if not tokens:
|
||||
continue
|
||||
project = tokens[0]
|
||||
branch = None
|
||||
if len(tokens) >= 2:
|
||||
if tokens[1] == "@" and len(tokens) >= 3:
|
||||
branch = tokens[2]
|
||||
elif tokens[1].startswith("@"):
|
||||
branch = tokens[1][1:]
|
||||
project_key = project.lower()
|
||||
if project_key not in projects.projects:
|
||||
raise DirectiveError(f"unknown project {project!r} in ctx line")
|
||||
ctx = RunContext(project=project_key, branch=branch)
|
||||
return ctx
|
||||
|
||||
|
||||
def _format_context_line(
|
||||
context: RunContext | None, *, projects: ProjectsConfig
|
||||
) -> str | None:
|
||||
if context is None or context.project is None:
|
||||
return None
|
||||
project_cfg = projects.projects.get(context.project)
|
||||
alias = project_cfg.alias if project_cfg is not None else context.project
|
||||
if context.branch:
|
||||
return f"`ctx: {alias} @ {context.branch}`"
|
||||
return f"`ctx: {alias}`"
|
||||
|
||||
|
||||
def _resolve_message(
|
||||
*,
|
||||
text: str,
|
||||
reply_text: str | None,
|
||||
router: AutoRouter,
|
||||
projects: ProjectsConfig,
|
||||
) -> ResolvedMessage:
|
||||
directives = _parse_directives(
|
||||
text,
|
||||
engine_ids=router.engine_ids,
|
||||
projects=projects,
|
||||
)
|
||||
reply_ctx = _parse_ctx_line(reply_text, projects=projects)
|
||||
resume_token = router.resolve_resume(directives.prompt, reply_text)
|
||||
|
||||
if resume_token is not None:
|
||||
return ResolvedMessage(
|
||||
prompt=directives.prompt,
|
||||
resume_token=resume_token,
|
||||
engine_override=None,
|
||||
context=reply_ctx,
|
||||
)
|
||||
|
||||
if reply_ctx is not None:
|
||||
engine_override = None
|
||||
if reply_ctx.project is not None:
|
||||
project = projects.projects.get(reply_ctx.project)
|
||||
if project is not None and project.default_engine is not None:
|
||||
engine_override = project.default_engine
|
||||
return ResolvedMessage(
|
||||
prompt=directives.prompt,
|
||||
resume_token=None,
|
||||
engine_override=engine_override,
|
||||
context=reply_ctx,
|
||||
)
|
||||
|
||||
project_key = directives.project
|
||||
if project_key is None and projects.default_project is not None:
|
||||
project_key = projects.default_project
|
||||
|
||||
context = None
|
||||
if project_key is not None or directives.branch is not None:
|
||||
context = RunContext(project=project_key, branch=directives.branch)
|
||||
|
||||
engine_override = directives.engine
|
||||
if engine_override is None and project_key is not None:
|
||||
project = projects.projects.get(project_key)
|
||||
if project is not None and project.default_engine is not None:
|
||||
engine_override = project.default_engine
|
||||
|
||||
return ResolvedMessage(
|
||||
prompt=directives.prompt,
|
||||
resume_token=None,
|
||||
engine_override=engine_override,
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
def _build_bot_commands(router: AutoRouter) -> list[dict[str, str]]:
|
||||
commands: list[dict[str, str]] = []
|
||||
seen: set[str] = set()
|
||||
@@ -232,6 +442,7 @@ class TelegramBridgeConfig:
|
||||
chat_id: int
|
||||
startup_msg: str
|
||||
exec_cfg: ExecBridgeConfig
|
||||
projects: ProjectsConfig = field(default_factory=empty_projects_config)
|
||||
|
||||
|
||||
async def _send_plain(
|
||||
@@ -281,41 +492,35 @@ async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int |
|
||||
drained += len(updates)
|
||||
|
||||
|
||||
async def poll_updates(cfg: TelegramBridgeConfig) -> AsyncIterator[dict[str, Any]]:
|
||||
async def poll_updates(
|
||||
cfg: TelegramBridgeConfig,
|
||||
) -> AsyncIterator[TransportIncomingMessage]:
|
||||
offset: int | None = None
|
||||
offset = await _drain_backlog(cfg, offset)
|
||||
await _send_startup(cfg)
|
||||
|
||||
while True:
|
||||
updates = await cfg.bot.get_updates(
|
||||
offset=offset, timeout_s=50, allowed_updates=["message"]
|
||||
)
|
||||
if updates is None:
|
||||
logger.info("loop.get_updates.failed")
|
||||
await anyio.sleep(2)
|
||||
continue
|
||||
logger.debug("loop.updates", updates=updates)
|
||||
|
||||
for upd in updates:
|
||||
offset = upd["update_id"] + 1
|
||||
msg = upd["message"]
|
||||
if "text" not in msg:
|
||||
continue
|
||||
if msg["chat"]["id"] != cfg.chat_id:
|
||||
continue
|
||||
yield msg
|
||||
async for msg in poll_incoming(cfg.bot, chat_id=cfg.chat_id, offset=offset):
|
||||
yield msg
|
||||
|
||||
|
||||
async def _handle_cancel(
|
||||
cfg: TelegramBridgeConfig,
|
||||
msg: dict[str, Any],
|
||||
msg: TransportIncomingMessage,
|
||||
running_tasks: RunningTasks,
|
||||
) -> None:
|
||||
chat_id = msg["chat"]["id"]
|
||||
user_msg_id = msg["message_id"]
|
||||
reply = msg.get("reply_to_message")
|
||||
chat_id = msg.chat_id
|
||||
user_msg_id = msg.message_id
|
||||
reply_id = msg.reply_to_message_id
|
||||
|
||||
if not reply:
|
||||
if reply_id is None:
|
||||
if msg.reply_to_text:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
text="nothing is currently running for that message.",
|
||||
)
|
||||
return
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=chat_id,
|
||||
@@ -324,17 +529,7 @@ async def _handle_cancel(
|
||||
)
|
||||
return
|
||||
|
||||
progress_id = reply.get("message_id")
|
||||
if progress_id is None:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
text="nothing is currently running for that message.",
|
||||
)
|
||||
return
|
||||
|
||||
progress_ref = MessageRef(channel_id=chat_id, message_id=progress_id)
|
||||
progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id)
|
||||
running_task = running_tasks.get(progress_ref)
|
||||
if running_task is None:
|
||||
await _send_plain(
|
||||
@@ -348,7 +543,7 @@ async def _handle_cancel(
|
||||
logger.info(
|
||||
"cancel.requested",
|
||||
chat_id=chat_id,
|
||||
progress_message_id=progress_id,
|
||||
progress_message_id=reply_id,
|
||||
)
|
||||
running_task.cancel_requested.set()
|
||||
|
||||
@@ -378,7 +573,7 @@ async def _wait_for_resume(running_task: RunningTask) -> ResumeToken | None:
|
||||
|
||||
async def _send_with_resume(
|
||||
cfg: TelegramBridgeConfig,
|
||||
enqueue: Callable[[int, int, str, ResumeToken], Awaitable[None]],
|
||||
enqueue: Callable[[int, int, str, ResumeToken, RunContext | None], Awaitable[None]],
|
||||
running_task: RunningTask,
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
@@ -394,7 +589,7 @@ async def _send_with_resume(
|
||||
notify=False,
|
||||
)
|
||||
return
|
||||
await enqueue(chat_id, user_msg_id, text, resume)
|
||||
await enqueue(chat_id, user_msg_id, text, resume, running_task.context)
|
||||
|
||||
|
||||
async def _send_runner_unavailable(
|
||||
@@ -426,7 +621,7 @@ async def _send_runner_unavailable(
|
||||
async def run_main_loop(
|
||||
cfg: TelegramBridgeConfig,
|
||||
poller: Callable[
|
||||
[TelegramBridgeConfig], AsyncIterator[dict[str, Any]]
|
||||
[TelegramBridgeConfig], AsyncIterator[TransportIncomingMessage]
|
||||
] = poll_updates,
|
||||
) -> None:
|
||||
running_tasks: RunningTasks = {}
|
||||
@@ -440,6 +635,7 @@ async def run_main_loop(
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
resume_token: ResumeToken | None,
|
||||
context: RunContext | None,
|
||||
reply_ref: MessageRef | None = None,
|
||||
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
|
||||
| None = None,
|
||||
@@ -471,27 +667,52 @@ async def run_main_loop(
|
||||
reason=reason,
|
||||
)
|
||||
return
|
||||
bind_run_context(
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
engine=entry.runner.engine,
|
||||
resume=resume_token.value if resume_token else None,
|
||||
)
|
||||
incoming = IncomingMessage(
|
||||
channel_id=chat_id,
|
||||
message_id=user_msg_id,
|
||||
text=text,
|
||||
reply_to=reply_ref,
|
||||
)
|
||||
await handle_message(
|
||||
cfg.exec_cfg,
|
||||
runner=entry.runner,
|
||||
incoming=incoming,
|
||||
resume_token=resume_token,
|
||||
strip_resume_line=cfg.router.is_resume_line,
|
||||
running_tasks=running_tasks,
|
||||
on_thread_known=on_thread_known,
|
||||
)
|
||||
try:
|
||||
cwd = resolve_run_cwd(context, projects=cfg.projects)
|
||||
except WorktreeError as exc:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
text=f"error:\n{exc}",
|
||||
)
|
||||
return
|
||||
run_base_token = set_run_base_dir(cwd)
|
||||
try:
|
||||
run_fields = {
|
||||
"chat_id": chat_id,
|
||||
"user_msg_id": user_msg_id,
|
||||
"engine": entry.runner.engine,
|
||||
"resume": resume_token.value if resume_token else None,
|
||||
}
|
||||
if context is not None:
|
||||
run_fields["project"] = context.project
|
||||
run_fields["branch"] = context.branch
|
||||
if cwd is not None:
|
||||
run_fields["cwd"] = str(cwd)
|
||||
bind_run_context(**run_fields)
|
||||
context_line = _format_context_line(
|
||||
context, projects=cfg.projects
|
||||
)
|
||||
incoming = RunnerIncomingMessage(
|
||||
channel_id=chat_id,
|
||||
message_id=user_msg_id,
|
||||
text=text,
|
||||
reply_to=reply_ref,
|
||||
)
|
||||
await handle_message(
|
||||
cfg.exec_cfg,
|
||||
runner=entry.runner,
|
||||
incoming=incoming,
|
||||
resume_token=resume_token,
|
||||
context=context,
|
||||
context_line=context_line,
|
||||
strip_resume_line=cfg.router.is_resume_line,
|
||||
running_tasks=running_tasks,
|
||||
on_thread_known=on_thread_known,
|
||||
)
|
||||
finally:
|
||||
reset_run_base_dir(run_base_token)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"handle.worker_failed",
|
||||
@@ -507,33 +728,48 @@ async def run_main_loop(
|
||||
job.user_msg_id,
|
||||
job.text,
|
||||
job.resume_token,
|
||||
job.context,
|
||||
None,
|
||||
)
|
||||
|
||||
scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
|
||||
|
||||
async for msg in poller(cfg):
|
||||
text = msg["text"]
|
||||
user_msg_id = msg["message_id"]
|
||||
chat_id = msg["chat"]["id"]
|
||||
reply_ref = None
|
||||
reply_msg = msg.get("reply_to_message")
|
||||
if reply_msg:
|
||||
reply_id = reply_msg.get("message_id")
|
||||
if reply_id is not None:
|
||||
reply_ref = MessageRef(channel_id=chat_id, message_id=reply_id)
|
||||
text = msg.text
|
||||
user_msg_id = msg.message_id
|
||||
chat_id = msg.chat_id
|
||||
reply_id = msg.reply_to_message_id
|
||||
reply_ref = (
|
||||
MessageRef(channel_id=chat_id, message_id=reply_id)
|
||||
if reply_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if _is_cancel_command(text):
|
||||
tg.start_soon(_handle_cancel, cfg, msg, running_tasks)
|
||||
continue
|
||||
|
||||
text, engine_override = _strip_engine_command(
|
||||
text, engine_ids=cfg.router.engine_ids
|
||||
)
|
||||
reply_text = msg.reply_to_text
|
||||
try:
|
||||
resolved = _resolve_message(
|
||||
text=text,
|
||||
reply_text=reply_text,
|
||||
router=cfg.router,
|
||||
projects=cfg.projects,
|
||||
)
|
||||
except DirectiveError as exc:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
text=f"error:\n{exc}",
|
||||
)
|
||||
continue
|
||||
|
||||
r = msg.get("reply_to_message") or {}
|
||||
resume_token = cfg.router.resolve_resume(text, r.get("text"))
|
||||
reply_id = r.get("message_id")
|
||||
text = resolved.prompt
|
||||
resume_token = resolved.resume_token
|
||||
engine_override = resolved.engine_override
|
||||
context = resolved.context
|
||||
if resume_token is None and reply_id is not None:
|
||||
running_task = running_tasks.get(
|
||||
MessageRef(channel_id=chat_id, message_id=reply_id)
|
||||
@@ -557,13 +793,18 @@ async def run_main_loop(
|
||||
user_msg_id,
|
||||
text,
|
||||
None,
|
||||
context,
|
||||
reply_ref,
|
||||
scheduler.note_thread_known,
|
||||
engine_override,
|
||||
)
|
||||
else:
|
||||
await scheduler.enqueue_resume(
|
||||
chat_id, user_msg_id, text, resume_token
|
||||
chat_id,
|
||||
user_msg_id,
|
||||
text,
|
||||
resume_token,
|
||||
context,
|
||||
)
|
||||
finally:
|
||||
await cfg.exec_cfg.transport.close()
|
||||
|
||||
@@ -3,13 +3,22 @@ from __future__ import annotations
|
||||
import itertools
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable, Hashable, Protocol, TYPE_CHECKING
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Hashable,
|
||||
Protocol,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
import anyio
|
||||
|
||||
from ..logging import get_logger
|
||||
from ..transport import IncomingMessage
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -34,6 +43,76 @@ def is_group_chat_id(chat_id: int) -> bool:
|
||||
return chat_id < 0
|
||||
|
||||
|
||||
def parse_incoming_update(
|
||||
update: dict[str, Any], *, chat_id: int
|
||||
) -> IncomingMessage | None:
|
||||
msg = update.get("message")
|
||||
if not isinstance(msg, dict):
|
||||
return None
|
||||
text = msg.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
chat = msg.get("chat")
|
||||
if not isinstance(chat, dict):
|
||||
return None
|
||||
msg_chat_id = chat.get("id")
|
||||
if not isinstance(msg_chat_id, int) or msg_chat_id != chat_id:
|
||||
return None
|
||||
message_id = msg.get("message_id")
|
||||
if not isinstance(message_id, int):
|
||||
return None
|
||||
reply = msg.get("reply_to_message")
|
||||
reply_to_message_id = None
|
||||
reply_to_text = None
|
||||
if isinstance(reply, dict):
|
||||
reply_to_message_id = (
|
||||
reply.get("message_id")
|
||||
if isinstance(reply.get("message_id"), int)
|
||||
else None
|
||||
)
|
||||
reply_to_text = (
|
||||
reply.get("text") if isinstance(reply.get("text"), str) else None
|
||||
)
|
||||
sender = msg.get("from")
|
||||
sender_id = (
|
||||
sender.get("id")
|
||||
if isinstance(sender, dict) and isinstance(sender.get("id"), int)
|
||||
else None
|
||||
)
|
||||
return IncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=msg_chat_id,
|
||||
message_id=message_id,
|
||||
text=text,
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
reply_to_text=reply_to_text,
|
||||
sender_id=sender_id,
|
||||
raw=msg,
|
||||
)
|
||||
|
||||
|
||||
async def poll_incoming(
|
||||
bot: BotClient,
|
||||
*,
|
||||
chat_id: int,
|
||||
offset: int | None = None,
|
||||
) -> AsyncIterator[IncomingMessage]:
|
||||
while True:
|
||||
updates = await bot.get_updates(
|
||||
offset=offset, timeout_s=50, allowed_updates=["message"]
|
||||
)
|
||||
if updates is None:
|
||||
logger.info("loop.get_updates.failed")
|
||||
await anyio.sleep(2)
|
||||
continue
|
||||
logger.debug("loop.updates", updates=updates)
|
||||
for upd in updates:
|
||||
offset = upd["update_id"] + 1
|
||||
msg = parse_incoming_update(upd, chat_id=chat_id)
|
||||
if msg is not None:
|
||||
yield msg
|
||||
|
||||
|
||||
class BotClient(Protocol):
|
||||
async def close(self) -> None: ...
|
||||
|
||||
|
||||
@@ -7,6 +7,18 @@ ChannelId: TypeAlias = int | str
|
||||
MessageId: TypeAlias = int | str
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class IncomingMessage:
|
||||
transport: str
|
||||
chat_id: int
|
||||
message_id: int
|
||||
text: str
|
||||
reply_to_message_id: int | None
|
||||
reply_to_text: str | None
|
||||
sender_id: int | None
|
||||
raw: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MessageRef:
|
||||
channel_id: ChannelId
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _run_git(
|
||||
args: Sequence[str], *, cwd: Path
|
||||
) -> subprocess.CompletedProcess[str] | None:
|
||||
try:
|
||||
return subprocess.run(
|
||||
["git", *args],
|
||||
cwd=cwd,
|
||||
check=False,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def git_run(
|
||||
args: Sequence[str], *, cwd: Path
|
||||
) -> subprocess.CompletedProcess[str] | None:
|
||||
return _run_git(args, cwd=cwd)
|
||||
|
||||
|
||||
def git_stdout(args: Sequence[str], *, cwd: Path) -> str | None:
|
||||
result = _run_git(args, cwd=cwd)
|
||||
if result is None or result.returncode != 0:
|
||||
return None
|
||||
output = result.stdout.strip()
|
||||
return output or None
|
||||
|
||||
|
||||
def git_ok(args: Sequence[str], *, cwd: Path) -> bool:
|
||||
result = _run_git(args, cwd=cwd)
|
||||
return result is not None and result.returncode == 0
|
||||
|
||||
|
||||
def git_is_worktree(path: Path) -> bool:
|
||||
return git_stdout(["rev-parse", "--is-inside-work-tree"], cwd=path) == "true"
|
||||
|
||||
|
||||
def resolve_default_base(root: Path) -> str | None:
|
||||
origin_head = git_stdout(
|
||||
["symbolic-ref", "-q", "refs/remotes/origin/HEAD"],
|
||||
cwd=root,
|
||||
)
|
||||
if origin_head:
|
||||
prefix = "refs/remotes/origin/"
|
||||
if origin_head.startswith(prefix):
|
||||
name = origin_head[len(prefix) :].strip()
|
||||
if name:
|
||||
return name
|
||||
|
||||
current = git_stdout(["branch", "--show-current"], cwd=root)
|
||||
if current:
|
||||
return current
|
||||
|
||||
if git_ok(["show-ref", "--verify", "--quiet", "refs/heads/master"], cwd=root):
|
||||
return "master"
|
||||
if git_ok(["show-ref", "--verify", "--quiet", "refs/heads/main"], cwd=root):
|
||||
return "main"
|
||||
return None
|
||||
|
||||
|
||||
def resolve_main_worktree_root(cwd: Path) -> Path | None:
|
||||
common_dir = git_stdout(
|
||||
["rev-parse", "--path-format=absolute", "--git-common-dir"],
|
||||
cwd=cwd,
|
||||
)
|
||||
if not common_dir:
|
||||
return None
|
||||
if git_stdout(["rev-parse", "--is-bare-repository"], cwd=cwd) == "true":
|
||||
return cwd
|
||||
common_path = Path(common_dir)
|
||||
if not common_path.is_absolute():
|
||||
common_path = (cwd / common_path).resolve()
|
||||
return common_path.parent
|
||||
@@ -1,13 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from contextvars import ContextVar, Token
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
_run_base_dir: ContextVar[Path | None] = ContextVar("takopi_run_base_dir", default=None)
|
||||
|
||||
|
||||
def get_run_base_dir() -> Path | None:
|
||||
return _run_base_dir.get()
|
||||
|
||||
|
||||
def set_run_base_dir(base_dir: Path | None) -> Token[Path | None]:
|
||||
return _run_base_dir.set(base_dir)
|
||||
|
||||
|
||||
def reset_run_base_dir(token: Token[Path | None]) -> None:
|
||||
_run_base_dir.reset(token)
|
||||
|
||||
|
||||
def relativize_path(value: str, *, base_dir: Path | None = None) -> str:
|
||||
if not value:
|
||||
return value
|
||||
base = Path.cwd() if base_dir is None else base_dir
|
||||
base = get_run_base_dir() if base_dir is None else base_dir
|
||||
if base is None:
|
||||
base = Path.cwd()
|
||||
base_str = str(base)
|
||||
if not base_str:
|
||||
return value
|
||||
@@ -22,6 +40,8 @@ def relativize_path(value: str, *, base_dir: Path | None = None) -> str:
|
||||
|
||||
|
||||
def relativize_command(value: str, *, base_dir: Path | None = None) -> str:
|
||||
base = Path.cwd() if base_dir is None else base_dir
|
||||
base = get_run_base_dir() if base_dir is None else base_dir
|
||||
if base is None:
|
||||
base = Path.cwd()
|
||||
base_with_sep = f"{base}{os.sep}"
|
||||
return value.replace(base_with_sep, "")
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .config import ProjectConfig, ProjectsConfig
|
||||
from .context import RunContext
|
||||
from .utils.git import git_is_worktree, git_ok, git_run, resolve_default_base
|
||||
|
||||
|
||||
class WorktreeError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def resolve_run_cwd(
|
||||
context: RunContext | None,
|
||||
*,
|
||||
projects: ProjectsConfig,
|
||||
) -> Path | None:
|
||||
if context is None or context.project is None:
|
||||
return None
|
||||
project = projects.projects.get(context.project)
|
||||
if project is None:
|
||||
raise WorktreeError(f"unknown project {context.project!r}")
|
||||
if context.branch is None:
|
||||
return project.path
|
||||
return ensure_worktree(project, context.branch)
|
||||
|
||||
|
||||
def ensure_worktree(project: ProjectConfig, branch: str) -> Path:
|
||||
root = project.path
|
||||
if not root.exists():
|
||||
raise WorktreeError(f"project path not found: {root}")
|
||||
|
||||
branch = _sanitize_branch(branch)
|
||||
worktrees_root = project.worktrees_root
|
||||
worktree_path = worktrees_root / branch
|
||||
_ensure_within_root(worktrees_root, worktree_path)
|
||||
|
||||
if worktree_path.exists():
|
||||
if not git_is_worktree(worktree_path):
|
||||
raise WorktreeError(f"{worktree_path} exists but is not a git worktree")
|
||||
return worktree_path
|
||||
|
||||
worktrees_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if git_ok(
|
||||
["show-ref", "--verify", "--quiet", f"refs/heads/{branch}"],
|
||||
cwd=root,
|
||||
):
|
||||
_git_worktree_add(root, worktree_path, branch)
|
||||
return worktree_path
|
||||
|
||||
if git_ok(
|
||||
["show-ref", "--verify", "--quiet", f"refs/remotes/origin/{branch}"],
|
||||
cwd=root,
|
||||
):
|
||||
_git_worktree_add(
|
||||
root,
|
||||
worktree_path,
|
||||
branch,
|
||||
base_ref=f"origin/{branch}",
|
||||
create_branch=True,
|
||||
)
|
||||
return worktree_path
|
||||
|
||||
base = project.worktree_base or resolve_default_base(root)
|
||||
if not base:
|
||||
raise WorktreeError("cannot determine base branch for new worktree")
|
||||
|
||||
_git_worktree_add(
|
||||
root,
|
||||
worktree_path,
|
||||
branch,
|
||||
base_ref=base,
|
||||
create_branch=True,
|
||||
)
|
||||
return worktree_path
|
||||
|
||||
|
||||
def _git_worktree_add(
|
||||
root: Path,
|
||||
worktree_path: Path,
|
||||
branch: str,
|
||||
*,
|
||||
base_ref: str | None = None,
|
||||
create_branch: bool = False,
|
||||
) -> None:
|
||||
if create_branch:
|
||||
if not base_ref:
|
||||
raise WorktreeError("missing base ref for worktree creation")
|
||||
args = ["worktree", "add", "-b", branch, str(worktree_path), base_ref]
|
||||
else:
|
||||
args = ["worktree", "add", str(worktree_path), branch]
|
||||
|
||||
result = git_run(args, cwd=root)
|
||||
if result is None:
|
||||
raise WorktreeError("git not available on PATH")
|
||||
if result.returncode != 0:
|
||||
message = result.stderr.strip() or result.stdout.strip()
|
||||
raise WorktreeError(message or "git worktree add failed")
|
||||
|
||||
|
||||
def _sanitize_branch(branch: str) -> str:
|
||||
cleaned = branch.strip()
|
||||
if not cleaned:
|
||||
raise WorktreeError("branch name cannot be empty")
|
||||
if cleaned.startswith("/"):
|
||||
raise WorktreeError("branch name cannot start with '/'")
|
||||
for part in Path(cleaned).parts:
|
||||
if part == "..":
|
||||
raise WorktreeError("branch name cannot contain '..'")
|
||||
return cleaned
|
||||
|
||||
|
||||
def _ensure_within_root(root: Path, path: Path) -> None:
|
||||
root_resolved = root.resolve(strict=False)
|
||||
path_resolved = path.resolve(strict=False)
|
||||
if not path_resolved.is_relative_to(root_resolved):
|
||||
raise WorktreeError("branch path escapes the worktrees directory")
|
||||
@@ -104,3 +104,21 @@ def test_translate_command_execution_allows_null_exit_code() -> None:
|
||||
assert isinstance(out[0], ActionEvent)
|
||||
assert out[0].ok is True
|
||||
assert out[0].action.detail["exit_code"] is None
|
||||
|
||||
|
||||
def test_translate_file_change_normalizes_changes() -> None:
|
||||
evt = {
|
||||
"type": "item.completed",
|
||||
"item": {
|
||||
"id": "item_6",
|
||||
"type": "file_change",
|
||||
"changes": [{"path": "README.md", "kind": "update"}],
|
||||
"status": "completed",
|
||||
},
|
||||
}
|
||||
|
||||
out = _translate_event(evt)
|
||||
assert len(out) == 1
|
||||
assert isinstance(out[0], ActionEvent)
|
||||
changes = out[0].action.detail["changes"]
|
||||
assert changes == [{"path": "README.md", "kind": "update"}]
|
||||
|
||||
@@ -334,6 +334,37 @@ async def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
|
||||
assert transport.send_calls[-1]["options"].replace == transport.send_calls[0]["ref"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_final_message_includes_ctx_line() -> None:
|
||||
transport = _FakeTransport()
|
||||
clock = _FakeClock()
|
||||
session_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
runner = ScriptRunner(
|
||||
[Return(answer="done")],
|
||||
engine=CODEX_ENGINE,
|
||||
resume_value=session_id,
|
||||
)
|
||||
cfg = ExecBridgeConfig(
|
||||
transport=transport,
|
||||
presenter=MarkdownPresenter(),
|
||||
final_notify=True,
|
||||
)
|
||||
|
||||
await handle_message(
|
||||
cfg,
|
||||
runner=runner,
|
||||
incoming=IncomingMessage(channel_id=123, message_id=42, text="do it"),
|
||||
resume_token=None,
|
||||
context_line="`ctx: takopi @ feat/api`",
|
||||
clock=clock,
|
||||
)
|
||||
|
||||
assert transport.send_calls
|
||||
final_text = transport.send_calls[-1]["message"].text
|
||||
assert "`ctx: takopi @ feat/api`" in final_text
|
||||
assert "codex resume" in final_text.lower()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_handle_message_cancelled_renders_cancelled_state() -> None:
|
||||
transport = _FakeTransport()
|
||||
|
||||
@@ -16,6 +16,7 @@ from takopi.markdown import (
|
||||
from takopi.model import Action, ActionEvent, ResumeToken, StartedEvent, TakopiEvent
|
||||
from takopi.progress import ProgressTracker
|
||||
from takopi.telegram.render import render_markdown
|
||||
from takopi.utils.paths import reset_run_base_dir, set_run_base_dir
|
||||
from tests.factories import (
|
||||
action_completed,
|
||||
action_started,
|
||||
@@ -119,6 +120,40 @@ def test_file_change_renders_relative_paths_inside_cwd() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_file_change_renders_change_objects(tmp_path: Path) -> None:
|
||||
base = tmp_path / "repo"
|
||||
base.mkdir()
|
||||
abs_path = str(base / "changelog.md")
|
||||
token = set_run_base_dir(base)
|
||||
try:
|
||||
out = render_event_cli(
|
||||
action_completed(
|
||||
"f-obj",
|
||||
"file_change",
|
||||
"ignored",
|
||||
ok=True,
|
||||
detail={"changes": [SimpleNamespace(path=abs_path, kind="update")]},
|
||||
)
|
||||
)
|
||||
finally:
|
||||
reset_run_base_dir(token)
|
||||
assert any("files: update `changelog.md`" in line for line in out)
|
||||
|
||||
|
||||
def test_file_change_title_relativizes_absolute_title(tmp_path: Path) -> None:
|
||||
base = tmp_path / "repo"
|
||||
base.mkdir()
|
||||
abs_path = str(base / "changelog.md")
|
||||
token = set_run_base_dir(base)
|
||||
try:
|
||||
out = render_event_cli(
|
||||
action_completed("f-abs", "file_change", abs_path, ok=True)
|
||||
)
|
||||
finally:
|
||||
reset_run_base_dir(token)
|
||||
assert any("files: `changelog.md`" in line for line in out)
|
||||
|
||||
|
||||
def test_progress_renderer_renders_progress_and_final() -> None:
|
||||
tracker = ProgressTracker(engine="codex")
|
||||
for evt in SAMPLE_EVENTS:
|
||||
@@ -145,6 +180,23 @@ def test_progress_renderer_renders_progress_and_final() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_progress_renderer_footer_includes_ctx_before_resume() -> None:
|
||||
tracker = ProgressTracker(engine="codex")
|
||||
for evt in SAMPLE_EVENTS:
|
||||
tracker.note_event(evt)
|
||||
|
||||
state = tracker.snapshot(
|
||||
resume_formatter=_format_resume,
|
||||
context_line="`ctx: z80 @ feat/name`",
|
||||
)
|
||||
formatter = MarkdownFormatter(max_actions=5)
|
||||
parts = formatter.render_progress_parts(state, elapsed_s=0.0)
|
||||
assert parts.footer == (
|
||||
"`ctx: z80 @ feat/name`"
|
||||
f"{HARD_BREAK}`codex resume 0199a213-81c0-7800-8aa1-bbab2a035a53`"
|
||||
)
|
||||
|
||||
|
||||
def test_progress_renderer_clamps_actions_and_ignores_unknown() -> None:
|
||||
tracker = ProgressTracker(engine="codex")
|
||||
events = [
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
|
||||
from takopi.utils.git import resolve_default_base, resolve_main_worktree_root
|
||||
|
||||
|
||||
def test_resolve_main_worktree_root_returns_none_when_no_git(monkeypatch) -> None:
|
||||
monkeypatch.setattr("takopi.utils.git.git_stdout", lambda *args, **kwargs: None)
|
||||
assert resolve_main_worktree_root(Path("/tmp")) is None
|
||||
|
||||
|
||||
def test_resolve_main_worktree_root_prefers_common_dir_parent(monkeypatch) -> None:
|
||||
base = Path("/repo")
|
||||
|
||||
def _fake_stdout(args, **kwargs):
|
||||
if args[:2] == ["rev-parse", "--path-format=absolute"]:
|
||||
return str(base / ".git")
|
||||
if args == ["rev-parse", "--is-bare-repository"]:
|
||||
return "false"
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout)
|
||||
assert resolve_main_worktree_root(base / ".worktrees" / "feature") == base
|
||||
|
||||
|
||||
def test_resolve_main_worktree_root_returns_cwd_for_bare_repo(monkeypatch) -> None:
|
||||
cwd = Path("/bare-repo")
|
||||
|
||||
def _fake_stdout(args, **kwargs):
|
||||
if args[:2] == ["rev-parse", "--path-format=absolute"]:
|
||||
return str(cwd / "repo.git")
|
||||
if args == ["rev-parse", "--is-bare-repository"]:
|
||||
return "true"
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout)
|
||||
assert resolve_main_worktree_root(cwd) == cwd
|
||||
|
||||
|
||||
def test_resolve_default_base_prefers_master_over_main(monkeypatch) -> None:
|
||||
def _fake_stdout(args, **kwargs):
|
||||
if args[:2] == ["symbolic-ref", "-q"]:
|
||||
return None
|
||||
if args == ["branch", "--show-current"]:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _fake_ok(args, **kwargs):
|
||||
if args == ["show-ref", "--verify", "--quiet", "refs/heads/master"]:
|
||||
return True
|
||||
if args == ["show-ref", "--verify", "--quiet", "refs/heads/main"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout)
|
||||
monkeypatch.setattr("takopi.utils.git.git_ok", _fake_ok)
|
||||
assert resolve_default_base(Path("/repo")) == "master"
|
||||
+17
-1
@@ -2,7 +2,12 @@ from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from takopi.utils.paths import relativize_command, relativize_path
|
||||
from takopi.utils.paths import (
|
||||
relativize_command,
|
||||
relativize_path,
|
||||
reset_run_base_dir,
|
||||
set_run_base_dir,
|
||||
)
|
||||
|
||||
|
||||
def test_relativize_command_rewrites_cwd_paths(tmp_path: Path) -> None:
|
||||
@@ -33,3 +38,14 @@ def test_relativize_path_inside_base(tmp_path: Path) -> None:
|
||||
base.mkdir()
|
||||
value = str(base / "src" / "app.py")
|
||||
assert relativize_path(value, base_dir=base) == "src/app.py"
|
||||
|
||||
|
||||
def test_relativize_path_uses_run_base_dir(tmp_path: Path) -> None:
|
||||
base = tmp_path / "repo"
|
||||
base.mkdir()
|
||||
token = set_run_base_dir(base)
|
||||
try:
|
||||
value = str(base / "src" / "app.py")
|
||||
assert relativize_path(value) == "src/app.py"
|
||||
finally:
|
||||
reset_run_base_dir(token)
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from takopi import cli
|
||||
from takopi.config import ConfigError, parse_projects_config
|
||||
|
||||
|
||||
def test_parse_projects_rejects_engine_alias() -> None:
|
||||
config = {"projects": {"codex": {"path": "/tmp/repo"}}}
|
||||
with pytest.raises(ConfigError, match="aliases must not match engine ids"):
|
||||
parse_projects_config(
|
||||
config,
|
||||
config_path=Path("takopi.toml"),
|
||||
engine_ids=["codex"],
|
||||
reserved=("cancel",),
|
||||
)
|
||||
|
||||
|
||||
def test_parse_projects_default_project_must_exist() -> None:
|
||||
config = {"default_project": "z80", "projects": {}}
|
||||
with pytest.raises(ConfigError, match="default_project"):
|
||||
parse_projects_config(
|
||||
config,
|
||||
config_path=Path("takopi.toml"),
|
||||
engine_ids=["codex"],
|
||||
reserved=("cancel",),
|
||||
)
|
||||
|
||||
|
||||
def test_init_writes_project(monkeypatch, tmp_path) -> None:
|
||||
config_path = tmp_path / "takopi.toml"
|
||||
monkeypatch.setattr("takopi.config.HOME_CONFIG_PATH", config_path)
|
||||
monkeypatch.setattr(cli, "resolve_default_base", lambda _: "main")
|
||||
|
||||
repo_path = tmp_path / "repo"
|
||||
repo_path.mkdir()
|
||||
monkeypatch.chdir(repo_path)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli.app, ["init", "z80"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
saved = config_path.read_text(encoding="utf-8")
|
||||
assert "[projects.z80]" in saved
|
||||
assert 'worktrees_dir = ".worktrees"' in saved
|
||||
assert 'default_engine = "codex"' in saved
|
||||
assert 'worktree_base = "main"' in saved
|
||||
+112
-42
@@ -1,3 +1,5 @@
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
@@ -7,16 +9,19 @@ from takopi.telegram.bridge import (
|
||||
_build_bot_commands,
|
||||
_handle_cancel,
|
||||
_is_cancel_command,
|
||||
_resolve_message,
|
||||
_send_with_resume,
|
||||
_strip_engine_command,
|
||||
run_main_loop,
|
||||
)
|
||||
from takopi.context import RunContext
|
||||
from takopi.config import ProjectConfig, ProjectsConfig
|
||||
from takopi.runner_bridge import ExecBridgeConfig, RunningTask
|
||||
from takopi.markdown import MarkdownPresenter
|
||||
from takopi.model import EngineId, ResumeToken
|
||||
from takopi.router import AutoRouter, RunnerEntry
|
||||
from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait
|
||||
from takopi.transport import MessageRef, RenderedMessage, SendOptions
|
||||
from takopi.transport import IncomingMessage, MessageRef, RenderedMessage, SendOptions
|
||||
|
||||
CODEX_ENGINE = EngineId("codex")
|
||||
|
||||
@@ -354,7 +359,15 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
|
||||
async def test_handle_cancel_without_reply_prompts_user() -> None:
|
||||
transport = _FakeTransport()
|
||||
cfg = _make_cfg(transport)
|
||||
msg = {"chat": {"id": 123}, "message_id": 10}
|
||||
msg = IncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=10,
|
||||
text="/cancel",
|
||||
reply_to_message_id=None,
|
||||
reply_to_text=None,
|
||||
sender_id=123,
|
||||
)
|
||||
running_tasks: dict = {}
|
||||
|
||||
await _handle_cancel(cfg, msg, running_tasks)
|
||||
@@ -367,11 +380,15 @@ async def test_handle_cancel_without_reply_prompts_user() -> None:
|
||||
async def test_handle_cancel_with_no_progress_message_says_nothing_running() -> None:
|
||||
transport = _FakeTransport()
|
||||
cfg = _make_cfg(transport)
|
||||
msg = {
|
||||
"chat": {"id": 123},
|
||||
"message_id": 10,
|
||||
"reply_to_message": {"text": "no message id"},
|
||||
}
|
||||
msg = IncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=10,
|
||||
text="/cancel",
|
||||
reply_to_message_id=None,
|
||||
reply_to_text="no message id",
|
||||
sender_id=123,
|
||||
)
|
||||
running_tasks: dict = {}
|
||||
|
||||
await _handle_cancel(cfg, msg, running_tasks)
|
||||
@@ -385,11 +402,15 @@ async def test_handle_cancel_with_finished_task_says_nothing_running() -> None:
|
||||
transport = _FakeTransport()
|
||||
cfg = _make_cfg(transport)
|
||||
progress_id = 99
|
||||
msg = {
|
||||
"chat": {"id": 123},
|
||||
"message_id": 10,
|
||||
"reply_to_message": {"message_id": progress_id},
|
||||
}
|
||||
msg = IncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=10,
|
||||
text="/cancel",
|
||||
reply_to_message_id=progress_id,
|
||||
reply_to_text=None,
|
||||
sender_id=123,
|
||||
)
|
||||
running_tasks: dict = {}
|
||||
|
||||
await _handle_cancel(cfg, msg, running_tasks)
|
||||
@@ -403,11 +424,15 @@ async def test_handle_cancel_cancels_running_task() -> None:
|
||||
transport = _FakeTransport()
|
||||
cfg = _make_cfg(transport)
|
||||
progress_id = 42
|
||||
msg = {
|
||||
"chat": {"id": 123},
|
||||
"message_id": 10,
|
||||
"reply_to_message": {"message_id": progress_id},
|
||||
}
|
||||
msg = IncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=10,
|
||||
text="/cancel",
|
||||
reply_to_message_id=progress_id,
|
||||
reply_to_text=None,
|
||||
sender_id=123,
|
||||
)
|
||||
|
||||
running_task = RunningTask()
|
||||
running_tasks = {MessageRef(channel_id=123, message_id=progress_id): running_task}
|
||||
@@ -423,11 +448,15 @@ async def test_handle_cancel_only_cancels_matching_progress_message() -> None:
|
||||
cfg = _make_cfg(transport)
|
||||
task_first = RunningTask()
|
||||
task_second = RunningTask()
|
||||
msg = {
|
||||
"chat": {"id": 123},
|
||||
"message_id": 10,
|
||||
"reply_to_message": {"message_id": 1},
|
||||
}
|
||||
msg = IncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=10,
|
||||
text="/cancel",
|
||||
reply_to_message_id=1,
|
||||
reply_to_text=None,
|
||||
sender_id=123,
|
||||
)
|
||||
running_tasks = {
|
||||
MessageRef(channel_id=123, message_id=1): task_first,
|
||||
MessageRef(channel_id=123, message_id=2): task_second,
|
||||
@@ -446,16 +475,46 @@ def test_cancel_command_accepts_extra_text() -> None:
|
||||
assert _is_cancel_command("/cancelled") is False
|
||||
|
||||
|
||||
def test_resolve_message_accepts_backticked_ctx_line() -> None:
|
||||
router = _make_router(ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE))
|
||||
projects = ProjectsConfig(
|
||||
projects={
|
||||
"takopi": ProjectConfig(
|
||||
alias="takopi",
|
||||
path=Path("."),
|
||||
worktrees_dir=Path(".worktrees"),
|
||||
)
|
||||
},
|
||||
default_project=None,
|
||||
)
|
||||
|
||||
resolved = _resolve_message(
|
||||
text="do it",
|
||||
reply_text="`ctx: takopi @ feat/api`",
|
||||
router=router,
|
||||
projects=projects,
|
||||
)
|
||||
|
||||
assert resolved.prompt == "do it"
|
||||
assert resolved.resume_token is None
|
||||
assert resolved.engine_override is None
|
||||
assert resolved.context == RunContext(project="takopi", branch="feat/api")
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_send_with_resume_waits_for_token() -> None:
|
||||
transport = _FakeTransport()
|
||||
cfg = _make_cfg(transport)
|
||||
sent: list[tuple[int, int, str, ResumeToken | None]] = []
|
||||
sent: list[tuple[int, int, str, ResumeToken, RunContext | None]] = []
|
||||
|
||||
async def enqueue(
|
||||
chat_id: int, user_msg_id: int, text: str, resume: ResumeToken
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
resume: ResumeToken,
|
||||
context: RunContext | None,
|
||||
) -> None:
|
||||
sent.append((chat_id, user_msg_id, text, resume))
|
||||
sent.append((chat_id, user_msg_id, text, resume, context))
|
||||
|
||||
running_task = RunningTask()
|
||||
|
||||
@@ -476,7 +535,7 @@ async def test_send_with_resume_waits_for_token() -> None:
|
||||
)
|
||||
|
||||
assert sent == [
|
||||
(123, 10, "hello", ResumeToken(engine=CODEX_ENGINE, value="abc123"))
|
||||
(123, 10, "hello", ResumeToken(engine=CODEX_ENGINE, value="abc123"), None)
|
||||
]
|
||||
assert transport.send_calls == []
|
||||
|
||||
@@ -485,12 +544,16 @@ async def test_send_with_resume_waits_for_token() -> None:
|
||||
async def test_send_with_resume_reports_when_missing() -> None:
|
||||
transport = _FakeTransport()
|
||||
cfg = _make_cfg(transport)
|
||||
sent: list[tuple[int, int, str, ResumeToken | None]] = []
|
||||
sent: list[tuple[int, int, str, ResumeToken, RunContext | None]] = []
|
||||
|
||||
async def enqueue(
|
||||
chat_id: int, user_msg_id: int, text: str, resume: ResumeToken
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
resume: ResumeToken,
|
||||
context: RunContext | None,
|
||||
) -> None:
|
||||
sent.append((chat_id, user_msg_id, text, resume))
|
||||
sent.append((chat_id, user_msg_id, text, resume, context))
|
||||
|
||||
running_task = RunningTask()
|
||||
running_task.done.set()
|
||||
@@ -538,22 +601,29 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None:
|
||||
)
|
||||
|
||||
async def poller(_cfg: TelegramBridgeConfig):
|
||||
yield {
|
||||
"message_id": 1,
|
||||
"text": "first",
|
||||
"chat": {"id": 123},
|
||||
"from": {"id": 123},
|
||||
}
|
||||
yield IncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=1,
|
||||
text="first",
|
||||
reply_to_message_id=None,
|
||||
reply_to_text=None,
|
||||
sender_id=123,
|
||||
)
|
||||
await progress_ready.wait()
|
||||
assert transport.progress_ref is not None
|
||||
assert isinstance(transport.progress_ref.message_id, int)
|
||||
reply_id = transport.progress_ref.message_id
|
||||
reply_ready.set()
|
||||
yield {
|
||||
"message_id": 2,
|
||||
"text": "followup",
|
||||
"chat": {"id": 123},
|
||||
"from": {"id": 123},
|
||||
"reply_to_message": {"message_id": transport.progress_ref.message_id},
|
||||
}
|
||||
yield IncomingMessage(
|
||||
transport="telegram",
|
||||
chat_id=123,
|
||||
message_id=2,
|
||||
text="followup",
|
||||
reply_to_message_id=reply_id,
|
||||
reply_to_text=None,
|
||||
sender_id=123,
|
||||
)
|
||||
await stop_polling.wait()
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from takopi.telegram import parse_incoming_update
|
||||
|
||||
|
||||
def test_parse_incoming_update_maps_fields() -> None:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"message_id": 10,
|
||||
"text": "hello",
|
||||
"chat": {"id": 123},
|
||||
"from": {"id": 99},
|
||||
"reply_to_message": {"message_id": 5, "text": "prev"},
|
||||
},
|
||||
}
|
||||
|
||||
msg = parse_incoming_update(update, chat_id=123)
|
||||
assert msg is not None
|
||||
assert msg.transport == "telegram"
|
||||
assert msg.chat_id == 123
|
||||
assert msg.message_id == 10
|
||||
assert msg.text == "hello"
|
||||
assert msg.reply_to_message_id == 5
|
||||
assert msg.reply_to_text == "prev"
|
||||
assert msg.sender_id == 99
|
||||
assert msg.raw == update["message"]
|
||||
|
||||
|
||||
def test_parse_incoming_update_filters_non_matching_chat() -> None:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"message_id": 10,
|
||||
"text": "hello",
|
||||
"chat": {"id": 123},
|
||||
},
|
||||
}
|
||||
|
||||
assert parse_incoming_update(update, chat_id=999) is None
|
||||
|
||||
|
||||
def test_parse_incoming_update_filters_non_text() -> None:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {"message_id": 10, "chat": {"id": 123}},
|
||||
}
|
||||
|
||||
assert parse_incoming_update(update, chat_id=123) is None
|
||||
@@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from takopi.config import ProjectConfig, ProjectsConfig
|
||||
from takopi.context import RunContext
|
||||
from takopi.worktrees import WorktreeError, ensure_worktree, resolve_run_cwd
|
||||
|
||||
|
||||
def _projects_config(path: Path) -> ProjectsConfig:
|
||||
return ProjectsConfig(
|
||||
projects={
|
||||
"z80": ProjectConfig(
|
||||
alias="z80",
|
||||
path=path,
|
||||
worktrees_dir=Path(".worktrees"),
|
||||
)
|
||||
},
|
||||
default_project=None,
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_run_cwd_uses_project_root(tmp_path: Path) -> None:
|
||||
projects = _projects_config(tmp_path)
|
||||
ctx = RunContext(project="z80")
|
||||
assert resolve_run_cwd(ctx, projects=projects) == tmp_path
|
||||
|
||||
|
||||
def test_resolve_run_cwd_rejects_invalid_branch(tmp_path: Path) -> None:
|
||||
projects = _projects_config(tmp_path)
|
||||
ctx = RunContext(project="z80", branch="../oops")
|
||||
with pytest.raises(WorktreeError, match="branch name"):
|
||||
resolve_run_cwd(ctx, projects=projects)
|
||||
|
||||
|
||||
def test_ensure_worktree_creates_from_base(monkeypatch, tmp_path: Path) -> None:
|
||||
project = ProjectConfig(
|
||||
alias="z80",
|
||||
path=tmp_path,
|
||||
worktrees_dir=Path(".worktrees"),
|
||||
)
|
||||
calls: list[list[str]] = []
|
||||
|
||||
monkeypatch.setattr("takopi.worktrees.git_ok", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("takopi.worktrees.resolve_default_base", lambda *_: "main")
|
||||
|
||||
def _fake_git_run(args, cwd):
|
||||
calls.append(list(args))
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr("takopi.worktrees.git_run", _fake_git_run)
|
||||
|
||||
worktree_path = ensure_worktree(project, "feat/name")
|
||||
assert worktree_path == tmp_path / ".worktrees" / "feat" / "name"
|
||||
assert calls == [["worktree", "add", "-b", "feat/name", str(worktree_path), "main"]]
|
||||
Reference in New Issue
Block a user