diff --git a/changelog.md b/changelog.md index ad421dc..b47083c 100644 --- a/changelog.md +++ b/changelog.md @@ -10,6 +10,19 @@ - add transport/presenter protocols plus transport-agnostic `exec_bridge` - move Telegram polling + wiring into `takopi.bridges.telegram` with transport/presenter adapters +- add project configuration, directive parsing (`/project`, `@branch`), and `ctx:`-aware routing for runs +- add `takopi init` to register project aliases from the main checkout (with worktree defaults) +- resolve git worktrees on demand and run engine subprocesses in the project/worktree cwd +- list configured projects in the startup banner +- add a shared incoming message shape plus Telegram parsing helpers + +### fixes + +- render `ctx:` footer lines consistently (backticked + hard breaks) and include them in final messages + +### docs + +- add a projects/worktrees guide and document `takopi init` behavior in the README ## v0.8.0 (2026-01-05) diff --git a/docs/projects.md b/docs/projects.md new file mode 100644 index 0000000..a8dc605 --- /dev/null +++ b/docs/projects.md @@ -0,0 +1,135 @@ +# Projects and Worktrees + +This doc covers project aliases, worktree behavior, and how Takopi resolves run +context from messages. + +## Overview + +Projects let you give a repo an alias (used as `/alias` in messages) and opt into +worktree-based runs via `@branch`. + +- If no projects are configured, Takopi runs in the startup working directory. +- If a project is configured, `@branch` resolves/creates a git worktree and runs + the task in that worktree. +- Progress/final messages include a `ctx:` footer when project context is active. + +## Config schema + +All config lives in `~/.takopi/takopi.toml`. + +```toml +default_engine = "codex" # optional +default_project = "z80" # optional +bot_token = "..." # required +chat_id = 123 # required + +[projects.z80] +path = "~/dev/z80" # required (repo root) +worktrees_dir = ".worktrees" # optional, default ".worktrees" +default_engine = "codex" # optional, per-project override +worktree_base = "master" # optional, base for new branches +``` + +Validation rules: + +- `projects` is optional. +- Each project entry must include `path` (string, non-empty). +- `default_project` must match a configured project alias. +- Project aliases cannot collide with engine ids or reserved commands (`/cancel`). +- `default_engine` and per-project `default_engine` must be valid engine ids. + +## `takopi init` + +`takopi init ` registers the current repo as a project alias. + +Important behavior: + +- The stored `path` is the **main checkout** of the repo, even if you run + `takopi init` inside a worktree. Takopi resolves the repo root via the git + common dir and writes that path to `[projects.].path`. +- `worktree_base` is set from the current repo using this resolution order: + `origin/HEAD` → current branch → `master` → `main`. + +## Directives and context resolution + +Takopi parses the first non-empty line of a message for a directive prefix. + +Supported directives: + +- `/engine` or `/engine@bot`: chooses the engine +- `/project`: chooses a project alias +- `@branch`: chooses a git branch/worktree + +Rules: + +- Directives must be a contiguous prefix of the line; parsing stops at the first + non-directive token. +- At most one engine directive, one project directive, and one `@branch` are + allowed (duplicates are errors). +- If a reply contains a `ctx:` line, Takopi **ignores new directives** and uses + the reply context. + +## Context footer (`ctx:`) + +When a run has project context, Takopi appends a footer line rendered as inline +code (backticked): + +- With branch: `` `ctx: @ ` `` +- Without branch: `` `ctx: ` `` + +The `ctx:` line is parsed from replies and takes precedence over new directives. + +## Worktree resolution + +When `@branch` is present: + +``` +worktrees_root = / +worktree_path = worktrees_root / +``` + +Branch validation: + +- Must be non-empty +- Must not start with `/` +- Must not contain `..` path segments +- May include `/` (nested directories) +- The resolved worktree path must stay within `worktrees_root` + +Worktree creation rules: + +1) If `worktree_path` exists: + - It must be a git worktree or Takopi errors. +2) If it does not exist: + - If local branch exists: `git worktree add ` + - Else if remote `origin/` exists: + `git worktree add -b origin/` + - Else: + `git worktree add -b ` + +Base branch selection: + +1) `projects..worktree_base` (if set) +2) `origin/HEAD` (if present) +3) current checked out branch +4) `master` if it exists +5) `main` if it exists +6) otherwise error + +When `@branch` is omitted: + +- Takopi runs in `` (the main checkout). + +## Examples + +Start a new thread in a worktree: + +``` +/z80 @feat/streaming fix flaky test +``` + +Reply to a progress message to continue in the same context: + +``` +ctx: z80 @ feat/streaming +``` diff --git a/readme.md b/readme.md index 05f4ecf..b54ecc6 100644 --- a/readme.md +++ b/readme.md @@ -76,6 +76,28 @@ provider = "openai" extra_args = ["--no-color"] ``` +## projects (optional) + +register the current repo as a project alias: + +```sh +takopi init z80 +``` + +`takopi init` writes the repo root to `[projects.].path`. if you run it inside a git worktree, it resolves the main checkout and records that path instead of the worktree. + +example: + +```toml +default_project = "z80" + +[projects.z80] +path = "~/dev/z80" +worktrees_dir = ".worktrees" +default_engine = "codex" +worktree_base = "master" +``` + ## usage start takopi in the repo you want to work on: diff --git a/src/takopi/cli.py b/src/takopi/cli.py index cdff1f7..127d349 100644 --- a/src/takopi/cli.py +++ b/src/takopi/cli.py @@ -11,10 +11,17 @@ import typer from . import __version__ from .backends import EngineBackend -from .config import ConfigError +from .config import ( + ConfigError, + load_or_init_config, + parse_projects_config, + write_config, +) from .engines import get_backend, get_engine_config, list_backends from .lockfile import LockError, LockHandle, acquire_lock, token_fingerprint from .logging import get_logger, setup_logging +from .router import AutoRouter, RunnerEntry +from .runner_bridge import ExecBridgeConfig from .telegram.bridge import ( TelegramBridgeConfig, TelegramPresenter, @@ -24,8 +31,7 @@ from .telegram.bridge import ( from .telegram.client import TelegramClient from .telegram.config import load_telegram_config from .telegram.onboarding import SetupResult, check_setup, interactive_setup -from .router import AutoRouter, RunnerEntry -from .runner_bridge import ExecBridgeConfig +from .utils.git import resolve_default_base, resolve_main_worktree_root logger = get_logger(__name__) @@ -195,6 +201,12 @@ def _parse_bridge_config( startup_pwd = os.getcwd() backends = list_backends() + projects = parse_projects_config( + config, + config_path=config_path, + engine_ids=[backend.id for backend in backends], + reserved=("cancel",), + ) default_engine = _resolve_default_engine( override=default_engine_override, config=config, @@ -212,10 +224,16 @@ def _parse_bridge_config( engine_list = ", ".join(available_engines) if available_engines else "none" if missing_engines: engine_list = f"{engine_list} (not installed: {', '.join(missing_engines)})" + project_aliases = sorted( + {project.alias for project in projects.projects.values()}, + key=str.lower, + ) + project_list = ", ".join(project_aliases) if project_aliases else "none" startup_msg = ( f"\N{OCTOPUS} **takopi is ready**\n\n" f"default: `{router.default_engine}` \n" f"agents: `{engine_list}` \n" + f"projects: `{project_list}` \n" f"working in: `{startup_pwd}`" ) @@ -234,6 +252,7 @@ def _parse_bridge_config( chat_id=chat_id, startup_msg=startup_msg, exec_cfg=exec_cfg, + projects=projects, ) @@ -320,6 +339,107 @@ def _run_auto_router( lock_handle.release() +def _prompt_alias(value: str | None, *, default_alias: str | None = None) -> str: + if value is not None: + alias = value + elif default_alias: + alias = typer.prompt("project alias", default=default_alias) + else: + alias = typer.prompt("project alias") + alias = alias.strip() + if not alias: + typer.echo("error: project alias cannot be empty", err=True) + raise typer.Exit(code=1) + return alias + + +def _default_alias_from_path(path: Path) -> str | None: + name = path.name + if not name: + return None + if name.endswith(".git"): + name = name[: -len(".git")] + return name or None + + +def _ensure_projects_table(config: dict, config_path: Path) -> dict: + projects = config.get("projects") + if projects is None: + projects = {} + config["projects"] = projects + if not isinstance(projects, dict): + raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.") + return projects + + +def init( + alias: str | None = typer.Argument( + None, help="Project alias (used as /alias in messages)." + ), + default: bool = typer.Option( + False, + "--default", + help="Set this project as the default_project.", + ), +) -> None: + """Register the current repo as a Takopi project.""" + config, config_path = load_or_init_config() + + cwd = Path.cwd() + project_path = resolve_main_worktree_root(cwd) or cwd + default_alias = _default_alias_from_path(project_path) + alias = _prompt_alias(alias, default_alias=default_alias) + + engine_ids = [backend.id for backend in list_backends()] + projects_cfg = parse_projects_config( + config, + config_path=config_path, + engine_ids=engine_ids, + reserved=("cancel",), + ) + + alias_key = alias.lower() + if alias_key in {engine.lower() for engine in engine_ids}: + raise ConfigError( + f"Invalid project alias {alias!r}; aliases must not match engine ids." + ) + if alias_key == "cancel": + raise ConfigError( + f"Invalid project alias {alias!r}; aliases must not match reserved commands." + ) + + existing = projects_cfg.projects.get(alias_key) + if existing is not None: + overwrite = typer.confirm( + f"project {existing.alias!r} already exists, overwrite?", + default=False, + ) + if not overwrite: + raise typer.Exit(code=1) + + projects = _ensure_projects_table(config, config_path) + if existing is not None and existing.alias in projects: + projects.pop(existing.alias, None) + + default_engine = _default_engine_for_setup(None) + worktree_base = resolve_default_base(project_path) + + entry: dict[str, object] = { + "path": str(project_path), + "worktrees_dir": ".worktrees", + "default_engine": default_engine, + } + if worktree_base: + entry["worktree_base"] = worktree_base + + projects[alias] = entry + if default: + config["default_project"] = alias + + write_config(config, config_path) + typer.echo(f"saved project {alias!r} to {_config_path_display(config_path)}") + + app = typer.Typer( add_completion=False, invoke_without_command=True, @@ -327,6 +447,9 @@ app = typer.Typer( ) +app.command(name="init")(init) + + @app.callback() def app_main( ctx: typer.Context, diff --git a/src/takopi/config.py b/src/takopi/config.py index dc00535..5c09b83 100644 --- a/src/takopi/config.py +++ b/src/takopi/config.py @@ -1,5 +1,260 @@ from __future__ import annotations +import tomllib +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Iterable + +HOME_CONFIG_PATH = Path.home() / ".takopi" / "takopi.toml" + class ConfigError(RuntimeError): pass + + +def _read_config(cfg_path: Path) -> dict: + try: + raw = cfg_path.read_text(encoding="utf-8") + except FileNotFoundError: + raise ConfigError(f"Missing config file {cfg_path}.") from None + except OSError as e: + raise ConfigError(f"Failed to read config file {cfg_path}: {e}") from e + try: + return tomllib.loads(raw) + except tomllib.TOMLDecodeError as e: + raise ConfigError(f"Malformed TOML in {cfg_path}: {e}") from None + + +def load_or_init_config(path: str | Path | None = None) -> tuple[dict, Path]: + cfg_path = Path(path).expanduser() if path else HOME_CONFIG_PATH + if cfg_path.exists() and not cfg_path.is_file(): + raise ConfigError(f"Config path {cfg_path} exists but is not a file.") from None + if not cfg_path.exists(): + return {}, cfg_path + return _read_config(cfg_path), cfg_path + + +@dataclass(frozen=True, slots=True) +class ProjectConfig: + alias: str + path: Path + worktrees_dir: Path + default_engine: str | None = None + worktree_base: str | None = None + + @property + def worktrees_root(self) -> Path: + if self.worktrees_dir.is_absolute(): + return self.worktrees_dir + return self.path / self.worktrees_dir + + +@dataclass(frozen=True, slots=True) +class ProjectsConfig: + projects: dict[str, ProjectConfig] + default_project: str | None = None + + def resolve(self, alias: str | None) -> ProjectConfig | None: + if alias is None: + if self.default_project is None: + return None + return self.projects.get(self.default_project) + return self.projects.get(alias.lower()) + + +def empty_projects_config() -> ProjectsConfig: + return ProjectsConfig(projects={}, default_project=None) + + +def _normalize_engine_id( + value: str, + *, + engine_ids: Iterable[str], + config_path: Path, + label: str, +) -> str: + engine_map = {engine.lower(): engine for engine in engine_ids} + cleaned = value.strip() + if not cleaned: + raise ConfigError(f"Invalid `{label}` in {config_path}; expected a string.") + engine = engine_map.get(cleaned.lower()) + if engine is None: + available = ", ".join(sorted(engine_map.values())) + raise ConfigError( + f"Unknown `{label}` {cleaned!r} in {config_path}. Available: {available}." + ) + return engine + + +def _normalize_project_path(value: str, *, config_path: Path) -> Path: + path = Path(value).expanduser() + if not path.is_absolute(): + path = config_path.parent / path + return path + + +def parse_projects_config( + config: dict[str, Any], + *, + config_path: Path, + engine_ids: Iterable[str], + reserved: Iterable[str] = ("cancel",), +) -> ProjectsConfig: + default_project_raw = config.get("default_project") + default_project = None + if default_project_raw is not None: + if not isinstance(default_project_raw, str) or not default_project_raw.strip(): + raise ConfigError( + f"Invalid `default_project` in {config_path}; expected a non-empty string." + ) + default_project = default_project_raw.strip() + + projects_raw = config.get("projects") or {} + if not isinstance(projects_raw, dict): + raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.") + + reserved_lower = {value.lower() for value in reserved} + engine_lower = {value.lower() for value in engine_ids} + projects: dict[str, ProjectConfig] = {} + + for raw_alias, raw_entry in projects_raw.items(): + if not isinstance(raw_alias, str) or not raw_alias.strip(): + raise ConfigError( + f"Invalid project alias in {config_path}; expected a non-empty string." + ) + alias = raw_alias.strip() + alias_key = alias.lower() + if alias_key in engine_lower or alias_key in reserved_lower: + raise ConfigError( + f"Invalid project alias {alias!r} in {config_path}; " + "aliases must not match engine ids or reserved commands." + ) + if alias_key in projects: + raise ConfigError(f"Duplicate project alias {alias!r} in {config_path}.") + if not isinstance(raw_entry, dict): + raise ConfigError( + f"Invalid project entry for {alias!r} in {config_path}; expected a table." + ) + + path_value = raw_entry.get("path") + if not isinstance(path_value, str) or not path_value.strip(): + raise ConfigError(f"Missing `path` for project {alias!r} in {config_path}.") + path = _normalize_project_path(path_value.strip(), config_path=config_path) + + worktrees_dir_raw = raw_entry.get("worktrees_dir", ".worktrees") + if not isinstance(worktrees_dir_raw, str) or not worktrees_dir_raw.strip(): + raise ConfigError( + f"Invalid `worktrees_dir` for project {alias!r} in {config_path}." + ) + worktrees_dir = Path(worktrees_dir_raw.strip()) + + default_engine_raw = raw_entry.get("default_engine") + default_engine = None + if default_engine_raw is not None: + if not isinstance(default_engine_raw, str): + raise ConfigError( + f"Invalid `projects.{alias}.default_engine` in {config_path}; " + "expected a string." + ) + default_engine = _normalize_engine_id( + default_engine_raw, + engine_ids=engine_ids, + config_path=config_path, + label=f"projects.{alias}.default_engine", + ) + + worktree_base_raw = raw_entry.get("worktree_base") + worktree_base = None + if worktree_base_raw is not None: + if not isinstance(worktree_base_raw, str) or not worktree_base_raw.strip(): + raise ConfigError( + f"Invalid `projects.{alias}.worktree_base` in {config_path}; " + "expected a string." + ) + worktree_base = worktree_base_raw.strip() + + projects[alias_key] = ProjectConfig( + alias=alias, + path=path, + worktrees_dir=worktrees_dir, + default_engine=default_engine, + worktree_base=worktree_base, + ) + + if default_project is not None: + default_key = default_project.lower() + if default_key not in projects: + raise ConfigError( + f"Invalid `default_project` {default_project!r} in {config_path}; " + "no matching project alias found." + ) + default_project = default_key + + return ProjectsConfig(projects=projects, default_project=default_project) + + +def _toml_escape(value: str) -> str: + return value.replace("\\", "\\\\").replace('"', '\\"') + + +def _format_toml_value(value: Any) -> str: + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, int): + return str(value) + if isinstance(value, float): + return repr(value) + if isinstance(value, str): + return f'"{_toml_escape(value)}"' + if isinstance(value, (list, tuple)): + inner = ", ".join(_format_toml_value(item) for item in value) + return f"[{inner}]" + raise ConfigError(f"Unsupported config value {value!r}") + + +def _table_has_scalars(table: dict[str, Any]) -> bool: + return any(not isinstance(value, dict) for value in table.values()) + + +def dump_toml(config: dict[str, Any]) -> str: + lines: list[str] = [] + + def write_kv(key: str, value: Any) -> None: + lines.append(f"{key} = {_format_toml_value(value)}") + + def write_table(name: str, table: dict[str, Any]) -> None: + if lines and lines[-1] != "": + lines.append("") + lines.append(f"[{name}]") + for key, value in table.items(): + if isinstance(value, dict): + continue + write_kv(key, value) + for key, value in table.items(): + if isinstance(value, dict): + write_table(f"{name}.{key}", value) + + for key, value in config.items(): + if isinstance(value, dict): + continue + write_kv(key, value) + + for key, value in config.items(): + if not isinstance(value, dict): + continue + if _table_has_scalars(value): + write_table(key, value) + continue + for subkey, subvalue in value.items(): + if isinstance(subvalue, dict): + write_table(f"{key}.{subkey}", subvalue) + else: + write_table(key, value) + break + + return "\n".join(lines) + "\n" + + +def write_config(config: dict[str, Any], path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(dump_toml(config), encoding="utf-8") diff --git a/src/takopi/context.py b/src/takopi/context.py new file mode 100644 index 0000000..a4efe07 --- /dev/null +++ b/src/takopi/context.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class RunContext: + project: str | None = None + branch: str | None = None diff --git a/src/takopi/markdown.py b/src/takopi/markdown.py index a184e0d..dea09f9 100644 --- a/src/takopi/markdown.py +++ b/src/takopi/markdown.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import textwrap from dataclasses import dataclass from pathlib import Path @@ -94,12 +95,16 @@ def format_file_change_title(action: Action, *, command_width: int | None) -> st if isinstance(changes, list) and changes: rendered: list[str] = [] for raw in changes: - if not isinstance(raw, dict): - continue - path = raw.get("path") + path: str | None + kind: str | None + if isinstance(raw, dict): + path = raw.get("path") + kind = raw.get("kind") + else: + path = getattr(raw, "path", None) + kind = getattr(raw, "kind", None) if not isinstance(path, str) or not path: continue - kind = raw.get("kind") verb = kind if isinstance(kind, str) and kind else "update" rendered.append(f"{verb} {format_changed_file_path(path)}") @@ -110,7 +115,15 @@ def format_file_change_title(action: Action, *, command_width: int | None) -> st inline = shorten(", ".join(rendered), command_width) return f"files: {inline}" - return f"files: {shorten(title, command_width)}" + fallback = title + relativized = relativize_path(fallback) + was_relativized = relativized != fallback + if was_relativized: + fallback = relativized + if fallback and not (fallback.startswith("`") and fallback.endswith("`")): + if was_relativized or os.sep in fallback or "/" in fallback: + fallback = f"`{fallback}`" + return f"files: {shorten(fallback, command_width)}" def format_action_title(action: Action, *, command_width: int | None) -> str: @@ -197,7 +210,9 @@ class MarkdownFormatter: engine=state.engine, ) body = self._assemble_body(self._format_actions(state)) - return MarkdownParts(header=header, body=body, footer=state.resume_line) + return MarkdownParts( + header=header, body=body, footer=self._format_footer(state) + ) def render_final_parts( self, @@ -216,7 +231,19 @@ class MarkdownFormatter: ) answer = (answer or "").strip() body = answer if answer else None - return MarkdownParts(header=header, body=body, footer=state.resume_line) + return MarkdownParts( + header=header, body=body, footer=self._format_footer(state) + ) + + def _format_footer(self, state: ProgressState) -> str | None: + lines: list[str] = [] + if state.context_line: + lines.append(state.context_line) + if state.resume_line: + lines.append(state.resume_line) + if not lines: + return None + return HARD_BREAK.join(lines) def _format_actions(self, state: ProgressState) -> list[str]: actions = list(state.actions) diff --git a/src/takopi/progress.py b/src/takopi/progress.py index d286d8c..763f576 100644 --- a/src/takopi/progress.py +++ b/src/takopi/progress.py @@ -24,6 +24,7 @@ class ProgressState: actions: tuple[ActionState, ...] resume: ResumeToken | None resume_line: str | None + context_line: str | None class ProgressTracker: @@ -80,6 +81,7 @@ class ProgressTracker: self, *, resume_formatter: Callable[[ResumeToken], str] | None = None, + context_line: str | None = None, ) -> ProgressState: resume_line: str | None = None if self.resume is not None and resume_formatter is not None: @@ -93,4 +95,5 @@ class ProgressTracker: actions=actions, resume=self.resume, resume_line=resume_line, + context_line=context_line, ) diff --git a/src/takopi/runner.py b/src/takopi/runner.py index 32642a8..265b067 100644 --- a/src/takopi/runner.py +++ b/src/takopi/runner.py @@ -22,6 +22,7 @@ from .model import ( StartedEvent, TakopiEvent, ) +from .utils.paths import get_run_base_dir from .utils.streams import drain_stderr, iter_bytes_lines from .utils.subprocess import manage_subprocess @@ -358,12 +359,15 @@ class JsonlSubprocessRunner(BaseRunner): prompt_len=len(prompt), ) + cwd = get_run_base_dir() + async with manage_subprocess( cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, + cwd=cwd, ) as proc: if proc.stdout is None or proc.stderr is None: raise RuntimeError(self.pipes_error_message()) diff --git a/src/takopi/runner_bridge.py b/src/takopi/runner_bridge.py index 9ed6ae5..e065eec 100644 --- a/src/takopi/runner_bridge.py +++ b/src/takopi/runner_bridge.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field import anyio +from .context import RunContext from .logging import bind_run_context, get_logger from .model import CompletedEvent, ResumeToken, StartedEvent, TakopiEvent from .presenter import Presenter @@ -93,6 +94,7 @@ class RunningTask: resume_ready: anyio.Event = field(default_factory=anyio.Event) cancel_requested: anyio.Event = field(default_factory=anyio.Event) done: anyio.Event = field(default_factory=anyio.Event) + context: RunContext | None = None RunningTasks = dict[MessageRef, RunningTask] @@ -152,6 +154,7 @@ class ProgressEdits: last_rendered: RenderedMessage | None, resume_formatter: Callable[[ResumeToken], str] | None = None, label: str = "working", + context_line: str | None = None, ) -> None: self.transport = transport self.presenter = presenter @@ -163,6 +166,7 @@ class ProgressEdits: self.last_rendered = last_rendered self.resume_formatter = resume_formatter self.label = label + self.context_line = context_line self.event_seq = 0 self.rendered_seq = 0 self.signal_send, self.signal_recv = anyio.create_memory_object_stream(1) @@ -179,7 +183,10 @@ class ProgressEdits: seq_at_render = self.event_seq now = self.clock() - state = self.tracker.snapshot(resume_formatter=self.resume_formatter) + state = self.tracker.snapshot( + resume_formatter=self.resume_formatter, + context_line=self.context_line, + ) rendered = self.presenter.render_progress( state, elapsed_s=now - self.started_at, label=self.label ) @@ -228,11 +235,15 @@ async def send_initial_progress( label: str, tracker: ProgressTracker, resume_formatter: Callable[[ResumeToken], str] | None = None, + context_line: str | None = None, ) -> ProgressMessageState: progress_ref: MessageRef | None = None last_rendered: RenderedMessage | None = None - state = tracker.snapshot(resume_formatter=resume_formatter) + state = tracker.snapshot( + resume_formatter=resume_formatter, + context_line=context_line, + ) initial_rendered = cfg.presenter.render_progress( state, elapsed_s=0.0, @@ -364,6 +375,8 @@ async def handle_message( runner: Runner, incoming: IncomingMessage, resume_token: ResumeToken | None, + context: RunContext | None = None, + context_line: str | None = None, strip_resume_line: Callable[[str], bool] | None = None, running_tasks: RunningTasks | None = None, on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] @@ -395,6 +408,7 @@ async def handle_message( label="starting", tracker=progress_tracker, resume_formatter=runner.format_resume, + context_line=context_line, ) progress_ref = progress_state.ref @@ -408,11 +422,12 @@ async def handle_message( clock=clock, last_rendered=progress_state.last_rendered, resume_formatter=runner.format_resume, + context_line=context_line, ) running_task: RunningTask | None = None if running_tasks is not None and progress_ref is not None: - running_task = RunningTask() + running_task = RunningTask(context=context) running_tasks[progress_ref] = running_task cancel_exc_type = anyio.get_cancelled_exc_class() @@ -464,7 +479,10 @@ async def handle_message( if error is not None: sync_resume_token(progress_tracker, outcome.resume) err_body = _format_error(error) - state = progress_tracker.snapshot(resume_formatter=runner.format_resume) + state = progress_tracker.snapshot( + resume_formatter=runner.format_resume, + context_line=context_line, + ) final_rendered = cfg.presenter.render_final( state, elapsed_s=elapsed, @@ -496,7 +514,10 @@ async def handle_message( resume=resume.value if resume else None, elapsed_s=elapsed, ) - state = progress_tracker.snapshot(resume_formatter=runner.format_resume) + state = progress_tracker.snapshot( + resume_formatter=runner.format_resume, + context_line=context_line, + ) final_rendered = cfg.presenter.render_progress( state, elapsed_s=elapsed, @@ -546,7 +567,10 @@ async def handle_message( resume=resume_value, ) sync_resume_token(progress_tracker, completed.resume or outcome.resume) - state = progress_tracker.snapshot(resume_formatter=runner.format_resume) + state = progress_tracker.snapshot( + resume_formatter=runner.format_resume, + context_line=context_line, + ) final_rendered = cfg.presenter.render_final( state, elapsed_s=elapsed, diff --git a/src/takopi/runners/codex.py b/src/takopi/runners/codex.py index bc6401f..cc04d15 100644 --- a/src/takopi/runners/codex.py +++ b/src/takopi/runners/codex.py @@ -76,6 +76,26 @@ def _summarize_tool_result(result: Any) -> dict[str, Any] | None: return None +def _normalize_change_list(changes: list[Any]) -> list[dict[str, str]]: + normalized: list[dict[str, str]] = [] + for change in changes: + path: str | None = None + kind: str | None = None + if isinstance(change, codex_schema.FileUpdateChange): + path = change.path + kind = change.kind + elif isinstance(change, dict): + path = change.get("path") + kind = change.get("kind") + if not isinstance(path, str) or not path: + continue + entry = {"path": path} + if isinstance(kind, str) and kind: + entry["kind"] = kind + normalized.append(entry) + return normalized + + def _format_change_summary(changes: list[Any]) -> str: paths: list[str] = [] for change in changes: @@ -260,8 +280,9 @@ def _translate_item_event( if phase != "completed": return [] title = _format_change_summary(changes) + normalized_changes = _normalize_change_list(changes) detail = { - "changes": changes, + "changes": normalized_changes, "status": status, "error": None, } diff --git a/src/takopi/scheduler.py b/src/takopi/scheduler.py index 9becdf0..ead25d7 100644 --- a/src/takopi/scheduler.py +++ b/src/takopi/scheduler.py @@ -6,6 +6,7 @@ from typing import Any, Awaitable, Callable, Protocol import anyio +from .context import RunContext from .model import ResumeToken @@ -15,6 +16,7 @@ class ThreadJob: user_msg_id: int text: str resume_token: ResumeToken + context: RunContext | None = None RunJob = Callable[[ThreadJob], Awaitable[None]] @@ -66,6 +68,7 @@ class ThreadScheduler: user_msg_id: int, text: str, resume_token: ResumeToken, + context: RunContext | None = None, ) -> None: await self.enqueue( ThreadJob( @@ -73,6 +76,7 @@ class ThreadScheduler: user_msg_id=user_msg_id, text=text, resume_token=resume_token, + context=context, ) ) diff --git a/src/takopi/telegram/__init__.py b/src/takopi/telegram/__init__.py index f620aaa..1b583fa 100644 --- a/src/takopi/telegram/__init__.py +++ b/src/takopi/telegram/__init__.py @@ -1 +1,5 @@ """Telegram-specific clients and adapters.""" + +from .client import parse_incoming_update, poll_incoming + +__all__ = ["parse_incoming_update", "poll_incoming"] diff --git a/src/takopi/telegram/bridge.py b/src/takopi/telegram/bridge.py index 27d5fcf..767f65d 100644 --- a/src/takopi/telegram/bridge.py +++ b/src/takopi/telegram/bridge.py @@ -1,14 +1,16 @@ from __future__ import annotations from collections.abc import AsyncIterator, Awaitable, Callable -from dataclasses import dataclass -from typing import Any +from dataclasses import dataclass, field + import anyio +from ..config import ProjectsConfig, empty_projects_config +from ..context import RunContext from ..runner_bridge import ( ExecBridgeConfig, - IncomingMessage, + IncomingMessage as RunnerIncomingMessage, RunningTask, RunningTasks, handle_message, @@ -20,8 +22,16 @@ from ..progress import ProgressState, ProgressTracker from ..router import AutoRouter, RunnerUnavailableError from ..runner import Runner from ..scheduler import ThreadJob, ThreadScheduler -from ..transport import MessageRef, RenderedMessage, SendOptions, Transport -from .client import BotClient +from ..transport import ( + IncomingMessage as TransportIncomingMessage, + MessageRef, + RenderedMessage, + SendOptions, + Transport, +) +from ..utils.paths import reset_run_base_dir, set_run_base_dir +from ..worktrees import WorktreeError, resolve_run_cwd +from .client import BotClient, poll_incoming from .render import prepare_telegram logger = get_logger(__name__) @@ -70,6 +80,206 @@ def _strip_engine_command( return "\n".join(lines).strip(), engine +@dataclass(frozen=True, slots=True) +class ParsedDirectives: + prompt: str + engine: EngineId | None + project: str | None + branch: str | None + + +@dataclass(frozen=True, slots=True) +class ResolvedMessage: + prompt: str + resume_token: ResumeToken | None + engine_override: EngineId | None + context: RunContext | None + + +class DirectiveError(RuntimeError): + pass + + +def _parse_directives( + text: str, + *, + engine_ids: tuple[EngineId, ...], + projects: ProjectsConfig, +) -> ParsedDirectives: + if not text: + return ParsedDirectives(prompt="", engine=None, project=None, branch=None) + + lines = text.splitlines() + idx = next((i for i, line in enumerate(lines) if line.strip()), None) + if idx is None: + return ParsedDirectives(prompt=text, engine=None, project=None, branch=None) + + line = lines[idx].lstrip() + tokens = line.split() + if not tokens: + return ParsedDirectives(prompt=text, engine=None, project=None, branch=None) + + engine_map = {engine.lower(): engine for engine in engine_ids} + project_map = {alias.lower(): alias for alias in projects.projects} + + engine: EngineId | None = None + project: str | None = None + branch: str | None = None + consumed = 0 + + for token in tokens: + if token.startswith("/"): + name = token[1:] + if "@" in name: + name = name.split("@", 1)[0] + if not name: + break + key = name.lower() + engine_candidate = engine_map.get(key) + project_candidate = project_map.get(key) + if engine_candidate is not None: + if engine is not None: + raise DirectiveError("multiple engine directives") + engine = engine_candidate + consumed += 1 + continue + if project_candidate is not None: + if project is not None: + raise DirectiveError("multiple project directives") + project = project_candidate + consumed += 1 + continue + break + if token.startswith("@"): + value = token[1:] + if not value: + break + if branch is not None: + raise DirectiveError("multiple @branch directives") + branch = value + consumed += 1 + continue + break + + if consumed == 0: + return ParsedDirectives(prompt=text, engine=None, project=None, branch=None) + + if consumed < len(tokens): + remainder = " ".join(tokens[consumed:]) + lines[idx] = remainder + else: + lines.pop(idx) + + prompt = "\n".join(lines).strip() + return ParsedDirectives( + prompt=prompt, engine=engine, project=project, branch=branch + ) + + +def _parse_ctx_line(text: str | None, *, projects: ProjectsConfig) -> RunContext | None: + if not text: + return None + ctx: RunContext | None = None + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("`") and stripped.endswith("`") and len(stripped) > 1: + stripped = stripped[1:-1].strip() + elif stripped.startswith("`"): + stripped = stripped[1:].strip() + elif stripped.endswith("`"): + stripped = stripped[:-1].strip() + if not stripped.lower().startswith("ctx:"): + continue + content = stripped.split(":", 1)[1].strip() + if not content: + continue + tokens = content.split() + if not tokens: + continue + project = tokens[0] + branch = None + if len(tokens) >= 2: + if tokens[1] == "@" and len(tokens) >= 3: + branch = tokens[2] + elif tokens[1].startswith("@"): + branch = tokens[1][1:] + project_key = project.lower() + if project_key not in projects.projects: + raise DirectiveError(f"unknown project {project!r} in ctx line") + ctx = RunContext(project=project_key, branch=branch) + return ctx + + +def _format_context_line( + context: RunContext | None, *, projects: ProjectsConfig +) -> str | None: + if context is None or context.project is None: + return None + project_cfg = projects.projects.get(context.project) + alias = project_cfg.alias if project_cfg is not None else context.project + if context.branch: + return f"`ctx: {alias} @ {context.branch}`" + return f"`ctx: {alias}`" + + +def _resolve_message( + *, + text: str, + reply_text: str | None, + router: AutoRouter, + projects: ProjectsConfig, +) -> ResolvedMessage: + directives = _parse_directives( + text, + engine_ids=router.engine_ids, + projects=projects, + ) + reply_ctx = _parse_ctx_line(reply_text, projects=projects) + resume_token = router.resolve_resume(directives.prompt, reply_text) + + if resume_token is not None: + return ResolvedMessage( + prompt=directives.prompt, + resume_token=resume_token, + engine_override=None, + context=reply_ctx, + ) + + if reply_ctx is not None: + engine_override = None + if reply_ctx.project is not None: + project = projects.projects.get(reply_ctx.project) + if project is not None and project.default_engine is not None: + engine_override = project.default_engine + return ResolvedMessage( + prompt=directives.prompt, + resume_token=None, + engine_override=engine_override, + context=reply_ctx, + ) + + project_key = directives.project + if project_key is None and projects.default_project is not None: + project_key = projects.default_project + + context = None + if project_key is not None or directives.branch is not None: + context = RunContext(project=project_key, branch=directives.branch) + + engine_override = directives.engine + if engine_override is None and project_key is not None: + project = projects.projects.get(project_key) + if project is not None and project.default_engine is not None: + engine_override = project.default_engine + + return ResolvedMessage( + prompt=directives.prompt, + resume_token=None, + engine_override=engine_override, + context=context, + ) + + def _build_bot_commands(router: AutoRouter) -> list[dict[str, str]]: commands: list[dict[str, str]] = [] seen: set[str] = set() @@ -232,6 +442,7 @@ class TelegramBridgeConfig: chat_id: int startup_msg: str exec_cfg: ExecBridgeConfig + projects: ProjectsConfig = field(default_factory=empty_projects_config) async def _send_plain( @@ -281,41 +492,35 @@ async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int | drained += len(updates) -async def poll_updates(cfg: TelegramBridgeConfig) -> AsyncIterator[dict[str, Any]]: +async def poll_updates( + cfg: TelegramBridgeConfig, +) -> AsyncIterator[TransportIncomingMessage]: offset: int | None = None offset = await _drain_backlog(cfg, offset) await _send_startup(cfg) - while True: - updates = await cfg.bot.get_updates( - offset=offset, timeout_s=50, allowed_updates=["message"] - ) - if updates is None: - logger.info("loop.get_updates.failed") - await anyio.sleep(2) - continue - logger.debug("loop.updates", updates=updates) - - for upd in updates: - offset = upd["update_id"] + 1 - msg = upd["message"] - if "text" not in msg: - continue - if msg["chat"]["id"] != cfg.chat_id: - continue - yield msg + async for msg in poll_incoming(cfg.bot, chat_id=cfg.chat_id, offset=offset): + yield msg async def _handle_cancel( cfg: TelegramBridgeConfig, - msg: dict[str, Any], + msg: TransportIncomingMessage, running_tasks: RunningTasks, ) -> None: - chat_id = msg["chat"]["id"] - user_msg_id = msg["message_id"] - reply = msg.get("reply_to_message") + chat_id = msg.chat_id + user_msg_id = msg.message_id + reply_id = msg.reply_to_message_id - if not reply: + if reply_id is None: + if msg.reply_to_text: + await _send_plain( + cfg.exec_cfg.transport, + chat_id=chat_id, + user_msg_id=user_msg_id, + text="nothing is currently running for that message.", + ) + return await _send_plain( cfg.exec_cfg.transport, chat_id=chat_id, @@ -324,17 +529,7 @@ async def _handle_cancel( ) return - progress_id = reply.get("message_id") - if progress_id is None: - await _send_plain( - cfg.exec_cfg.transport, - chat_id=chat_id, - user_msg_id=user_msg_id, - text="nothing is currently running for that message.", - ) - return - - progress_ref = MessageRef(channel_id=chat_id, message_id=progress_id) + progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id) running_task = running_tasks.get(progress_ref) if running_task is None: await _send_plain( @@ -348,7 +543,7 @@ async def _handle_cancel( logger.info( "cancel.requested", chat_id=chat_id, - progress_message_id=progress_id, + progress_message_id=reply_id, ) running_task.cancel_requested.set() @@ -378,7 +573,7 @@ async def _wait_for_resume(running_task: RunningTask) -> ResumeToken | None: async def _send_with_resume( cfg: TelegramBridgeConfig, - enqueue: Callable[[int, int, str, ResumeToken], Awaitable[None]], + enqueue: Callable[[int, int, str, ResumeToken, RunContext | None], Awaitable[None]], running_task: RunningTask, chat_id: int, user_msg_id: int, @@ -394,7 +589,7 @@ async def _send_with_resume( notify=False, ) return - await enqueue(chat_id, user_msg_id, text, resume) + await enqueue(chat_id, user_msg_id, text, resume, running_task.context) async def _send_runner_unavailable( @@ -426,7 +621,7 @@ async def _send_runner_unavailable( async def run_main_loop( cfg: TelegramBridgeConfig, poller: Callable[ - [TelegramBridgeConfig], AsyncIterator[dict[str, Any]] + [TelegramBridgeConfig], AsyncIterator[TransportIncomingMessage] ] = poll_updates, ) -> None: running_tasks: RunningTasks = {} @@ -440,6 +635,7 @@ async def run_main_loop( user_msg_id: int, text: str, resume_token: ResumeToken | None, + context: RunContext | None, reply_ref: MessageRef | None = None, on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None = None, @@ -471,27 +667,52 @@ async def run_main_loop( reason=reason, ) return - bind_run_context( - chat_id=chat_id, - user_msg_id=user_msg_id, - engine=entry.runner.engine, - resume=resume_token.value if resume_token else None, - ) - incoming = IncomingMessage( - channel_id=chat_id, - message_id=user_msg_id, - text=text, - reply_to=reply_ref, - ) - await handle_message( - cfg.exec_cfg, - runner=entry.runner, - incoming=incoming, - resume_token=resume_token, - strip_resume_line=cfg.router.is_resume_line, - running_tasks=running_tasks, - on_thread_known=on_thread_known, - ) + try: + cwd = resolve_run_cwd(context, projects=cfg.projects) + except WorktreeError as exc: + await _send_plain( + cfg.exec_cfg.transport, + chat_id=chat_id, + user_msg_id=user_msg_id, + text=f"error:\n{exc}", + ) + return + run_base_token = set_run_base_dir(cwd) + try: + run_fields = { + "chat_id": chat_id, + "user_msg_id": user_msg_id, + "engine": entry.runner.engine, + "resume": resume_token.value if resume_token else None, + } + if context is not None: + run_fields["project"] = context.project + run_fields["branch"] = context.branch + if cwd is not None: + run_fields["cwd"] = str(cwd) + bind_run_context(**run_fields) + context_line = _format_context_line( + context, projects=cfg.projects + ) + incoming = RunnerIncomingMessage( + channel_id=chat_id, + message_id=user_msg_id, + text=text, + reply_to=reply_ref, + ) + await handle_message( + cfg.exec_cfg, + runner=entry.runner, + incoming=incoming, + resume_token=resume_token, + context=context, + context_line=context_line, + strip_resume_line=cfg.router.is_resume_line, + running_tasks=running_tasks, + on_thread_known=on_thread_known, + ) + finally: + reset_run_base_dir(run_base_token) except Exception as exc: logger.exception( "handle.worker_failed", @@ -507,33 +728,48 @@ async def run_main_loop( job.user_msg_id, job.text, job.resume_token, + job.context, None, ) scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job) async for msg in poller(cfg): - text = msg["text"] - user_msg_id = msg["message_id"] - chat_id = msg["chat"]["id"] - reply_ref = None - reply_msg = msg.get("reply_to_message") - if reply_msg: - reply_id = reply_msg.get("message_id") - if reply_id is not None: - reply_ref = MessageRef(channel_id=chat_id, message_id=reply_id) + text = msg.text + user_msg_id = msg.message_id + chat_id = msg.chat_id + reply_id = msg.reply_to_message_id + reply_ref = ( + MessageRef(channel_id=chat_id, message_id=reply_id) + if reply_id is not None + else None + ) if _is_cancel_command(text): tg.start_soon(_handle_cancel, cfg, msg, running_tasks) continue - text, engine_override = _strip_engine_command( - text, engine_ids=cfg.router.engine_ids - ) + reply_text = msg.reply_to_text + try: + resolved = _resolve_message( + text=text, + reply_text=reply_text, + router=cfg.router, + projects=cfg.projects, + ) + except DirectiveError as exc: + await _send_plain( + cfg.exec_cfg.transport, + chat_id=chat_id, + user_msg_id=user_msg_id, + text=f"error:\n{exc}", + ) + continue - r = msg.get("reply_to_message") or {} - resume_token = cfg.router.resolve_resume(text, r.get("text")) - reply_id = r.get("message_id") + text = resolved.prompt + resume_token = resolved.resume_token + engine_override = resolved.engine_override + context = resolved.context if resume_token is None and reply_id is not None: running_task = running_tasks.get( MessageRef(channel_id=chat_id, message_id=reply_id) @@ -557,13 +793,18 @@ async def run_main_loop( user_msg_id, text, None, + context, reply_ref, scheduler.note_thread_known, engine_override, ) else: await scheduler.enqueue_resume( - chat_id, user_msg_id, text, resume_token + chat_id, + user_msg_id, + text, + resume_token, + context, ) finally: await cfg.exec_cfg.transport.close() diff --git a/src/takopi/telegram/client.py b/src/takopi/telegram/client.py index 52e3c69..83f44a5 100644 --- a/src/takopi/telegram/client.py +++ b/src/takopi/telegram/client.py @@ -3,13 +3,22 @@ from __future__ import annotations import itertools import time from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, Hashable, Protocol, TYPE_CHECKING +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Hashable, + Protocol, + TYPE_CHECKING, +) import httpx import anyio from ..logging import get_logger +from ..transport import IncomingMessage logger = get_logger(__name__) @@ -34,6 +43,76 @@ def is_group_chat_id(chat_id: int) -> bool: return chat_id < 0 +def parse_incoming_update( + update: dict[str, Any], *, chat_id: int +) -> IncomingMessage | None: + msg = update.get("message") + if not isinstance(msg, dict): + return None + text = msg.get("text") + if not isinstance(text, str): + return None + chat = msg.get("chat") + if not isinstance(chat, dict): + return None + msg_chat_id = chat.get("id") + if not isinstance(msg_chat_id, int) or msg_chat_id != chat_id: + return None + message_id = msg.get("message_id") + if not isinstance(message_id, int): + return None + reply = msg.get("reply_to_message") + reply_to_message_id = None + reply_to_text = None + if isinstance(reply, dict): + reply_to_message_id = ( + reply.get("message_id") + if isinstance(reply.get("message_id"), int) + else None + ) + reply_to_text = ( + reply.get("text") if isinstance(reply.get("text"), str) else None + ) + sender = msg.get("from") + sender_id = ( + sender.get("id") + if isinstance(sender, dict) and isinstance(sender.get("id"), int) + else None + ) + return IncomingMessage( + transport="telegram", + chat_id=msg_chat_id, + message_id=message_id, + text=text, + reply_to_message_id=reply_to_message_id, + reply_to_text=reply_to_text, + sender_id=sender_id, + raw=msg, + ) + + +async def poll_incoming( + bot: BotClient, + *, + chat_id: int, + offset: int | None = None, +) -> AsyncIterator[IncomingMessage]: + while True: + updates = await bot.get_updates( + offset=offset, timeout_s=50, allowed_updates=["message"] + ) + if updates is None: + logger.info("loop.get_updates.failed") + await anyio.sleep(2) + continue + logger.debug("loop.updates", updates=updates) + for upd in updates: + offset = upd["update_id"] + 1 + msg = parse_incoming_update(upd, chat_id=chat_id) + if msg is not None: + yield msg + + class BotClient(Protocol): async def close(self) -> None: ... diff --git a/src/takopi/transport.py b/src/takopi/transport.py index b0a8789..fab11d8 100644 --- a/src/takopi/transport.py +++ b/src/takopi/transport.py @@ -7,6 +7,18 @@ ChannelId: TypeAlias = int | str MessageId: TypeAlias = int | str +@dataclass(frozen=True, slots=True) +class IncomingMessage: + transport: str + chat_id: int + message_id: int + text: str + reply_to_message_id: int | None + reply_to_text: str | None + sender_id: int | None + raw: dict[str, Any] | None = None + + @dataclass(frozen=True, slots=True) class MessageRef: channel_id: ChannelId diff --git a/src/takopi/utils/git.py b/src/takopi/utils/git.py new file mode 100644 index 0000000..75d1028 --- /dev/null +++ b/src/takopi/utils/git.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import subprocess +from collections.abc import Sequence +from pathlib import Path + + +def _run_git( + args: Sequence[str], *, cwd: Path +) -> subprocess.CompletedProcess[str] | None: + try: + return subprocess.run( + ["git", *args], + cwd=cwd, + check=False, + text=True, + capture_output=True, + ) + except FileNotFoundError: + return None + + +def git_run( + args: Sequence[str], *, cwd: Path +) -> subprocess.CompletedProcess[str] | None: + return _run_git(args, cwd=cwd) + + +def git_stdout(args: Sequence[str], *, cwd: Path) -> str | None: + result = _run_git(args, cwd=cwd) + if result is None or result.returncode != 0: + return None + output = result.stdout.strip() + return output or None + + +def git_ok(args: Sequence[str], *, cwd: Path) -> bool: + result = _run_git(args, cwd=cwd) + return result is not None and result.returncode == 0 + + +def git_is_worktree(path: Path) -> bool: + return git_stdout(["rev-parse", "--is-inside-work-tree"], cwd=path) == "true" + + +def resolve_default_base(root: Path) -> str | None: + origin_head = git_stdout( + ["symbolic-ref", "-q", "refs/remotes/origin/HEAD"], + cwd=root, + ) + if origin_head: + prefix = "refs/remotes/origin/" + if origin_head.startswith(prefix): + name = origin_head[len(prefix) :].strip() + if name: + return name + + current = git_stdout(["branch", "--show-current"], cwd=root) + if current: + return current + + if git_ok(["show-ref", "--verify", "--quiet", "refs/heads/master"], cwd=root): + return "master" + if git_ok(["show-ref", "--verify", "--quiet", "refs/heads/main"], cwd=root): + return "main" + return None + + +def resolve_main_worktree_root(cwd: Path) -> Path | None: + common_dir = git_stdout( + ["rev-parse", "--path-format=absolute", "--git-common-dir"], + cwd=cwd, + ) + if not common_dir: + return None + if git_stdout(["rev-parse", "--is-bare-repository"], cwd=cwd) == "true": + return cwd + common_path = Path(common_dir) + if not common_path.is_absolute(): + common_path = (cwd / common_path).resolve() + return common_path.parent diff --git a/src/takopi/utils/paths.py b/src/takopi/utils/paths.py index a4518dc..6f1312a 100644 --- a/src/takopi/utils/paths.py +++ b/src/takopi/utils/paths.py @@ -1,13 +1,31 @@ from __future__ import annotations import os +from contextvars import ContextVar, Token from pathlib import Path +_run_base_dir: ContextVar[Path | None] = ContextVar("takopi_run_base_dir", default=None) + + +def get_run_base_dir() -> Path | None: + return _run_base_dir.get() + + +def set_run_base_dir(base_dir: Path | None) -> Token[Path | None]: + return _run_base_dir.set(base_dir) + + +def reset_run_base_dir(token: Token[Path | None]) -> None: + _run_base_dir.reset(token) + + def relativize_path(value: str, *, base_dir: Path | None = None) -> str: if not value: return value - base = Path.cwd() if base_dir is None else base_dir + base = get_run_base_dir() if base_dir is None else base_dir + if base is None: + base = Path.cwd() base_str = str(base) if not base_str: return value @@ -22,6 +40,8 @@ def relativize_path(value: str, *, base_dir: Path | None = None) -> str: def relativize_command(value: str, *, base_dir: Path | None = None) -> str: - base = Path.cwd() if base_dir is None else base_dir + base = get_run_base_dir() if base_dir is None else base_dir + if base is None: + base = Path.cwd() base_with_sep = f"{base}{os.sep}" return value.replace(base_with_sep, "") diff --git a/src/takopi/worktrees.py b/src/takopi/worktrees.py new file mode 100644 index 0000000..96d6e2c --- /dev/null +++ b/src/takopi/worktrees.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from pathlib import Path + +from .config import ProjectConfig, ProjectsConfig +from .context import RunContext +from .utils.git import git_is_worktree, git_ok, git_run, resolve_default_base + + +class WorktreeError(RuntimeError): + pass + + +def resolve_run_cwd( + context: RunContext | None, + *, + projects: ProjectsConfig, +) -> Path | None: + if context is None or context.project is None: + return None + project = projects.projects.get(context.project) + if project is None: + raise WorktreeError(f"unknown project {context.project!r}") + if context.branch is None: + return project.path + return ensure_worktree(project, context.branch) + + +def ensure_worktree(project: ProjectConfig, branch: str) -> Path: + root = project.path + if not root.exists(): + raise WorktreeError(f"project path not found: {root}") + + branch = _sanitize_branch(branch) + worktrees_root = project.worktrees_root + worktree_path = worktrees_root / branch + _ensure_within_root(worktrees_root, worktree_path) + + if worktree_path.exists(): + if not git_is_worktree(worktree_path): + raise WorktreeError(f"{worktree_path} exists but is not a git worktree") + return worktree_path + + worktrees_root.mkdir(parents=True, exist_ok=True) + + if git_ok( + ["show-ref", "--verify", "--quiet", f"refs/heads/{branch}"], + cwd=root, + ): + _git_worktree_add(root, worktree_path, branch) + return worktree_path + + if git_ok( + ["show-ref", "--verify", "--quiet", f"refs/remotes/origin/{branch}"], + cwd=root, + ): + _git_worktree_add( + root, + worktree_path, + branch, + base_ref=f"origin/{branch}", + create_branch=True, + ) + return worktree_path + + base = project.worktree_base or resolve_default_base(root) + if not base: + raise WorktreeError("cannot determine base branch for new worktree") + + _git_worktree_add( + root, + worktree_path, + branch, + base_ref=base, + create_branch=True, + ) + return worktree_path + + +def _git_worktree_add( + root: Path, + worktree_path: Path, + branch: str, + *, + base_ref: str | None = None, + create_branch: bool = False, +) -> None: + if create_branch: + if not base_ref: + raise WorktreeError("missing base ref for worktree creation") + args = ["worktree", "add", "-b", branch, str(worktree_path), base_ref] + else: + args = ["worktree", "add", str(worktree_path), branch] + + result = git_run(args, cwd=root) + if result is None: + raise WorktreeError("git not available on PATH") + if result.returncode != 0: + message = result.stderr.strip() or result.stdout.strip() + raise WorktreeError(message or "git worktree add failed") + + +def _sanitize_branch(branch: str) -> str: + cleaned = branch.strip() + if not cleaned: + raise WorktreeError("branch name cannot be empty") + if cleaned.startswith("/"): + raise WorktreeError("branch name cannot start with '/'") + for part in Path(cleaned).parts: + if part == "..": + raise WorktreeError("branch name cannot contain '..'") + return cleaned + + +def _ensure_within_root(root: Path, path: Path) -> None: + root_resolved = root.resolve(strict=False) + path_resolved = path.resolve(strict=False) + if not path_resolved.is_relative_to(root_resolved): + raise WorktreeError("branch path escapes the worktrees directory") diff --git a/tests/test_codex_tool_result_summary.py b/tests/test_codex_tool_result_summary.py index 8fc3cb3..fccd9d1 100644 --- a/tests/test_codex_tool_result_summary.py +++ b/tests/test_codex_tool_result_summary.py @@ -104,3 +104,21 @@ def test_translate_command_execution_allows_null_exit_code() -> None: assert isinstance(out[0], ActionEvent) assert out[0].ok is True assert out[0].action.detail["exit_code"] is None + + +def test_translate_file_change_normalizes_changes() -> None: + evt = { + "type": "item.completed", + "item": { + "id": "item_6", + "type": "file_change", + "changes": [{"path": "README.md", "kind": "update"}], + "status": "completed", + }, + } + + out = _translate_event(evt) + assert len(out) == 1 + assert isinstance(out[0], ActionEvent) + changes = out[0].action.detail["changes"] + assert changes == [{"path": "README.md", "kind": "update"}] diff --git a/tests/test_exec_bridge.py b/tests/test_exec_bridge.py index 5636990..b36c476 100644 --- a/tests/test_exec_bridge.py +++ b/tests/test_exec_bridge.py @@ -334,6 +334,37 @@ async def test_bridge_flow_sends_progress_edits_and_final_resume() -> None: assert transport.send_calls[-1]["options"].replace == transport.send_calls[0]["ref"] +@pytest.mark.anyio +async def test_final_message_includes_ctx_line() -> None: + transport = _FakeTransport() + clock = _FakeClock() + session_id = "123e4567-e89b-12d3-a456-426614174000" + runner = ScriptRunner( + [Return(answer="done")], + engine=CODEX_ENGINE, + resume_value=session_id, + ) + cfg = ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ) + + await handle_message( + cfg, + runner=runner, + incoming=IncomingMessage(channel_id=123, message_id=42, text="do it"), + resume_token=None, + context_line="`ctx: takopi @ feat/api`", + clock=clock, + ) + + assert transport.send_calls + final_text = transport.send_calls[-1]["message"].text + assert "`ctx: takopi @ feat/api`" in final_text + assert "codex resume" in final_text.lower() + + @pytest.mark.anyio async def test_handle_message_cancelled_renders_cancelled_state() -> None: transport = _FakeTransport() diff --git a/tests/test_exec_render.py b/tests/test_exec_render.py index c1b42b1..a72cacd 100644 --- a/tests/test_exec_render.py +++ b/tests/test_exec_render.py @@ -16,6 +16,7 @@ from takopi.markdown import ( from takopi.model import Action, ActionEvent, ResumeToken, StartedEvent, TakopiEvent from takopi.progress import ProgressTracker from takopi.telegram.render import render_markdown +from takopi.utils.paths import reset_run_base_dir, set_run_base_dir from tests.factories import ( action_completed, action_started, @@ -119,6 +120,40 @@ def test_file_change_renders_relative_paths_inside_cwd() -> None: ) +def test_file_change_renders_change_objects(tmp_path: Path) -> None: + base = tmp_path / "repo" + base.mkdir() + abs_path = str(base / "changelog.md") + token = set_run_base_dir(base) + try: + out = render_event_cli( + action_completed( + "f-obj", + "file_change", + "ignored", + ok=True, + detail={"changes": [SimpleNamespace(path=abs_path, kind="update")]}, + ) + ) + finally: + reset_run_base_dir(token) + assert any("files: update `changelog.md`" in line for line in out) + + +def test_file_change_title_relativizes_absolute_title(tmp_path: Path) -> None: + base = tmp_path / "repo" + base.mkdir() + abs_path = str(base / "changelog.md") + token = set_run_base_dir(base) + try: + out = render_event_cli( + action_completed("f-abs", "file_change", abs_path, ok=True) + ) + finally: + reset_run_base_dir(token) + assert any("files: `changelog.md`" in line for line in out) + + def test_progress_renderer_renders_progress_and_final() -> None: tracker = ProgressTracker(engine="codex") for evt in SAMPLE_EVENTS: @@ -145,6 +180,23 @@ def test_progress_renderer_renders_progress_and_final() -> None: ) +def test_progress_renderer_footer_includes_ctx_before_resume() -> None: + tracker = ProgressTracker(engine="codex") + for evt in SAMPLE_EVENTS: + tracker.note_event(evt) + + state = tracker.snapshot( + resume_formatter=_format_resume, + context_line="`ctx: z80 @ feat/name`", + ) + formatter = MarkdownFormatter(max_actions=5) + parts = formatter.render_progress_parts(state, elapsed_s=0.0) + assert parts.footer == ( + "`ctx: z80 @ feat/name`" + f"{HARD_BREAK}`codex resume 0199a213-81c0-7800-8aa1-bbab2a035a53`" + ) + + def test_progress_renderer_clamps_actions_and_ignores_unknown() -> None: tracker = ProgressTracker(engine="codex") events = [ diff --git a/tests/test_git_utils.py b/tests/test_git_utils.py new file mode 100644 index 0000000..06834bd --- /dev/null +++ b/tests/test_git_utils.py @@ -0,0 +1,56 @@ +from pathlib import Path + +from takopi.utils.git import resolve_default_base, resolve_main_worktree_root + + +def test_resolve_main_worktree_root_returns_none_when_no_git(monkeypatch) -> None: + monkeypatch.setattr("takopi.utils.git.git_stdout", lambda *args, **kwargs: None) + assert resolve_main_worktree_root(Path("/tmp")) is None + + +def test_resolve_main_worktree_root_prefers_common_dir_parent(monkeypatch) -> None: + base = Path("/repo") + + def _fake_stdout(args, **kwargs): + if args[:2] == ["rev-parse", "--path-format=absolute"]: + return str(base / ".git") + if args == ["rev-parse", "--is-bare-repository"]: + return "false" + return None + + monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout) + assert resolve_main_worktree_root(base / ".worktrees" / "feature") == base + + +def test_resolve_main_worktree_root_returns_cwd_for_bare_repo(monkeypatch) -> None: + cwd = Path("/bare-repo") + + def _fake_stdout(args, **kwargs): + if args[:2] == ["rev-parse", "--path-format=absolute"]: + return str(cwd / "repo.git") + if args == ["rev-parse", "--is-bare-repository"]: + return "true" + return None + + monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout) + assert resolve_main_worktree_root(cwd) == cwd + + +def test_resolve_default_base_prefers_master_over_main(monkeypatch) -> None: + def _fake_stdout(args, **kwargs): + if args[:2] == ["symbolic-ref", "-q"]: + return None + if args == ["branch", "--show-current"]: + return None + return None + + def _fake_ok(args, **kwargs): + if args == ["show-ref", "--verify", "--quiet", "refs/heads/master"]: + return True + if args == ["show-ref", "--verify", "--quiet", "refs/heads/main"]: + return True + return False + + monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout) + monkeypatch.setattr("takopi.utils.git.git_ok", _fake_ok) + assert resolve_default_base(Path("/repo")) == "master" diff --git a/tests/test_paths.py b/tests/test_paths.py index 0f3d772..6ab5392 100644 --- a/tests/test_paths.py +++ b/tests/test_paths.py @@ -2,7 +2,12 @@ from __future__ import annotations from pathlib import Path -from takopi.utils.paths import relativize_command, relativize_path +from takopi.utils.paths import ( + relativize_command, + relativize_path, + reset_run_base_dir, + set_run_base_dir, +) def test_relativize_command_rewrites_cwd_paths(tmp_path: Path) -> None: @@ -33,3 +38,14 @@ def test_relativize_path_inside_base(tmp_path: Path) -> None: base.mkdir() value = str(base / "src" / "app.py") assert relativize_path(value, base_dir=base) == "src/app.py" + + +def test_relativize_path_uses_run_base_dir(tmp_path: Path) -> None: + base = tmp_path / "repo" + base.mkdir() + token = set_run_base_dir(base) + try: + value = str(base / "src" / "app.py") + assert relativize_path(value) == "src/app.py" + finally: + reset_run_base_dir(token) diff --git a/tests/test_projects_config.py b/tests/test_projects_config.py new file mode 100644 index 0000000..fb78225 --- /dev/null +++ b/tests/test_projects_config.py @@ -0,0 +1,49 @@ +from pathlib import Path + +import pytest +from typer.testing import CliRunner + +from takopi import cli +from takopi.config import ConfigError, parse_projects_config + + +def test_parse_projects_rejects_engine_alias() -> None: + config = {"projects": {"codex": {"path": "/tmp/repo"}}} + with pytest.raises(ConfigError, match="aliases must not match engine ids"): + parse_projects_config( + config, + config_path=Path("takopi.toml"), + engine_ids=["codex"], + reserved=("cancel",), + ) + + +def test_parse_projects_default_project_must_exist() -> None: + config = {"default_project": "z80", "projects": {}} + with pytest.raises(ConfigError, match="default_project"): + parse_projects_config( + config, + config_path=Path("takopi.toml"), + engine_ids=["codex"], + reserved=("cancel",), + ) + + +def test_init_writes_project(monkeypatch, tmp_path) -> None: + config_path = tmp_path / "takopi.toml" + monkeypatch.setattr("takopi.config.HOME_CONFIG_PATH", config_path) + monkeypatch.setattr(cli, "resolve_default_base", lambda _: "main") + + repo_path = tmp_path / "repo" + repo_path.mkdir() + monkeypatch.chdir(repo_path) + + runner = CliRunner() + result = runner.invoke(cli.app, ["init", "z80"]) + assert result.exit_code == 0 + + saved = config_path.read_text(encoding="utf-8") + assert "[projects.z80]" in saved + assert 'worktrees_dir = ".worktrees"' in saved + assert 'default_engine = "codex"' in saved + assert 'worktree_base = "main"' in saved diff --git a/tests/test_telegram_bridge.py b/tests/test_telegram_bridge.py index 6e34300..8616bb0 100644 --- a/tests/test_telegram_bridge.py +++ b/tests/test_telegram_bridge.py @@ -1,3 +1,5 @@ +from pathlib import Path + import anyio import pytest @@ -7,16 +9,19 @@ from takopi.telegram.bridge import ( _build_bot_commands, _handle_cancel, _is_cancel_command, + _resolve_message, _send_with_resume, _strip_engine_command, run_main_loop, ) +from takopi.context import RunContext +from takopi.config import ProjectConfig, ProjectsConfig from takopi.runner_bridge import ExecBridgeConfig, RunningTask from takopi.markdown import MarkdownPresenter from takopi.model import EngineId, ResumeToken from takopi.router import AutoRouter, RunnerEntry from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait -from takopi.transport import MessageRef, RenderedMessage, SendOptions +from takopi.transport import IncomingMessage, MessageRef, RenderedMessage, SendOptions CODEX_ENGINE = EngineId("codex") @@ -354,7 +359,15 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None: async def test_handle_cancel_without_reply_prompts_user() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) - msg = {"chat": {"id": 123}, "message_id": 10} + msg = IncomingMessage( + transport="telegram", + chat_id=123, + message_id=10, + text="/cancel", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + ) running_tasks: dict = {} await _handle_cancel(cfg, msg, running_tasks) @@ -367,11 +380,15 @@ async def test_handle_cancel_without_reply_prompts_user() -> None: async def test_handle_cancel_with_no_progress_message_says_nothing_running() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) - msg = { - "chat": {"id": 123}, - "message_id": 10, - "reply_to_message": {"text": "no message id"}, - } + msg = IncomingMessage( + transport="telegram", + chat_id=123, + message_id=10, + text="/cancel", + reply_to_message_id=None, + reply_to_text="no message id", + sender_id=123, + ) running_tasks: dict = {} await _handle_cancel(cfg, msg, running_tasks) @@ -385,11 +402,15 @@ async def test_handle_cancel_with_finished_task_says_nothing_running() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) progress_id = 99 - msg = { - "chat": {"id": 123}, - "message_id": 10, - "reply_to_message": {"message_id": progress_id}, - } + msg = IncomingMessage( + transport="telegram", + chat_id=123, + message_id=10, + text="/cancel", + reply_to_message_id=progress_id, + reply_to_text=None, + sender_id=123, + ) running_tasks: dict = {} await _handle_cancel(cfg, msg, running_tasks) @@ -403,11 +424,15 @@ async def test_handle_cancel_cancels_running_task() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) progress_id = 42 - msg = { - "chat": {"id": 123}, - "message_id": 10, - "reply_to_message": {"message_id": progress_id}, - } + msg = IncomingMessage( + transport="telegram", + chat_id=123, + message_id=10, + text="/cancel", + reply_to_message_id=progress_id, + reply_to_text=None, + sender_id=123, + ) running_task = RunningTask() running_tasks = {MessageRef(channel_id=123, message_id=progress_id): running_task} @@ -423,11 +448,15 @@ async def test_handle_cancel_only_cancels_matching_progress_message() -> None: cfg = _make_cfg(transport) task_first = RunningTask() task_second = RunningTask() - msg = { - "chat": {"id": 123}, - "message_id": 10, - "reply_to_message": {"message_id": 1}, - } + msg = IncomingMessage( + transport="telegram", + chat_id=123, + message_id=10, + text="/cancel", + reply_to_message_id=1, + reply_to_text=None, + sender_id=123, + ) running_tasks = { MessageRef(channel_id=123, message_id=1): task_first, MessageRef(channel_id=123, message_id=2): task_second, @@ -446,16 +475,46 @@ def test_cancel_command_accepts_extra_text() -> None: assert _is_cancel_command("/cancelled") is False +def test_resolve_message_accepts_backticked_ctx_line() -> None: + router = _make_router(ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)) + projects = ProjectsConfig( + projects={ + "takopi": ProjectConfig( + alias="takopi", + path=Path("."), + worktrees_dir=Path(".worktrees"), + ) + }, + default_project=None, + ) + + resolved = _resolve_message( + text="do it", + reply_text="`ctx: takopi @ feat/api`", + router=router, + projects=projects, + ) + + assert resolved.prompt == "do it" + assert resolved.resume_token is None + assert resolved.engine_override is None + assert resolved.context == RunContext(project="takopi", branch="feat/api") + + @pytest.mark.anyio async def test_send_with_resume_waits_for_token() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) - sent: list[tuple[int, int, str, ResumeToken | None]] = [] + sent: list[tuple[int, int, str, ResumeToken, RunContext | None]] = [] async def enqueue( - chat_id: int, user_msg_id: int, text: str, resume: ResumeToken + chat_id: int, + user_msg_id: int, + text: str, + resume: ResumeToken, + context: RunContext | None, ) -> None: - sent.append((chat_id, user_msg_id, text, resume)) + sent.append((chat_id, user_msg_id, text, resume, context)) running_task = RunningTask() @@ -476,7 +535,7 @@ async def test_send_with_resume_waits_for_token() -> None: ) assert sent == [ - (123, 10, "hello", ResumeToken(engine=CODEX_ENGINE, value="abc123")) + (123, 10, "hello", ResumeToken(engine=CODEX_ENGINE, value="abc123"), None) ] assert transport.send_calls == [] @@ -485,12 +544,16 @@ async def test_send_with_resume_waits_for_token() -> None: async def test_send_with_resume_reports_when_missing() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) - sent: list[tuple[int, int, str, ResumeToken | None]] = [] + sent: list[tuple[int, int, str, ResumeToken, RunContext | None]] = [] async def enqueue( - chat_id: int, user_msg_id: int, text: str, resume: ResumeToken + chat_id: int, + user_msg_id: int, + text: str, + resume: ResumeToken, + context: RunContext | None, ) -> None: - sent.append((chat_id, user_msg_id, text, resume)) + sent.append((chat_id, user_msg_id, text, resume, context)) running_task = RunningTask() running_task.done.set() @@ -538,22 +601,29 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None: ) async def poller(_cfg: TelegramBridgeConfig): - yield { - "message_id": 1, - "text": "first", - "chat": {"id": 123}, - "from": {"id": 123}, - } + yield IncomingMessage( + transport="telegram", + chat_id=123, + message_id=1, + text="first", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + ) await progress_ready.wait() assert transport.progress_ref is not None + assert isinstance(transport.progress_ref.message_id, int) + reply_id = transport.progress_ref.message_id reply_ready.set() - yield { - "message_id": 2, - "text": "followup", - "chat": {"id": 123}, - "from": {"id": 123}, - "reply_to_message": {"message_id": transport.progress_ref.message_id}, - } + yield IncomingMessage( + transport="telegram", + chat_id=123, + message_id=2, + text="followup", + reply_to_message_id=reply_id, + reply_to_text=None, + sender_id=123, + ) await stop_polling.wait() async with anyio.create_task_group() as tg: diff --git a/tests/test_telegram_incoming.py b/tests/test_telegram_incoming.py new file mode 100644 index 0000000..c527792 --- /dev/null +++ b/tests/test_telegram_incoming.py @@ -0,0 +1,47 @@ +from takopi.telegram import parse_incoming_update + + +def test_parse_incoming_update_maps_fields() -> None: + update = { + "update_id": 1, + "message": { + "message_id": 10, + "text": "hello", + "chat": {"id": 123}, + "from": {"id": 99}, + "reply_to_message": {"message_id": 5, "text": "prev"}, + }, + } + + msg = parse_incoming_update(update, chat_id=123) + assert msg is not None + assert msg.transport == "telegram" + assert msg.chat_id == 123 + assert msg.message_id == 10 + assert msg.text == "hello" + assert msg.reply_to_message_id == 5 + assert msg.reply_to_text == "prev" + assert msg.sender_id == 99 + assert msg.raw == update["message"] + + +def test_parse_incoming_update_filters_non_matching_chat() -> None: + update = { + "update_id": 1, + "message": { + "message_id": 10, + "text": "hello", + "chat": {"id": 123}, + }, + } + + assert parse_incoming_update(update, chat_id=999) is None + + +def test_parse_incoming_update_filters_non_text() -> None: + update = { + "update_id": 1, + "message": {"message_id": 10, "chat": {"id": 123}}, + } + + assert parse_incoming_update(update, chat_id=123) is None diff --git a/tests/test_worktrees.py b/tests/test_worktrees.py new file mode 100644 index 0000000..33fcfe8 --- /dev/null +++ b/tests/test_worktrees.py @@ -0,0 +1,56 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from takopi.config import ProjectConfig, ProjectsConfig +from takopi.context import RunContext +from takopi.worktrees import WorktreeError, ensure_worktree, resolve_run_cwd + + +def _projects_config(path: Path) -> ProjectsConfig: + return ProjectsConfig( + projects={ + "z80": ProjectConfig( + alias="z80", + path=path, + worktrees_dir=Path(".worktrees"), + ) + }, + default_project=None, + ) + + +def test_resolve_run_cwd_uses_project_root(tmp_path: Path) -> None: + projects = _projects_config(tmp_path) + ctx = RunContext(project="z80") + assert resolve_run_cwd(ctx, projects=projects) == tmp_path + + +def test_resolve_run_cwd_rejects_invalid_branch(tmp_path: Path) -> None: + projects = _projects_config(tmp_path) + ctx = RunContext(project="z80", branch="../oops") + with pytest.raises(WorktreeError, match="branch name"): + resolve_run_cwd(ctx, projects=projects) + + +def test_ensure_worktree_creates_from_base(monkeypatch, tmp_path: Path) -> None: + project = ProjectConfig( + alias="z80", + path=tmp_path, + worktrees_dir=Path(".worktrees"), + ) + calls: list[list[str]] = [] + + monkeypatch.setattr("takopi.worktrees.git_ok", lambda *args, **kwargs: False) + monkeypatch.setattr("takopi.worktrees.resolve_default_base", lambda *_: "main") + + def _fake_git_run(args, cwd): + calls.append(list(args)) + return SimpleNamespace(returncode=0, stdout="", stderr="") + + monkeypatch.setattr("takopi.worktrees.git_run", _fake_git_run) + + worktree_path = ensure_worktree(project, "feat/name") + assert worktree_path == tmp_path / ".worktrees" / "feat" / "name" + assert calls == [["worktree", "add", "-b", "feat/name", str(worktree_path), "main"]]