feat: projects and worktree management (#62)

This commit is contained in:
banteg
2026-01-07 17:45:05 +04:00
committed by GitHub
parent 1178b738df
commit aa078258ea
28 changed files with 1735 additions and 144 deletions
+13
View File
@@ -10,6 +10,19 @@
- add transport/presenter protocols plus transport-agnostic `exec_bridge` - add transport/presenter protocols plus transport-agnostic `exec_bridge`
- move Telegram polling + wiring into `takopi.bridges.telegram` with transport/presenter adapters - move Telegram polling + wiring into `takopi.bridges.telegram` with transport/presenter adapters
- add project configuration, directive parsing (`/project`, `@branch`), and `ctx:`-aware routing for runs
- add `takopi init` to register project aliases from the main checkout (with worktree defaults)
- resolve git worktrees on demand and run engine subprocesses in the project/worktree cwd
- list configured projects in the startup banner
- add a shared incoming message shape plus Telegram parsing helpers
### fixes
- render `ctx:` footer lines consistently (backticked + hard breaks) and include them in final messages
### docs
- add a projects/worktrees guide and document `takopi init` behavior in the README
## v0.8.0 (2026-01-05) ## v0.8.0 (2026-01-05)
+135
View File
@@ -0,0 +1,135 @@
# Projects and Worktrees
This doc covers project aliases, worktree behavior, and how Takopi resolves run
context from messages.
## Overview
Projects let you give a repo an alias (used as `/alias` in messages) and opt into
worktree-based runs via `@branch`.
- If no projects are configured, Takopi runs in the startup working directory.
- If a project is configured, `@branch` resolves/creates a git worktree and runs
the task in that worktree.
- Progress/final messages include a `ctx:` footer when project context is active.
## Config schema
All config lives in `~/.takopi/takopi.toml`.
```toml
default_engine = "codex" # optional
default_project = "z80" # optional
bot_token = "..." # required
chat_id = 123 # required
[projects.z80]
path = "~/dev/z80" # required (repo root)
worktrees_dir = ".worktrees" # optional, default ".worktrees"
default_engine = "codex" # optional, per-project override
worktree_base = "master" # optional, base for new branches
```
Validation rules:
- `projects` is optional.
- Each project entry must include `path` (string, non-empty).
- `default_project` must match a configured project alias.
- Project aliases cannot collide with engine ids or reserved commands (`/cancel`).
- `default_engine` and per-project `default_engine` must be valid engine ids.
## `takopi init`
`takopi init <alias>` registers the current repo as a project alias.
Important behavior:
- The stored `path` is the **main checkout** of the repo, even if you run
`takopi init` inside a worktree. Takopi resolves the repo root via the git
common dir and writes that path to `[projects.<alias>].path`.
- `worktree_base` is set from the current repo using this resolution order:
`origin/HEAD` → current branch → `master``main`.
## Directives and context resolution
Takopi parses the first non-empty line of a message for a directive prefix.
Supported directives:
- `/engine` or `/engine@bot`: chooses the engine
- `/project`: chooses a project alias
- `@branch`: chooses a git branch/worktree
Rules:
- Directives must be a contiguous prefix of the line; parsing stops at the first
non-directive token.
- At most one engine directive, one project directive, and one `@branch` are
allowed (duplicates are errors).
- If a reply contains a `ctx:` line, Takopi **ignores new directives** and uses
the reply context.
## Context footer (`ctx:`)
When a run has project context, Takopi appends a footer line rendered as inline
code (backticked):
- With branch: `` `ctx: <project> @ <branch>` ``
- Without branch: `` `ctx: <project>` ``
The `ctx:` line is parsed from replies and takes precedence over new directives.
## Worktree resolution
When `@branch` is present:
```
worktrees_root = <project.path> / <worktrees_dir>
worktree_path = worktrees_root / <branch>
```
Branch validation:
- Must be non-empty
- Must not start with `/`
- Must not contain `..` path segments
- May include `/` (nested directories)
- The resolved worktree path must stay within `worktrees_root`
Worktree creation rules:
1) If `worktree_path` exists:
- It must be a git worktree or Takopi errors.
2) If it does not exist:
- If local branch exists: `git worktree add <path> <branch>`
- Else if remote `origin/<branch>` exists:
`git worktree add -b <branch> <path> origin/<branch>`
- Else:
`git worktree add -b <branch> <path> <base>`
Base branch selection:
1) `projects.<alias>.worktree_base` (if set)
2) `origin/HEAD` (if present)
3) current checked out branch
4) `master` if it exists
5) `main` if it exists
6) otherwise error
When `@branch` is omitted:
- Takopi runs in `<project.path>` (the main checkout).
## Examples
Start a new thread in a worktree:
```
/z80 @feat/streaming fix flaky test
```
Reply to a progress message to continue in the same context:
```
ctx: z80 @ feat/streaming
```
+22
View File
@@ -76,6 +76,28 @@ provider = "openai"
extra_args = ["--no-color"] extra_args = ["--no-color"]
``` ```
## projects (optional)
register the current repo as a project alias:
```sh
takopi init z80
```
`takopi init` writes the repo root to `[projects.<alias>].path`. if you run it inside a git worktree, it resolves the main checkout and records that path instead of the worktree.
example:
```toml
default_project = "z80"
[projects.z80]
path = "~/dev/z80"
worktrees_dir = ".worktrees"
default_engine = "codex"
worktree_base = "master"
```
## usage ## usage
start takopi in the repo you want to work on: start takopi in the repo you want to work on:
+126 -3
View File
@@ -11,10 +11,17 @@ import typer
from . import __version__ from . import __version__
from .backends import EngineBackend from .backends import EngineBackend
from .config import ConfigError from .config import (
ConfigError,
load_or_init_config,
parse_projects_config,
write_config,
)
from .engines import get_backend, get_engine_config, list_backends from .engines import get_backend, get_engine_config, list_backends
from .lockfile import LockError, LockHandle, acquire_lock, token_fingerprint from .lockfile import LockError, LockHandle, acquire_lock, token_fingerprint
from .logging import get_logger, setup_logging from .logging import get_logger, setup_logging
from .router import AutoRouter, RunnerEntry
from .runner_bridge import ExecBridgeConfig
from .telegram.bridge import ( from .telegram.bridge import (
TelegramBridgeConfig, TelegramBridgeConfig,
TelegramPresenter, TelegramPresenter,
@@ -24,8 +31,7 @@ from .telegram.bridge import (
from .telegram.client import TelegramClient from .telegram.client import TelegramClient
from .telegram.config import load_telegram_config from .telegram.config import load_telegram_config
from .telegram.onboarding import SetupResult, check_setup, interactive_setup from .telegram.onboarding import SetupResult, check_setup, interactive_setup
from .router import AutoRouter, RunnerEntry from .utils.git import resolve_default_base, resolve_main_worktree_root
from .runner_bridge import ExecBridgeConfig
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -195,6 +201,12 @@ def _parse_bridge_config(
startup_pwd = os.getcwd() startup_pwd = os.getcwd()
backends = list_backends() backends = list_backends()
projects = parse_projects_config(
config,
config_path=config_path,
engine_ids=[backend.id for backend in backends],
reserved=("cancel",),
)
default_engine = _resolve_default_engine( default_engine = _resolve_default_engine(
override=default_engine_override, override=default_engine_override,
config=config, config=config,
@@ -212,10 +224,16 @@ def _parse_bridge_config(
engine_list = ", ".join(available_engines) if available_engines else "none" engine_list = ", ".join(available_engines) if available_engines else "none"
if missing_engines: if missing_engines:
engine_list = f"{engine_list} (not installed: {', '.join(missing_engines)})" engine_list = f"{engine_list} (not installed: {', '.join(missing_engines)})"
project_aliases = sorted(
{project.alias for project in projects.projects.values()},
key=str.lower,
)
project_list = ", ".join(project_aliases) if project_aliases else "none"
startup_msg = ( startup_msg = (
f"\N{OCTOPUS} **takopi is ready**\n\n" f"\N{OCTOPUS} **takopi is ready**\n\n"
f"default: `{router.default_engine}` \n" f"default: `{router.default_engine}` \n"
f"agents: `{engine_list}` \n" f"agents: `{engine_list}` \n"
f"projects: `{project_list}` \n"
f"working in: `{startup_pwd}`" f"working in: `{startup_pwd}`"
) )
@@ -234,6 +252,7 @@ def _parse_bridge_config(
chat_id=chat_id, chat_id=chat_id,
startup_msg=startup_msg, startup_msg=startup_msg,
exec_cfg=exec_cfg, exec_cfg=exec_cfg,
projects=projects,
) )
@@ -320,6 +339,107 @@ def _run_auto_router(
lock_handle.release() lock_handle.release()
def _prompt_alias(value: str | None, *, default_alias: str | None = None) -> str:
if value is not None:
alias = value
elif default_alias:
alias = typer.prompt("project alias", default=default_alias)
else:
alias = typer.prompt("project alias")
alias = alias.strip()
if not alias:
typer.echo("error: project alias cannot be empty", err=True)
raise typer.Exit(code=1)
return alias
def _default_alias_from_path(path: Path) -> str | None:
name = path.name
if not name:
return None
if name.endswith(".git"):
name = name[: -len(".git")]
return name or None
def _ensure_projects_table(config: dict, config_path: Path) -> dict:
projects = config.get("projects")
if projects is None:
projects = {}
config["projects"] = projects
if not isinstance(projects, dict):
raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.")
return projects
def init(
alias: str | None = typer.Argument(
None, help="Project alias (used as /alias in messages)."
),
default: bool = typer.Option(
False,
"--default",
help="Set this project as the default_project.",
),
) -> None:
"""Register the current repo as a Takopi project."""
config, config_path = load_or_init_config()
cwd = Path.cwd()
project_path = resolve_main_worktree_root(cwd) or cwd
default_alias = _default_alias_from_path(project_path)
alias = _prompt_alias(alias, default_alias=default_alias)
engine_ids = [backend.id for backend in list_backends()]
projects_cfg = parse_projects_config(
config,
config_path=config_path,
engine_ids=engine_ids,
reserved=("cancel",),
)
alias_key = alias.lower()
if alias_key in {engine.lower() for engine in engine_ids}:
raise ConfigError(
f"Invalid project alias {alias!r}; aliases must not match engine ids."
)
if alias_key == "cancel":
raise ConfigError(
f"Invalid project alias {alias!r}; aliases must not match reserved commands."
)
existing = projects_cfg.projects.get(alias_key)
if existing is not None:
overwrite = typer.confirm(
f"project {existing.alias!r} already exists, overwrite?",
default=False,
)
if not overwrite:
raise typer.Exit(code=1)
projects = _ensure_projects_table(config, config_path)
if existing is not None and existing.alias in projects:
projects.pop(existing.alias, None)
default_engine = _default_engine_for_setup(None)
worktree_base = resolve_default_base(project_path)
entry: dict[str, object] = {
"path": str(project_path),
"worktrees_dir": ".worktrees",
"default_engine": default_engine,
}
if worktree_base:
entry["worktree_base"] = worktree_base
projects[alias] = entry
if default:
config["default_project"] = alias
write_config(config, config_path)
typer.echo(f"saved project {alias!r} to {_config_path_display(config_path)}")
app = typer.Typer( app = typer.Typer(
add_completion=False, add_completion=False,
invoke_without_command=True, invoke_without_command=True,
@@ -327,6 +447,9 @@ app = typer.Typer(
) )
app.command(name="init")(init)
@app.callback() @app.callback()
def app_main( def app_main(
ctx: typer.Context, ctx: typer.Context,
+255
View File
@@ -1,5 +1,260 @@
from __future__ import annotations from __future__ import annotations
import tomllib
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable
HOME_CONFIG_PATH = Path.home() / ".takopi" / "takopi.toml"
class ConfigError(RuntimeError): class ConfigError(RuntimeError):
pass pass
def _read_config(cfg_path: Path) -> dict:
try:
raw = cfg_path.read_text(encoding="utf-8")
except FileNotFoundError:
raise ConfigError(f"Missing config file {cfg_path}.") from None
except OSError as e:
raise ConfigError(f"Failed to read config file {cfg_path}: {e}") from e
try:
return tomllib.loads(raw)
except tomllib.TOMLDecodeError as e:
raise ConfigError(f"Malformed TOML in {cfg_path}: {e}") from None
def load_or_init_config(path: str | Path | None = None) -> tuple[dict, Path]:
cfg_path = Path(path).expanduser() if path else HOME_CONFIG_PATH
if cfg_path.exists() and not cfg_path.is_file():
raise ConfigError(f"Config path {cfg_path} exists but is not a file.") from None
if not cfg_path.exists():
return {}, cfg_path
return _read_config(cfg_path), cfg_path
@dataclass(frozen=True, slots=True)
class ProjectConfig:
alias: str
path: Path
worktrees_dir: Path
default_engine: str | None = None
worktree_base: str | None = None
@property
def worktrees_root(self) -> Path:
if self.worktrees_dir.is_absolute():
return self.worktrees_dir
return self.path / self.worktrees_dir
@dataclass(frozen=True, slots=True)
class ProjectsConfig:
projects: dict[str, ProjectConfig]
default_project: str | None = None
def resolve(self, alias: str | None) -> ProjectConfig | None:
if alias is None:
if self.default_project is None:
return None
return self.projects.get(self.default_project)
return self.projects.get(alias.lower())
def empty_projects_config() -> ProjectsConfig:
return ProjectsConfig(projects={}, default_project=None)
def _normalize_engine_id(
value: str,
*,
engine_ids: Iterable[str],
config_path: Path,
label: str,
) -> str:
engine_map = {engine.lower(): engine for engine in engine_ids}
cleaned = value.strip()
if not cleaned:
raise ConfigError(f"Invalid `{label}` in {config_path}; expected a string.")
engine = engine_map.get(cleaned.lower())
if engine is None:
available = ", ".join(sorted(engine_map.values()))
raise ConfigError(
f"Unknown `{label}` {cleaned!r} in {config_path}. Available: {available}."
)
return engine
def _normalize_project_path(value: str, *, config_path: Path) -> Path:
path = Path(value).expanduser()
if not path.is_absolute():
path = config_path.parent / path
return path
def parse_projects_config(
config: dict[str, Any],
*,
config_path: Path,
engine_ids: Iterable[str],
reserved: Iterable[str] = ("cancel",),
) -> ProjectsConfig:
default_project_raw = config.get("default_project")
default_project = None
if default_project_raw is not None:
if not isinstance(default_project_raw, str) or not default_project_raw.strip():
raise ConfigError(
f"Invalid `default_project` in {config_path}; expected a non-empty string."
)
default_project = default_project_raw.strip()
projects_raw = config.get("projects") or {}
if not isinstance(projects_raw, dict):
raise ConfigError(f"Invalid `projects` in {config_path}; expected a table.")
reserved_lower = {value.lower() for value in reserved}
engine_lower = {value.lower() for value in engine_ids}
projects: dict[str, ProjectConfig] = {}
for raw_alias, raw_entry in projects_raw.items():
if not isinstance(raw_alias, str) or not raw_alias.strip():
raise ConfigError(
f"Invalid project alias in {config_path}; expected a non-empty string."
)
alias = raw_alias.strip()
alias_key = alias.lower()
if alias_key in engine_lower or alias_key in reserved_lower:
raise ConfigError(
f"Invalid project alias {alias!r} in {config_path}; "
"aliases must not match engine ids or reserved commands."
)
if alias_key in projects:
raise ConfigError(f"Duplicate project alias {alias!r} in {config_path}.")
if not isinstance(raw_entry, dict):
raise ConfigError(
f"Invalid project entry for {alias!r} in {config_path}; expected a table."
)
path_value = raw_entry.get("path")
if not isinstance(path_value, str) or not path_value.strip():
raise ConfigError(f"Missing `path` for project {alias!r} in {config_path}.")
path = _normalize_project_path(path_value.strip(), config_path=config_path)
worktrees_dir_raw = raw_entry.get("worktrees_dir", ".worktrees")
if not isinstance(worktrees_dir_raw, str) or not worktrees_dir_raw.strip():
raise ConfigError(
f"Invalid `worktrees_dir` for project {alias!r} in {config_path}."
)
worktrees_dir = Path(worktrees_dir_raw.strip())
default_engine_raw = raw_entry.get("default_engine")
default_engine = None
if default_engine_raw is not None:
if not isinstance(default_engine_raw, str):
raise ConfigError(
f"Invalid `projects.{alias}.default_engine` in {config_path}; "
"expected a string."
)
default_engine = _normalize_engine_id(
default_engine_raw,
engine_ids=engine_ids,
config_path=config_path,
label=f"projects.{alias}.default_engine",
)
worktree_base_raw = raw_entry.get("worktree_base")
worktree_base = None
if worktree_base_raw is not None:
if not isinstance(worktree_base_raw, str) or not worktree_base_raw.strip():
raise ConfigError(
f"Invalid `projects.{alias}.worktree_base` in {config_path}; "
"expected a string."
)
worktree_base = worktree_base_raw.strip()
projects[alias_key] = ProjectConfig(
alias=alias,
path=path,
worktrees_dir=worktrees_dir,
default_engine=default_engine,
worktree_base=worktree_base,
)
if default_project is not None:
default_key = default_project.lower()
if default_key not in projects:
raise ConfigError(
f"Invalid `default_project` {default_project!r} in {config_path}; "
"no matching project alias found."
)
default_project = default_key
return ProjectsConfig(projects=projects, default_project=default_project)
def _toml_escape(value: str) -> str:
return value.replace("\\", "\\\\").replace('"', '\\"')
def _format_toml_value(value: Any) -> str:
if isinstance(value, bool):
return "true" if value else "false"
if isinstance(value, int):
return str(value)
if isinstance(value, float):
return repr(value)
if isinstance(value, str):
return f'"{_toml_escape(value)}"'
if isinstance(value, (list, tuple)):
inner = ", ".join(_format_toml_value(item) for item in value)
return f"[{inner}]"
raise ConfigError(f"Unsupported config value {value!r}")
def _table_has_scalars(table: dict[str, Any]) -> bool:
return any(not isinstance(value, dict) for value in table.values())
def dump_toml(config: dict[str, Any]) -> str:
lines: list[str] = []
def write_kv(key: str, value: Any) -> None:
lines.append(f"{key} = {_format_toml_value(value)}")
def write_table(name: str, table: dict[str, Any]) -> None:
if lines and lines[-1] != "":
lines.append("")
lines.append(f"[{name}]")
for key, value in table.items():
if isinstance(value, dict):
continue
write_kv(key, value)
for key, value in table.items():
if isinstance(value, dict):
write_table(f"{name}.{key}", value)
for key, value in config.items():
if isinstance(value, dict):
continue
write_kv(key, value)
for key, value in config.items():
if not isinstance(value, dict):
continue
if _table_has_scalars(value):
write_table(key, value)
continue
for subkey, subvalue in value.items():
if isinstance(subvalue, dict):
write_table(f"{key}.{subkey}", subvalue)
else:
write_table(key, value)
break
return "\n".join(lines) + "\n"
def write_config(config: dict[str, Any], path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(dump_toml(config), encoding="utf-8")
+9
View File
@@ -0,0 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class RunContext:
project: str | None = None
branch: str | None = None
+34 -7
View File
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import os
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
@@ -94,12 +95,16 @@ def format_file_change_title(action: Action, *, command_width: int | None) -> st
if isinstance(changes, list) and changes: if isinstance(changes, list) and changes:
rendered: list[str] = [] rendered: list[str] = []
for raw in changes: for raw in changes:
if not isinstance(raw, dict): path: str | None
continue kind: str | None
path = raw.get("path") if isinstance(raw, dict):
path = raw.get("path")
kind = raw.get("kind")
else:
path = getattr(raw, "path", None)
kind = getattr(raw, "kind", None)
if not isinstance(path, str) or not path: if not isinstance(path, str) or not path:
continue continue
kind = raw.get("kind")
verb = kind if isinstance(kind, str) and kind else "update" verb = kind if isinstance(kind, str) and kind else "update"
rendered.append(f"{verb} {format_changed_file_path(path)}") rendered.append(f"{verb} {format_changed_file_path(path)}")
@@ -110,7 +115,15 @@ def format_file_change_title(action: Action, *, command_width: int | None) -> st
inline = shorten(", ".join(rendered), command_width) inline = shorten(", ".join(rendered), command_width)
return f"files: {inline}" return f"files: {inline}"
return f"files: {shorten(title, command_width)}" fallback = title
relativized = relativize_path(fallback)
was_relativized = relativized != fallback
if was_relativized:
fallback = relativized
if fallback and not (fallback.startswith("`") and fallback.endswith("`")):
if was_relativized or os.sep in fallback or "/" in fallback:
fallback = f"`{fallback}`"
return f"files: {shorten(fallback, command_width)}"
def format_action_title(action: Action, *, command_width: int | None) -> str: def format_action_title(action: Action, *, command_width: int | None) -> str:
@@ -197,7 +210,9 @@ class MarkdownFormatter:
engine=state.engine, engine=state.engine,
) )
body = self._assemble_body(self._format_actions(state)) body = self._assemble_body(self._format_actions(state))
return MarkdownParts(header=header, body=body, footer=state.resume_line) return MarkdownParts(
header=header, body=body, footer=self._format_footer(state)
)
def render_final_parts( def render_final_parts(
self, self,
@@ -216,7 +231,19 @@ class MarkdownFormatter:
) )
answer = (answer or "").strip() answer = (answer or "").strip()
body = answer if answer else None body = answer if answer else None
return MarkdownParts(header=header, body=body, footer=state.resume_line) return MarkdownParts(
header=header, body=body, footer=self._format_footer(state)
)
def _format_footer(self, state: ProgressState) -> str | None:
lines: list[str] = []
if state.context_line:
lines.append(state.context_line)
if state.resume_line:
lines.append(state.resume_line)
if not lines:
return None
return HARD_BREAK.join(lines)
def _format_actions(self, state: ProgressState) -> list[str]: def _format_actions(self, state: ProgressState) -> list[str]:
actions = list(state.actions) actions = list(state.actions)
+3
View File
@@ -24,6 +24,7 @@ class ProgressState:
actions: tuple[ActionState, ...] actions: tuple[ActionState, ...]
resume: ResumeToken | None resume: ResumeToken | None
resume_line: str | None resume_line: str | None
context_line: str | None
class ProgressTracker: class ProgressTracker:
@@ -80,6 +81,7 @@ class ProgressTracker:
self, self,
*, *,
resume_formatter: Callable[[ResumeToken], str] | None = None, resume_formatter: Callable[[ResumeToken], str] | None = None,
context_line: str | None = None,
) -> ProgressState: ) -> ProgressState:
resume_line: str | None = None resume_line: str | None = None
if self.resume is not None and resume_formatter is not None: if self.resume is not None and resume_formatter is not None:
@@ -93,4 +95,5 @@ class ProgressTracker:
actions=actions, actions=actions,
resume=self.resume, resume=self.resume,
resume_line=resume_line, resume_line=resume_line,
context_line=context_line,
) )
+4
View File
@@ -22,6 +22,7 @@ from .model import (
StartedEvent, StartedEvent,
TakopiEvent, TakopiEvent,
) )
from .utils.paths import get_run_base_dir
from .utils.streams import drain_stderr, iter_bytes_lines from .utils.streams import drain_stderr, iter_bytes_lines
from .utils.subprocess import manage_subprocess from .utils.subprocess import manage_subprocess
@@ -358,12 +359,15 @@ class JsonlSubprocessRunner(BaseRunner):
prompt_len=len(prompt), prompt_len=len(prompt),
) )
cwd = get_run_base_dir()
async with manage_subprocess( async with manage_subprocess(
cmd, cmd,
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
env=env, env=env,
cwd=cwd,
) as proc: ) as proc:
if proc.stdout is None or proc.stderr is None: if proc.stdout is None or proc.stderr is None:
raise RuntimeError(self.pipes_error_message()) raise RuntimeError(self.pipes_error_message())
+30 -6
View File
@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
import anyio import anyio
from .context import RunContext
from .logging import bind_run_context, get_logger from .logging import bind_run_context, get_logger
from .model import CompletedEvent, ResumeToken, StartedEvent, TakopiEvent from .model import CompletedEvent, ResumeToken, StartedEvent, TakopiEvent
from .presenter import Presenter from .presenter import Presenter
@@ -93,6 +94,7 @@ class RunningTask:
resume_ready: anyio.Event = field(default_factory=anyio.Event) resume_ready: anyio.Event = field(default_factory=anyio.Event)
cancel_requested: anyio.Event = field(default_factory=anyio.Event) cancel_requested: anyio.Event = field(default_factory=anyio.Event)
done: anyio.Event = field(default_factory=anyio.Event) done: anyio.Event = field(default_factory=anyio.Event)
context: RunContext | None = None
RunningTasks = dict[MessageRef, RunningTask] RunningTasks = dict[MessageRef, RunningTask]
@@ -152,6 +154,7 @@ class ProgressEdits:
last_rendered: RenderedMessage | None, last_rendered: RenderedMessage | None,
resume_formatter: Callable[[ResumeToken], str] | None = None, resume_formatter: Callable[[ResumeToken], str] | None = None,
label: str = "working", label: str = "working",
context_line: str | None = None,
) -> None: ) -> None:
self.transport = transport self.transport = transport
self.presenter = presenter self.presenter = presenter
@@ -163,6 +166,7 @@ class ProgressEdits:
self.last_rendered = last_rendered self.last_rendered = last_rendered
self.resume_formatter = resume_formatter self.resume_formatter = resume_formatter
self.label = label self.label = label
self.context_line = context_line
self.event_seq = 0 self.event_seq = 0
self.rendered_seq = 0 self.rendered_seq = 0
self.signal_send, self.signal_recv = anyio.create_memory_object_stream(1) self.signal_send, self.signal_recv = anyio.create_memory_object_stream(1)
@@ -179,7 +183,10 @@ class ProgressEdits:
seq_at_render = self.event_seq seq_at_render = self.event_seq
now = self.clock() now = self.clock()
state = self.tracker.snapshot(resume_formatter=self.resume_formatter) state = self.tracker.snapshot(
resume_formatter=self.resume_formatter,
context_line=self.context_line,
)
rendered = self.presenter.render_progress( rendered = self.presenter.render_progress(
state, elapsed_s=now - self.started_at, label=self.label state, elapsed_s=now - self.started_at, label=self.label
) )
@@ -228,11 +235,15 @@ async def send_initial_progress(
label: str, label: str,
tracker: ProgressTracker, tracker: ProgressTracker,
resume_formatter: Callable[[ResumeToken], str] | None = None, resume_formatter: Callable[[ResumeToken], str] | None = None,
context_line: str | None = None,
) -> ProgressMessageState: ) -> ProgressMessageState:
progress_ref: MessageRef | None = None progress_ref: MessageRef | None = None
last_rendered: RenderedMessage | None = None last_rendered: RenderedMessage | None = None
state = tracker.snapshot(resume_formatter=resume_formatter) state = tracker.snapshot(
resume_formatter=resume_formatter,
context_line=context_line,
)
initial_rendered = cfg.presenter.render_progress( initial_rendered = cfg.presenter.render_progress(
state, state,
elapsed_s=0.0, elapsed_s=0.0,
@@ -364,6 +375,8 @@ async def handle_message(
runner: Runner, runner: Runner,
incoming: IncomingMessage, incoming: IncomingMessage,
resume_token: ResumeToken | None, resume_token: ResumeToken | None,
context: RunContext | None = None,
context_line: str | None = None,
strip_resume_line: Callable[[str], bool] | None = None, strip_resume_line: Callable[[str], bool] | None = None,
running_tasks: RunningTasks | None = None, running_tasks: RunningTasks | None = None,
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
@@ -395,6 +408,7 @@ async def handle_message(
label="starting", label="starting",
tracker=progress_tracker, tracker=progress_tracker,
resume_formatter=runner.format_resume, resume_formatter=runner.format_resume,
context_line=context_line,
) )
progress_ref = progress_state.ref progress_ref = progress_state.ref
@@ -408,11 +422,12 @@ async def handle_message(
clock=clock, clock=clock,
last_rendered=progress_state.last_rendered, last_rendered=progress_state.last_rendered,
resume_formatter=runner.format_resume, resume_formatter=runner.format_resume,
context_line=context_line,
) )
running_task: RunningTask | None = None running_task: RunningTask | None = None
if running_tasks is not None and progress_ref is not None: if running_tasks is not None and progress_ref is not None:
running_task = RunningTask() running_task = RunningTask(context=context)
running_tasks[progress_ref] = running_task running_tasks[progress_ref] = running_task
cancel_exc_type = anyio.get_cancelled_exc_class() cancel_exc_type = anyio.get_cancelled_exc_class()
@@ -464,7 +479,10 @@ async def handle_message(
if error is not None: if error is not None:
sync_resume_token(progress_tracker, outcome.resume) sync_resume_token(progress_tracker, outcome.resume)
err_body = _format_error(error) err_body = _format_error(error)
state = progress_tracker.snapshot(resume_formatter=runner.format_resume) state = progress_tracker.snapshot(
resume_formatter=runner.format_resume,
context_line=context_line,
)
final_rendered = cfg.presenter.render_final( final_rendered = cfg.presenter.render_final(
state, state,
elapsed_s=elapsed, elapsed_s=elapsed,
@@ -496,7 +514,10 @@ async def handle_message(
resume=resume.value if resume else None, resume=resume.value if resume else None,
elapsed_s=elapsed, elapsed_s=elapsed,
) )
state = progress_tracker.snapshot(resume_formatter=runner.format_resume) state = progress_tracker.snapshot(
resume_formatter=runner.format_resume,
context_line=context_line,
)
final_rendered = cfg.presenter.render_progress( final_rendered = cfg.presenter.render_progress(
state, state,
elapsed_s=elapsed, elapsed_s=elapsed,
@@ -546,7 +567,10 @@ async def handle_message(
resume=resume_value, resume=resume_value,
) )
sync_resume_token(progress_tracker, completed.resume or outcome.resume) sync_resume_token(progress_tracker, completed.resume or outcome.resume)
state = progress_tracker.snapshot(resume_formatter=runner.format_resume) state = progress_tracker.snapshot(
resume_formatter=runner.format_resume,
context_line=context_line,
)
final_rendered = cfg.presenter.render_final( final_rendered = cfg.presenter.render_final(
state, state,
elapsed_s=elapsed, elapsed_s=elapsed,
+22 -1
View File
@@ -76,6 +76,26 @@ def _summarize_tool_result(result: Any) -> dict[str, Any] | None:
return None return None
def _normalize_change_list(changes: list[Any]) -> list[dict[str, str]]:
normalized: list[dict[str, str]] = []
for change in changes:
path: str | None = None
kind: str | None = None
if isinstance(change, codex_schema.FileUpdateChange):
path = change.path
kind = change.kind
elif isinstance(change, dict):
path = change.get("path")
kind = change.get("kind")
if not isinstance(path, str) or not path:
continue
entry = {"path": path}
if isinstance(kind, str) and kind:
entry["kind"] = kind
normalized.append(entry)
return normalized
def _format_change_summary(changes: list[Any]) -> str: def _format_change_summary(changes: list[Any]) -> str:
paths: list[str] = [] paths: list[str] = []
for change in changes: for change in changes:
@@ -260,8 +280,9 @@ def _translate_item_event(
if phase != "completed": if phase != "completed":
return [] return []
title = _format_change_summary(changes) title = _format_change_summary(changes)
normalized_changes = _normalize_change_list(changes)
detail = { detail = {
"changes": changes, "changes": normalized_changes,
"status": status, "status": status,
"error": None, "error": None,
} }
+4
View File
@@ -6,6 +6,7 @@ from typing import Any, Awaitable, Callable, Protocol
import anyio import anyio
from .context import RunContext
from .model import ResumeToken from .model import ResumeToken
@@ -15,6 +16,7 @@ class ThreadJob:
user_msg_id: int user_msg_id: int
text: str text: str
resume_token: ResumeToken resume_token: ResumeToken
context: RunContext | None = None
RunJob = Callable[[ThreadJob], Awaitable[None]] RunJob = Callable[[ThreadJob], Awaitable[None]]
@@ -66,6 +68,7 @@ class ThreadScheduler:
user_msg_id: int, user_msg_id: int,
text: str, text: str,
resume_token: ResumeToken, resume_token: ResumeToken,
context: RunContext | None = None,
) -> None: ) -> None:
await self.enqueue( await self.enqueue(
ThreadJob( ThreadJob(
@@ -73,6 +76,7 @@ class ThreadScheduler:
user_msg_id=user_msg_id, user_msg_id=user_msg_id,
text=text, text=text,
resume_token=resume_token, resume_token=resume_token,
context=context,
) )
) )
+4
View File
@@ -1 +1,5 @@
"""Telegram-specific clients and adapters.""" """Telegram-specific clients and adapters."""
from .client import parse_incoming_update, poll_incoming
__all__ = ["parse_incoming_update", "poll_incoming"]
+322 -81
View File
@@ -1,14 +1,16 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncIterator, Awaitable, Callable from collections.abc import AsyncIterator, Awaitable, Callable
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any
import anyio import anyio
from ..config import ProjectsConfig, empty_projects_config
from ..context import RunContext
from ..runner_bridge import ( from ..runner_bridge import (
ExecBridgeConfig, ExecBridgeConfig,
IncomingMessage, IncomingMessage as RunnerIncomingMessage,
RunningTask, RunningTask,
RunningTasks, RunningTasks,
handle_message, handle_message,
@@ -20,8 +22,16 @@ from ..progress import ProgressState, ProgressTracker
from ..router import AutoRouter, RunnerUnavailableError from ..router import AutoRouter, RunnerUnavailableError
from ..runner import Runner from ..runner import Runner
from ..scheduler import ThreadJob, ThreadScheduler from ..scheduler import ThreadJob, ThreadScheduler
from ..transport import MessageRef, RenderedMessage, SendOptions, Transport from ..transport import (
from .client import BotClient IncomingMessage as TransportIncomingMessage,
MessageRef,
RenderedMessage,
SendOptions,
Transport,
)
from ..utils.paths import reset_run_base_dir, set_run_base_dir
from ..worktrees import WorktreeError, resolve_run_cwd
from .client import BotClient, poll_incoming
from .render import prepare_telegram from .render import prepare_telegram
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -70,6 +80,206 @@ def _strip_engine_command(
return "\n".join(lines).strip(), engine return "\n".join(lines).strip(), engine
@dataclass(frozen=True, slots=True)
class ParsedDirectives:
prompt: str
engine: EngineId | None
project: str | None
branch: str | None
@dataclass(frozen=True, slots=True)
class ResolvedMessage:
prompt: str
resume_token: ResumeToken | None
engine_override: EngineId | None
context: RunContext | None
class DirectiveError(RuntimeError):
pass
def _parse_directives(
text: str,
*,
engine_ids: tuple[EngineId, ...],
projects: ProjectsConfig,
) -> ParsedDirectives:
if not text:
return ParsedDirectives(prompt="", engine=None, project=None, branch=None)
lines = text.splitlines()
idx = next((i for i, line in enumerate(lines) if line.strip()), None)
if idx is None:
return ParsedDirectives(prompt=text, engine=None, project=None, branch=None)
line = lines[idx].lstrip()
tokens = line.split()
if not tokens:
return ParsedDirectives(prompt=text, engine=None, project=None, branch=None)
engine_map = {engine.lower(): engine for engine in engine_ids}
project_map = {alias.lower(): alias for alias in projects.projects}
engine: EngineId | None = None
project: str | None = None
branch: str | None = None
consumed = 0
for token in tokens:
if token.startswith("/"):
name = token[1:]
if "@" in name:
name = name.split("@", 1)[0]
if not name:
break
key = name.lower()
engine_candidate = engine_map.get(key)
project_candidate = project_map.get(key)
if engine_candidate is not None:
if engine is not None:
raise DirectiveError("multiple engine directives")
engine = engine_candidate
consumed += 1
continue
if project_candidate is not None:
if project is not None:
raise DirectiveError("multiple project directives")
project = project_candidate
consumed += 1
continue
break
if token.startswith("@"):
value = token[1:]
if not value:
break
if branch is not None:
raise DirectiveError("multiple @branch directives")
branch = value
consumed += 1
continue
break
if consumed == 0:
return ParsedDirectives(prompt=text, engine=None, project=None, branch=None)
if consumed < len(tokens):
remainder = " ".join(tokens[consumed:])
lines[idx] = remainder
else:
lines.pop(idx)
prompt = "\n".join(lines).strip()
return ParsedDirectives(
prompt=prompt, engine=engine, project=project, branch=branch
)
def _parse_ctx_line(text: str | None, *, projects: ProjectsConfig) -> RunContext | None:
if not text:
return None
ctx: RunContext | None = None
for line in text.splitlines():
stripped = line.strip()
if stripped.startswith("`") and stripped.endswith("`") and len(stripped) > 1:
stripped = stripped[1:-1].strip()
elif stripped.startswith("`"):
stripped = stripped[1:].strip()
elif stripped.endswith("`"):
stripped = stripped[:-1].strip()
if not stripped.lower().startswith("ctx:"):
continue
content = stripped.split(":", 1)[1].strip()
if not content:
continue
tokens = content.split()
if not tokens:
continue
project = tokens[0]
branch = None
if len(tokens) >= 2:
if tokens[1] == "@" and len(tokens) >= 3:
branch = tokens[2]
elif tokens[1].startswith("@"):
branch = tokens[1][1:]
project_key = project.lower()
if project_key not in projects.projects:
raise DirectiveError(f"unknown project {project!r} in ctx line")
ctx = RunContext(project=project_key, branch=branch)
return ctx
def _format_context_line(
context: RunContext | None, *, projects: ProjectsConfig
) -> str | None:
if context is None or context.project is None:
return None
project_cfg = projects.projects.get(context.project)
alias = project_cfg.alias if project_cfg is not None else context.project
if context.branch:
return f"`ctx: {alias} @ {context.branch}`"
return f"`ctx: {alias}`"
def _resolve_message(
*,
text: str,
reply_text: str | None,
router: AutoRouter,
projects: ProjectsConfig,
) -> ResolvedMessage:
directives = _parse_directives(
text,
engine_ids=router.engine_ids,
projects=projects,
)
reply_ctx = _parse_ctx_line(reply_text, projects=projects)
resume_token = router.resolve_resume(directives.prompt, reply_text)
if resume_token is not None:
return ResolvedMessage(
prompt=directives.prompt,
resume_token=resume_token,
engine_override=None,
context=reply_ctx,
)
if reply_ctx is not None:
engine_override = None
if reply_ctx.project is not None:
project = projects.projects.get(reply_ctx.project)
if project is not None and project.default_engine is not None:
engine_override = project.default_engine
return ResolvedMessage(
prompt=directives.prompt,
resume_token=None,
engine_override=engine_override,
context=reply_ctx,
)
project_key = directives.project
if project_key is None and projects.default_project is not None:
project_key = projects.default_project
context = None
if project_key is not None or directives.branch is not None:
context = RunContext(project=project_key, branch=directives.branch)
engine_override = directives.engine
if engine_override is None and project_key is not None:
project = projects.projects.get(project_key)
if project is not None and project.default_engine is not None:
engine_override = project.default_engine
return ResolvedMessage(
prompt=directives.prompt,
resume_token=None,
engine_override=engine_override,
context=context,
)
def _build_bot_commands(router: AutoRouter) -> list[dict[str, str]]: def _build_bot_commands(router: AutoRouter) -> list[dict[str, str]]:
commands: list[dict[str, str]] = [] commands: list[dict[str, str]] = []
seen: set[str] = set() seen: set[str] = set()
@@ -232,6 +442,7 @@ class TelegramBridgeConfig:
chat_id: int chat_id: int
startup_msg: str startup_msg: str
exec_cfg: ExecBridgeConfig exec_cfg: ExecBridgeConfig
projects: ProjectsConfig = field(default_factory=empty_projects_config)
async def _send_plain( async def _send_plain(
@@ -281,41 +492,35 @@ async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int |
drained += len(updates) drained += len(updates)
async def poll_updates(cfg: TelegramBridgeConfig) -> AsyncIterator[dict[str, Any]]: async def poll_updates(
cfg: TelegramBridgeConfig,
) -> AsyncIterator[TransportIncomingMessage]:
offset: int | None = None offset: int | None = None
offset = await _drain_backlog(cfg, offset) offset = await _drain_backlog(cfg, offset)
await _send_startup(cfg) await _send_startup(cfg)
while True: async for msg in poll_incoming(cfg.bot, chat_id=cfg.chat_id, offset=offset):
updates = await cfg.bot.get_updates( yield msg
offset=offset, timeout_s=50, allowed_updates=["message"]
)
if updates is None:
logger.info("loop.get_updates.failed")
await anyio.sleep(2)
continue
logger.debug("loop.updates", updates=updates)
for upd in updates:
offset = upd["update_id"] + 1
msg = upd["message"]
if "text" not in msg:
continue
if msg["chat"]["id"] != cfg.chat_id:
continue
yield msg
async def _handle_cancel( async def _handle_cancel(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
msg: dict[str, Any], msg: TransportIncomingMessage,
running_tasks: RunningTasks, running_tasks: RunningTasks,
) -> None: ) -> None:
chat_id = msg["chat"]["id"] chat_id = msg.chat_id
user_msg_id = msg["message_id"] user_msg_id = msg.message_id
reply = msg.get("reply_to_message") reply_id = msg.reply_to_message_id
if not reply: if reply_id is None:
if msg.reply_to_text:
await _send_plain(
cfg.exec_cfg.transport,
chat_id=chat_id,
user_msg_id=user_msg_id,
text="nothing is currently running for that message.",
)
return
await _send_plain( await _send_plain(
cfg.exec_cfg.transport, cfg.exec_cfg.transport,
chat_id=chat_id, chat_id=chat_id,
@@ -324,17 +529,7 @@ async def _handle_cancel(
) )
return return
progress_id = reply.get("message_id") progress_ref = MessageRef(channel_id=chat_id, message_id=reply_id)
if progress_id is None:
await _send_plain(
cfg.exec_cfg.transport,
chat_id=chat_id,
user_msg_id=user_msg_id,
text="nothing is currently running for that message.",
)
return
progress_ref = MessageRef(channel_id=chat_id, message_id=progress_id)
running_task = running_tasks.get(progress_ref) running_task = running_tasks.get(progress_ref)
if running_task is None: if running_task is None:
await _send_plain( await _send_plain(
@@ -348,7 +543,7 @@ async def _handle_cancel(
logger.info( logger.info(
"cancel.requested", "cancel.requested",
chat_id=chat_id, chat_id=chat_id,
progress_message_id=progress_id, progress_message_id=reply_id,
) )
running_task.cancel_requested.set() running_task.cancel_requested.set()
@@ -378,7 +573,7 @@ async def _wait_for_resume(running_task: RunningTask) -> ResumeToken | None:
async def _send_with_resume( async def _send_with_resume(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
enqueue: Callable[[int, int, str, ResumeToken], Awaitable[None]], enqueue: Callable[[int, int, str, ResumeToken, RunContext | None], Awaitable[None]],
running_task: RunningTask, running_task: RunningTask,
chat_id: int, chat_id: int,
user_msg_id: int, user_msg_id: int,
@@ -394,7 +589,7 @@ async def _send_with_resume(
notify=False, notify=False,
) )
return return
await enqueue(chat_id, user_msg_id, text, resume) await enqueue(chat_id, user_msg_id, text, resume, running_task.context)
async def _send_runner_unavailable( async def _send_runner_unavailable(
@@ -426,7 +621,7 @@ async def _send_runner_unavailable(
async def run_main_loop( async def run_main_loop(
cfg: TelegramBridgeConfig, cfg: TelegramBridgeConfig,
poller: Callable[ poller: Callable[
[TelegramBridgeConfig], AsyncIterator[dict[str, Any]] [TelegramBridgeConfig], AsyncIterator[TransportIncomingMessage]
] = poll_updates, ] = poll_updates,
) -> None: ) -> None:
running_tasks: RunningTasks = {} running_tasks: RunningTasks = {}
@@ -440,6 +635,7 @@ async def run_main_loop(
user_msg_id: int, user_msg_id: int,
text: str, text: str,
resume_token: ResumeToken | None, resume_token: ResumeToken | None,
context: RunContext | None,
reply_ref: MessageRef | None = None, reply_ref: MessageRef | None = None,
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]]
| None = None, | None = None,
@@ -471,27 +667,52 @@ async def run_main_loop(
reason=reason, reason=reason,
) )
return return
bind_run_context( try:
chat_id=chat_id, cwd = resolve_run_cwd(context, projects=cfg.projects)
user_msg_id=user_msg_id, except WorktreeError as exc:
engine=entry.runner.engine, await _send_plain(
resume=resume_token.value if resume_token else None, cfg.exec_cfg.transport,
) chat_id=chat_id,
incoming = IncomingMessage( user_msg_id=user_msg_id,
channel_id=chat_id, text=f"error:\n{exc}",
message_id=user_msg_id, )
text=text, return
reply_to=reply_ref, run_base_token = set_run_base_dir(cwd)
) try:
await handle_message( run_fields = {
cfg.exec_cfg, "chat_id": chat_id,
runner=entry.runner, "user_msg_id": user_msg_id,
incoming=incoming, "engine": entry.runner.engine,
resume_token=resume_token, "resume": resume_token.value if resume_token else None,
strip_resume_line=cfg.router.is_resume_line, }
running_tasks=running_tasks, if context is not None:
on_thread_known=on_thread_known, run_fields["project"] = context.project
) run_fields["branch"] = context.branch
if cwd is not None:
run_fields["cwd"] = str(cwd)
bind_run_context(**run_fields)
context_line = _format_context_line(
context, projects=cfg.projects
)
incoming = RunnerIncomingMessage(
channel_id=chat_id,
message_id=user_msg_id,
text=text,
reply_to=reply_ref,
)
await handle_message(
cfg.exec_cfg,
runner=entry.runner,
incoming=incoming,
resume_token=resume_token,
context=context,
context_line=context_line,
strip_resume_line=cfg.router.is_resume_line,
running_tasks=running_tasks,
on_thread_known=on_thread_known,
)
finally:
reset_run_base_dir(run_base_token)
except Exception as exc: except Exception as exc:
logger.exception( logger.exception(
"handle.worker_failed", "handle.worker_failed",
@@ -507,33 +728,48 @@ async def run_main_loop(
job.user_msg_id, job.user_msg_id,
job.text, job.text,
job.resume_token, job.resume_token,
job.context,
None, None,
) )
scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job) scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
async for msg in poller(cfg): async for msg in poller(cfg):
text = msg["text"] text = msg.text
user_msg_id = msg["message_id"] user_msg_id = msg.message_id
chat_id = msg["chat"]["id"] chat_id = msg.chat_id
reply_ref = None reply_id = msg.reply_to_message_id
reply_msg = msg.get("reply_to_message") reply_ref = (
if reply_msg: MessageRef(channel_id=chat_id, message_id=reply_id)
reply_id = reply_msg.get("message_id") if reply_id is not None
if reply_id is not None: else None
reply_ref = MessageRef(channel_id=chat_id, message_id=reply_id) )
if _is_cancel_command(text): if _is_cancel_command(text):
tg.start_soon(_handle_cancel, cfg, msg, running_tasks) tg.start_soon(_handle_cancel, cfg, msg, running_tasks)
continue continue
text, engine_override = _strip_engine_command( reply_text = msg.reply_to_text
text, engine_ids=cfg.router.engine_ids try:
) resolved = _resolve_message(
text=text,
reply_text=reply_text,
router=cfg.router,
projects=cfg.projects,
)
except DirectiveError as exc:
await _send_plain(
cfg.exec_cfg.transport,
chat_id=chat_id,
user_msg_id=user_msg_id,
text=f"error:\n{exc}",
)
continue
r = msg.get("reply_to_message") or {} text = resolved.prompt
resume_token = cfg.router.resolve_resume(text, r.get("text")) resume_token = resolved.resume_token
reply_id = r.get("message_id") engine_override = resolved.engine_override
context = resolved.context
if resume_token is None and reply_id is not None: if resume_token is None and reply_id is not None:
running_task = running_tasks.get( running_task = running_tasks.get(
MessageRef(channel_id=chat_id, message_id=reply_id) MessageRef(channel_id=chat_id, message_id=reply_id)
@@ -557,13 +793,18 @@ async def run_main_loop(
user_msg_id, user_msg_id,
text, text,
None, None,
context,
reply_ref, reply_ref,
scheduler.note_thread_known, scheduler.note_thread_known,
engine_override, engine_override,
) )
else: else:
await scheduler.enqueue_resume( await scheduler.enqueue_resume(
chat_id, user_msg_id, text, resume_token chat_id,
user_msg_id,
text,
resume_token,
context,
) )
finally: finally:
await cfg.exec_cfg.transport.close() await cfg.exec_cfg.transport.close()
+80 -1
View File
@@ -3,13 +3,22 @@ from __future__ import annotations
import itertools import itertools
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Hashable, Protocol, TYPE_CHECKING from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Hashable,
Protocol,
TYPE_CHECKING,
)
import httpx import httpx
import anyio import anyio
from ..logging import get_logger from ..logging import get_logger
from ..transport import IncomingMessage
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -34,6 +43,76 @@ def is_group_chat_id(chat_id: int) -> bool:
return chat_id < 0 return chat_id < 0
def parse_incoming_update(
update: dict[str, Any], *, chat_id: int
) -> IncomingMessage | None:
msg = update.get("message")
if not isinstance(msg, dict):
return None
text = msg.get("text")
if not isinstance(text, str):
return None
chat = msg.get("chat")
if not isinstance(chat, dict):
return None
msg_chat_id = chat.get("id")
if not isinstance(msg_chat_id, int) or msg_chat_id != chat_id:
return None
message_id = msg.get("message_id")
if not isinstance(message_id, int):
return None
reply = msg.get("reply_to_message")
reply_to_message_id = None
reply_to_text = None
if isinstance(reply, dict):
reply_to_message_id = (
reply.get("message_id")
if isinstance(reply.get("message_id"), int)
else None
)
reply_to_text = (
reply.get("text") if isinstance(reply.get("text"), str) else None
)
sender = msg.get("from")
sender_id = (
sender.get("id")
if isinstance(sender, dict) and isinstance(sender.get("id"), int)
else None
)
return IncomingMessage(
transport="telegram",
chat_id=msg_chat_id,
message_id=message_id,
text=text,
reply_to_message_id=reply_to_message_id,
reply_to_text=reply_to_text,
sender_id=sender_id,
raw=msg,
)
async def poll_incoming(
bot: BotClient,
*,
chat_id: int,
offset: int | None = None,
) -> AsyncIterator[IncomingMessage]:
while True:
updates = await bot.get_updates(
offset=offset, timeout_s=50, allowed_updates=["message"]
)
if updates is None:
logger.info("loop.get_updates.failed")
await anyio.sleep(2)
continue
logger.debug("loop.updates", updates=updates)
for upd in updates:
offset = upd["update_id"] + 1
msg = parse_incoming_update(upd, chat_id=chat_id)
if msg is not None:
yield msg
class BotClient(Protocol): class BotClient(Protocol):
async def close(self) -> None: ... async def close(self) -> None: ...
+12
View File
@@ -7,6 +7,18 @@ ChannelId: TypeAlias = int | str
MessageId: TypeAlias = int | str MessageId: TypeAlias = int | str
@dataclass(frozen=True, slots=True)
class IncomingMessage:
transport: str
chat_id: int
message_id: int
text: str
reply_to_message_id: int | None
reply_to_text: str | None
sender_id: int | None
raw: dict[str, Any] | None = None
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class MessageRef: class MessageRef:
channel_id: ChannelId channel_id: ChannelId
+81
View File
@@ -0,0 +1,81 @@
from __future__ import annotations
import subprocess
from collections.abc import Sequence
from pathlib import Path
def _run_git(
args: Sequence[str], *, cwd: Path
) -> subprocess.CompletedProcess[str] | None:
try:
return subprocess.run(
["git", *args],
cwd=cwd,
check=False,
text=True,
capture_output=True,
)
except FileNotFoundError:
return None
def git_run(
args: Sequence[str], *, cwd: Path
) -> subprocess.CompletedProcess[str] | None:
return _run_git(args, cwd=cwd)
def git_stdout(args: Sequence[str], *, cwd: Path) -> str | None:
result = _run_git(args, cwd=cwd)
if result is None or result.returncode != 0:
return None
output = result.stdout.strip()
return output or None
def git_ok(args: Sequence[str], *, cwd: Path) -> bool:
result = _run_git(args, cwd=cwd)
return result is not None and result.returncode == 0
def git_is_worktree(path: Path) -> bool:
return git_stdout(["rev-parse", "--is-inside-work-tree"], cwd=path) == "true"
def resolve_default_base(root: Path) -> str | None:
origin_head = git_stdout(
["symbolic-ref", "-q", "refs/remotes/origin/HEAD"],
cwd=root,
)
if origin_head:
prefix = "refs/remotes/origin/"
if origin_head.startswith(prefix):
name = origin_head[len(prefix) :].strip()
if name:
return name
current = git_stdout(["branch", "--show-current"], cwd=root)
if current:
return current
if git_ok(["show-ref", "--verify", "--quiet", "refs/heads/master"], cwd=root):
return "master"
if git_ok(["show-ref", "--verify", "--quiet", "refs/heads/main"], cwd=root):
return "main"
return None
def resolve_main_worktree_root(cwd: Path) -> Path | None:
common_dir = git_stdout(
["rev-parse", "--path-format=absolute", "--git-common-dir"],
cwd=cwd,
)
if not common_dir:
return None
if git_stdout(["rev-parse", "--is-bare-repository"], cwd=cwd) == "true":
return cwd
common_path = Path(common_dir)
if not common_path.is_absolute():
common_path = (cwd / common_path).resolve()
return common_path.parent
+22 -2
View File
@@ -1,13 +1,31 @@
from __future__ import annotations from __future__ import annotations
import os import os
from contextvars import ContextVar, Token
from pathlib import Path from pathlib import Path
_run_base_dir: ContextVar[Path | None] = ContextVar("takopi_run_base_dir", default=None)
def get_run_base_dir() -> Path | None:
return _run_base_dir.get()
def set_run_base_dir(base_dir: Path | None) -> Token[Path | None]:
return _run_base_dir.set(base_dir)
def reset_run_base_dir(token: Token[Path | None]) -> None:
_run_base_dir.reset(token)
def relativize_path(value: str, *, base_dir: Path | None = None) -> str: def relativize_path(value: str, *, base_dir: Path | None = None) -> str:
if not value: if not value:
return value return value
base = Path.cwd() if base_dir is None else base_dir base = get_run_base_dir() if base_dir is None else base_dir
if base is None:
base = Path.cwd()
base_str = str(base) base_str = str(base)
if not base_str: if not base_str:
return value return value
@@ -22,6 +40,8 @@ def relativize_path(value: str, *, base_dir: Path | None = None) -> str:
def relativize_command(value: str, *, base_dir: Path | None = None) -> str: def relativize_command(value: str, *, base_dir: Path | None = None) -> str:
base = Path.cwd() if base_dir is None else base_dir base = get_run_base_dir() if base_dir is None else base_dir
if base is None:
base = Path.cwd()
base_with_sep = f"{base}{os.sep}" base_with_sep = f"{base}{os.sep}"
return value.replace(base_with_sep, "") return value.replace(base_with_sep, "")
+119
View File
@@ -0,0 +1,119 @@
from __future__ import annotations
from pathlib import Path
from .config import ProjectConfig, ProjectsConfig
from .context import RunContext
from .utils.git import git_is_worktree, git_ok, git_run, resolve_default_base
class WorktreeError(RuntimeError):
pass
def resolve_run_cwd(
context: RunContext | None,
*,
projects: ProjectsConfig,
) -> Path | None:
if context is None or context.project is None:
return None
project = projects.projects.get(context.project)
if project is None:
raise WorktreeError(f"unknown project {context.project!r}")
if context.branch is None:
return project.path
return ensure_worktree(project, context.branch)
def ensure_worktree(project: ProjectConfig, branch: str) -> Path:
root = project.path
if not root.exists():
raise WorktreeError(f"project path not found: {root}")
branch = _sanitize_branch(branch)
worktrees_root = project.worktrees_root
worktree_path = worktrees_root / branch
_ensure_within_root(worktrees_root, worktree_path)
if worktree_path.exists():
if not git_is_worktree(worktree_path):
raise WorktreeError(f"{worktree_path} exists but is not a git worktree")
return worktree_path
worktrees_root.mkdir(parents=True, exist_ok=True)
if git_ok(
["show-ref", "--verify", "--quiet", f"refs/heads/{branch}"],
cwd=root,
):
_git_worktree_add(root, worktree_path, branch)
return worktree_path
if git_ok(
["show-ref", "--verify", "--quiet", f"refs/remotes/origin/{branch}"],
cwd=root,
):
_git_worktree_add(
root,
worktree_path,
branch,
base_ref=f"origin/{branch}",
create_branch=True,
)
return worktree_path
base = project.worktree_base or resolve_default_base(root)
if not base:
raise WorktreeError("cannot determine base branch for new worktree")
_git_worktree_add(
root,
worktree_path,
branch,
base_ref=base,
create_branch=True,
)
return worktree_path
def _git_worktree_add(
root: Path,
worktree_path: Path,
branch: str,
*,
base_ref: str | None = None,
create_branch: bool = False,
) -> None:
if create_branch:
if not base_ref:
raise WorktreeError("missing base ref for worktree creation")
args = ["worktree", "add", "-b", branch, str(worktree_path), base_ref]
else:
args = ["worktree", "add", str(worktree_path), branch]
result = git_run(args, cwd=root)
if result is None:
raise WorktreeError("git not available on PATH")
if result.returncode != 0:
message = result.stderr.strip() or result.stdout.strip()
raise WorktreeError(message or "git worktree add failed")
def _sanitize_branch(branch: str) -> str:
cleaned = branch.strip()
if not cleaned:
raise WorktreeError("branch name cannot be empty")
if cleaned.startswith("/"):
raise WorktreeError("branch name cannot start with '/'")
for part in Path(cleaned).parts:
if part == "..":
raise WorktreeError("branch name cannot contain '..'")
return cleaned
def _ensure_within_root(root: Path, path: Path) -> None:
root_resolved = root.resolve(strict=False)
path_resolved = path.resolve(strict=False)
if not path_resolved.is_relative_to(root_resolved):
raise WorktreeError("branch path escapes the worktrees directory")
+18
View File
@@ -104,3 +104,21 @@ def test_translate_command_execution_allows_null_exit_code() -> None:
assert isinstance(out[0], ActionEvent) assert isinstance(out[0], ActionEvent)
assert out[0].ok is True assert out[0].ok is True
assert out[0].action.detail["exit_code"] is None assert out[0].action.detail["exit_code"] is None
def test_translate_file_change_normalizes_changes() -> None:
evt = {
"type": "item.completed",
"item": {
"id": "item_6",
"type": "file_change",
"changes": [{"path": "README.md", "kind": "update"}],
"status": "completed",
},
}
out = _translate_event(evt)
assert len(out) == 1
assert isinstance(out[0], ActionEvent)
changes = out[0].action.detail["changes"]
assert changes == [{"path": "README.md", "kind": "update"}]
+31
View File
@@ -334,6 +334,37 @@ async def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
assert transport.send_calls[-1]["options"].replace == transport.send_calls[0]["ref"] assert transport.send_calls[-1]["options"].replace == transport.send_calls[0]["ref"]
@pytest.mark.anyio
async def test_final_message_includes_ctx_line() -> None:
transport = _FakeTransport()
clock = _FakeClock()
session_id = "123e4567-e89b-12d3-a456-426614174000"
runner = ScriptRunner(
[Return(answer="done")],
engine=CODEX_ENGINE,
resume_value=session_id,
)
cfg = ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
)
await handle_message(
cfg,
runner=runner,
incoming=IncomingMessage(channel_id=123, message_id=42, text="do it"),
resume_token=None,
context_line="`ctx: takopi @ feat/api`",
clock=clock,
)
assert transport.send_calls
final_text = transport.send_calls[-1]["message"].text
assert "`ctx: takopi @ feat/api`" in final_text
assert "codex resume" in final_text.lower()
@pytest.mark.anyio @pytest.mark.anyio
async def test_handle_message_cancelled_renders_cancelled_state() -> None: async def test_handle_message_cancelled_renders_cancelled_state() -> None:
transport = _FakeTransport() transport = _FakeTransport()
+52
View File
@@ -16,6 +16,7 @@ from takopi.markdown import (
from takopi.model import Action, ActionEvent, ResumeToken, StartedEvent, TakopiEvent from takopi.model import Action, ActionEvent, ResumeToken, StartedEvent, TakopiEvent
from takopi.progress import ProgressTracker from takopi.progress import ProgressTracker
from takopi.telegram.render import render_markdown from takopi.telegram.render import render_markdown
from takopi.utils.paths import reset_run_base_dir, set_run_base_dir
from tests.factories import ( from tests.factories import (
action_completed, action_completed,
action_started, action_started,
@@ -119,6 +120,40 @@ def test_file_change_renders_relative_paths_inside_cwd() -> None:
) )
def test_file_change_renders_change_objects(tmp_path: Path) -> None:
base = tmp_path / "repo"
base.mkdir()
abs_path = str(base / "changelog.md")
token = set_run_base_dir(base)
try:
out = render_event_cli(
action_completed(
"f-obj",
"file_change",
"ignored",
ok=True,
detail={"changes": [SimpleNamespace(path=abs_path, kind="update")]},
)
)
finally:
reset_run_base_dir(token)
assert any("files: update `changelog.md`" in line for line in out)
def test_file_change_title_relativizes_absolute_title(tmp_path: Path) -> None:
base = tmp_path / "repo"
base.mkdir()
abs_path = str(base / "changelog.md")
token = set_run_base_dir(base)
try:
out = render_event_cli(
action_completed("f-abs", "file_change", abs_path, ok=True)
)
finally:
reset_run_base_dir(token)
assert any("files: `changelog.md`" in line for line in out)
def test_progress_renderer_renders_progress_and_final() -> None: def test_progress_renderer_renders_progress_and_final() -> None:
tracker = ProgressTracker(engine="codex") tracker = ProgressTracker(engine="codex")
for evt in SAMPLE_EVENTS: for evt in SAMPLE_EVENTS:
@@ -145,6 +180,23 @@ def test_progress_renderer_renders_progress_and_final() -> None:
) )
def test_progress_renderer_footer_includes_ctx_before_resume() -> None:
tracker = ProgressTracker(engine="codex")
for evt in SAMPLE_EVENTS:
tracker.note_event(evt)
state = tracker.snapshot(
resume_formatter=_format_resume,
context_line="`ctx: z80 @ feat/name`",
)
formatter = MarkdownFormatter(max_actions=5)
parts = formatter.render_progress_parts(state, elapsed_s=0.0)
assert parts.footer == (
"`ctx: z80 @ feat/name`"
f"{HARD_BREAK}`codex resume 0199a213-81c0-7800-8aa1-bbab2a035a53`"
)
def test_progress_renderer_clamps_actions_and_ignores_unknown() -> None: def test_progress_renderer_clamps_actions_and_ignores_unknown() -> None:
tracker = ProgressTracker(engine="codex") tracker = ProgressTracker(engine="codex")
events = [ events = [
+56
View File
@@ -0,0 +1,56 @@
from pathlib import Path
from takopi.utils.git import resolve_default_base, resolve_main_worktree_root
def test_resolve_main_worktree_root_returns_none_when_no_git(monkeypatch) -> None:
monkeypatch.setattr("takopi.utils.git.git_stdout", lambda *args, **kwargs: None)
assert resolve_main_worktree_root(Path("/tmp")) is None
def test_resolve_main_worktree_root_prefers_common_dir_parent(monkeypatch) -> None:
base = Path("/repo")
def _fake_stdout(args, **kwargs):
if args[:2] == ["rev-parse", "--path-format=absolute"]:
return str(base / ".git")
if args == ["rev-parse", "--is-bare-repository"]:
return "false"
return None
monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout)
assert resolve_main_worktree_root(base / ".worktrees" / "feature") == base
def test_resolve_main_worktree_root_returns_cwd_for_bare_repo(monkeypatch) -> None:
cwd = Path("/bare-repo")
def _fake_stdout(args, **kwargs):
if args[:2] == ["rev-parse", "--path-format=absolute"]:
return str(cwd / "repo.git")
if args == ["rev-parse", "--is-bare-repository"]:
return "true"
return None
monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout)
assert resolve_main_worktree_root(cwd) == cwd
def test_resolve_default_base_prefers_master_over_main(monkeypatch) -> None:
def _fake_stdout(args, **kwargs):
if args[:2] == ["symbolic-ref", "-q"]:
return None
if args == ["branch", "--show-current"]:
return None
return None
def _fake_ok(args, **kwargs):
if args == ["show-ref", "--verify", "--quiet", "refs/heads/master"]:
return True
if args == ["show-ref", "--verify", "--quiet", "refs/heads/main"]:
return True
return False
monkeypatch.setattr("takopi.utils.git.git_stdout", _fake_stdout)
monkeypatch.setattr("takopi.utils.git.git_ok", _fake_ok)
assert resolve_default_base(Path("/repo")) == "master"
+17 -1
View File
@@ -2,7 +2,12 @@ from __future__ import annotations
from pathlib import Path from pathlib import Path
from takopi.utils.paths import relativize_command, relativize_path from takopi.utils.paths import (
relativize_command,
relativize_path,
reset_run_base_dir,
set_run_base_dir,
)
def test_relativize_command_rewrites_cwd_paths(tmp_path: Path) -> None: def test_relativize_command_rewrites_cwd_paths(tmp_path: Path) -> None:
@@ -33,3 +38,14 @@ def test_relativize_path_inside_base(tmp_path: Path) -> None:
base.mkdir() base.mkdir()
value = str(base / "src" / "app.py") value = str(base / "src" / "app.py")
assert relativize_path(value, base_dir=base) == "src/app.py" assert relativize_path(value, base_dir=base) == "src/app.py"
def test_relativize_path_uses_run_base_dir(tmp_path: Path) -> None:
base = tmp_path / "repo"
base.mkdir()
token = set_run_base_dir(base)
try:
value = str(base / "src" / "app.py")
assert relativize_path(value) == "src/app.py"
finally:
reset_run_base_dir(token)
+49
View File
@@ -0,0 +1,49 @@
from pathlib import Path
import pytest
from typer.testing import CliRunner
from takopi import cli
from takopi.config import ConfigError, parse_projects_config
def test_parse_projects_rejects_engine_alias() -> None:
config = {"projects": {"codex": {"path": "/tmp/repo"}}}
with pytest.raises(ConfigError, match="aliases must not match engine ids"):
parse_projects_config(
config,
config_path=Path("takopi.toml"),
engine_ids=["codex"],
reserved=("cancel",),
)
def test_parse_projects_default_project_must_exist() -> None:
config = {"default_project": "z80", "projects": {}}
with pytest.raises(ConfigError, match="default_project"):
parse_projects_config(
config,
config_path=Path("takopi.toml"),
engine_ids=["codex"],
reserved=("cancel",),
)
def test_init_writes_project(monkeypatch, tmp_path) -> None:
config_path = tmp_path / "takopi.toml"
monkeypatch.setattr("takopi.config.HOME_CONFIG_PATH", config_path)
monkeypatch.setattr(cli, "resolve_default_base", lambda _: "main")
repo_path = tmp_path / "repo"
repo_path.mkdir()
monkeypatch.chdir(repo_path)
runner = CliRunner()
result = runner.invoke(cli.app, ["init", "z80"])
assert result.exit_code == 0
saved = config_path.read_text(encoding="utf-8")
assert "[projects.z80]" in saved
assert 'worktrees_dir = ".worktrees"' in saved
assert 'default_engine = "codex"' in saved
assert 'worktree_base = "main"' in saved
+112 -42
View File
@@ -1,3 +1,5 @@
from pathlib import Path
import anyio import anyio
import pytest import pytest
@@ -7,16 +9,19 @@ from takopi.telegram.bridge import (
_build_bot_commands, _build_bot_commands,
_handle_cancel, _handle_cancel,
_is_cancel_command, _is_cancel_command,
_resolve_message,
_send_with_resume, _send_with_resume,
_strip_engine_command, _strip_engine_command,
run_main_loop, run_main_loop,
) )
from takopi.context import RunContext
from takopi.config import ProjectConfig, ProjectsConfig
from takopi.runner_bridge import ExecBridgeConfig, RunningTask from takopi.runner_bridge import ExecBridgeConfig, RunningTask
from takopi.markdown import MarkdownPresenter from takopi.markdown import MarkdownPresenter
from takopi.model import EngineId, ResumeToken from takopi.model import EngineId, ResumeToken
from takopi.router import AutoRouter, RunnerEntry from takopi.router import AutoRouter, RunnerEntry
from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait from takopi.runners.mock import Return, ScriptRunner, Sleep, Wait
from takopi.transport import MessageRef, RenderedMessage, SendOptions from takopi.transport import IncomingMessage, MessageRef, RenderedMessage, SendOptions
CODEX_ENGINE = EngineId("codex") CODEX_ENGINE = EngineId("codex")
@@ -354,7 +359,15 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None:
async def test_handle_cancel_without_reply_prompts_user() -> None: async def test_handle_cancel_without_reply_prompts_user() -> None:
transport = _FakeTransport() transport = _FakeTransport()
cfg = _make_cfg(transport) cfg = _make_cfg(transport)
msg = {"chat": {"id": 123}, "message_id": 10} msg = IncomingMessage(
transport="telegram",
chat_id=123,
message_id=10,
text="/cancel",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
)
running_tasks: dict = {} running_tasks: dict = {}
await _handle_cancel(cfg, msg, running_tasks) await _handle_cancel(cfg, msg, running_tasks)
@@ -367,11 +380,15 @@ async def test_handle_cancel_without_reply_prompts_user() -> None:
async def test_handle_cancel_with_no_progress_message_says_nothing_running() -> None: async def test_handle_cancel_with_no_progress_message_says_nothing_running() -> None:
transport = _FakeTransport() transport = _FakeTransport()
cfg = _make_cfg(transport) cfg = _make_cfg(transport)
msg = { msg = IncomingMessage(
"chat": {"id": 123}, transport="telegram",
"message_id": 10, chat_id=123,
"reply_to_message": {"text": "no message id"}, message_id=10,
} text="/cancel",
reply_to_message_id=None,
reply_to_text="no message id",
sender_id=123,
)
running_tasks: dict = {} running_tasks: dict = {}
await _handle_cancel(cfg, msg, running_tasks) await _handle_cancel(cfg, msg, running_tasks)
@@ -385,11 +402,15 @@ async def test_handle_cancel_with_finished_task_says_nothing_running() -> None:
transport = _FakeTransport() transport = _FakeTransport()
cfg = _make_cfg(transport) cfg = _make_cfg(transport)
progress_id = 99 progress_id = 99
msg = { msg = IncomingMessage(
"chat": {"id": 123}, transport="telegram",
"message_id": 10, chat_id=123,
"reply_to_message": {"message_id": progress_id}, message_id=10,
} text="/cancel",
reply_to_message_id=progress_id,
reply_to_text=None,
sender_id=123,
)
running_tasks: dict = {} running_tasks: dict = {}
await _handle_cancel(cfg, msg, running_tasks) await _handle_cancel(cfg, msg, running_tasks)
@@ -403,11 +424,15 @@ async def test_handle_cancel_cancels_running_task() -> None:
transport = _FakeTransport() transport = _FakeTransport()
cfg = _make_cfg(transport) cfg = _make_cfg(transport)
progress_id = 42 progress_id = 42
msg = { msg = IncomingMessage(
"chat": {"id": 123}, transport="telegram",
"message_id": 10, chat_id=123,
"reply_to_message": {"message_id": progress_id}, message_id=10,
} text="/cancel",
reply_to_message_id=progress_id,
reply_to_text=None,
sender_id=123,
)
running_task = RunningTask() running_task = RunningTask()
running_tasks = {MessageRef(channel_id=123, message_id=progress_id): running_task} running_tasks = {MessageRef(channel_id=123, message_id=progress_id): running_task}
@@ -423,11 +448,15 @@ async def test_handle_cancel_only_cancels_matching_progress_message() -> None:
cfg = _make_cfg(transport) cfg = _make_cfg(transport)
task_first = RunningTask() task_first = RunningTask()
task_second = RunningTask() task_second = RunningTask()
msg = { msg = IncomingMessage(
"chat": {"id": 123}, transport="telegram",
"message_id": 10, chat_id=123,
"reply_to_message": {"message_id": 1}, message_id=10,
} text="/cancel",
reply_to_message_id=1,
reply_to_text=None,
sender_id=123,
)
running_tasks = { running_tasks = {
MessageRef(channel_id=123, message_id=1): task_first, MessageRef(channel_id=123, message_id=1): task_first,
MessageRef(channel_id=123, message_id=2): task_second, MessageRef(channel_id=123, message_id=2): task_second,
@@ -446,16 +475,46 @@ def test_cancel_command_accepts_extra_text() -> None:
assert _is_cancel_command("/cancelled") is False assert _is_cancel_command("/cancelled") is False
def test_resolve_message_accepts_backticked_ctx_line() -> None:
router = _make_router(ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE))
projects = ProjectsConfig(
projects={
"takopi": ProjectConfig(
alias="takopi",
path=Path("."),
worktrees_dir=Path(".worktrees"),
)
},
default_project=None,
)
resolved = _resolve_message(
text="do it",
reply_text="`ctx: takopi @ feat/api`",
router=router,
projects=projects,
)
assert resolved.prompt == "do it"
assert resolved.resume_token is None
assert resolved.engine_override is None
assert resolved.context == RunContext(project="takopi", branch="feat/api")
@pytest.mark.anyio @pytest.mark.anyio
async def test_send_with_resume_waits_for_token() -> None: async def test_send_with_resume_waits_for_token() -> None:
transport = _FakeTransport() transport = _FakeTransport()
cfg = _make_cfg(transport) cfg = _make_cfg(transport)
sent: list[tuple[int, int, str, ResumeToken | None]] = [] sent: list[tuple[int, int, str, ResumeToken, RunContext | None]] = []
async def enqueue( async def enqueue(
chat_id: int, user_msg_id: int, text: str, resume: ResumeToken chat_id: int,
user_msg_id: int,
text: str,
resume: ResumeToken,
context: RunContext | None,
) -> None: ) -> None:
sent.append((chat_id, user_msg_id, text, resume)) sent.append((chat_id, user_msg_id, text, resume, context))
running_task = RunningTask() running_task = RunningTask()
@@ -476,7 +535,7 @@ async def test_send_with_resume_waits_for_token() -> None:
) )
assert sent == [ assert sent == [
(123, 10, "hello", ResumeToken(engine=CODEX_ENGINE, value="abc123")) (123, 10, "hello", ResumeToken(engine=CODEX_ENGINE, value="abc123"), None)
] ]
assert transport.send_calls == [] assert transport.send_calls == []
@@ -485,12 +544,16 @@ async def test_send_with_resume_waits_for_token() -> None:
async def test_send_with_resume_reports_when_missing() -> None: async def test_send_with_resume_reports_when_missing() -> None:
transport = _FakeTransport() transport = _FakeTransport()
cfg = _make_cfg(transport) cfg = _make_cfg(transport)
sent: list[tuple[int, int, str, ResumeToken | None]] = [] sent: list[tuple[int, int, str, ResumeToken, RunContext | None]] = []
async def enqueue( async def enqueue(
chat_id: int, user_msg_id: int, text: str, resume: ResumeToken chat_id: int,
user_msg_id: int,
text: str,
resume: ResumeToken,
context: RunContext | None,
) -> None: ) -> None:
sent.append((chat_id, user_msg_id, text, resume)) sent.append((chat_id, user_msg_id, text, resume, context))
running_task = RunningTask() running_task = RunningTask()
running_task.done.set() running_task.done.set()
@@ -538,22 +601,29 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None:
) )
async def poller(_cfg: TelegramBridgeConfig): async def poller(_cfg: TelegramBridgeConfig):
yield { yield IncomingMessage(
"message_id": 1, transport="telegram",
"text": "first", chat_id=123,
"chat": {"id": 123}, message_id=1,
"from": {"id": 123}, text="first",
} reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
)
await progress_ready.wait() await progress_ready.wait()
assert transport.progress_ref is not None assert transport.progress_ref is not None
assert isinstance(transport.progress_ref.message_id, int)
reply_id = transport.progress_ref.message_id
reply_ready.set() reply_ready.set()
yield { yield IncomingMessage(
"message_id": 2, transport="telegram",
"text": "followup", chat_id=123,
"chat": {"id": 123}, message_id=2,
"from": {"id": 123}, text="followup",
"reply_to_message": {"message_id": transport.progress_ref.message_id}, reply_to_message_id=reply_id,
} reply_to_text=None,
sender_id=123,
)
await stop_polling.wait() await stop_polling.wait()
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
+47
View File
@@ -0,0 +1,47 @@
from takopi.telegram import parse_incoming_update
def test_parse_incoming_update_maps_fields() -> None:
update = {
"update_id": 1,
"message": {
"message_id": 10,
"text": "hello",
"chat": {"id": 123},
"from": {"id": 99},
"reply_to_message": {"message_id": 5, "text": "prev"},
},
}
msg = parse_incoming_update(update, chat_id=123)
assert msg is not None
assert msg.transport == "telegram"
assert msg.chat_id == 123
assert msg.message_id == 10
assert msg.text == "hello"
assert msg.reply_to_message_id == 5
assert msg.reply_to_text == "prev"
assert msg.sender_id == 99
assert msg.raw == update["message"]
def test_parse_incoming_update_filters_non_matching_chat() -> None:
update = {
"update_id": 1,
"message": {
"message_id": 10,
"text": "hello",
"chat": {"id": 123},
},
}
assert parse_incoming_update(update, chat_id=999) is None
def test_parse_incoming_update_filters_non_text() -> None:
update = {
"update_id": 1,
"message": {"message_id": 10, "chat": {"id": 123}},
}
assert parse_incoming_update(update, chat_id=123) is None
+56
View File
@@ -0,0 +1,56 @@
from pathlib import Path
from types import SimpleNamespace
import pytest
from takopi.config import ProjectConfig, ProjectsConfig
from takopi.context import RunContext
from takopi.worktrees import WorktreeError, ensure_worktree, resolve_run_cwd
def _projects_config(path: Path) -> ProjectsConfig:
return ProjectsConfig(
projects={
"z80": ProjectConfig(
alias="z80",
path=path,
worktrees_dir=Path(".worktrees"),
)
},
default_project=None,
)
def test_resolve_run_cwd_uses_project_root(tmp_path: Path) -> None:
projects = _projects_config(tmp_path)
ctx = RunContext(project="z80")
assert resolve_run_cwd(ctx, projects=projects) == tmp_path
def test_resolve_run_cwd_rejects_invalid_branch(tmp_path: Path) -> None:
projects = _projects_config(tmp_path)
ctx = RunContext(project="z80", branch="../oops")
with pytest.raises(WorktreeError, match="branch name"):
resolve_run_cwd(ctx, projects=projects)
def test_ensure_worktree_creates_from_base(monkeypatch, tmp_path: Path) -> None:
project = ProjectConfig(
alias="z80",
path=tmp_path,
worktrees_dir=Path(".worktrees"),
)
calls: list[list[str]] = []
monkeypatch.setattr("takopi.worktrees.git_ok", lambda *args, **kwargs: False)
monkeypatch.setattr("takopi.worktrees.resolve_default_base", lambda *_: "main")
def _fake_git_run(args, cwd):
calls.append(list(args))
return SimpleNamespace(returncode=0, stdout="", stderr="")
monkeypatch.setattr("takopi.worktrees.git_run", _fake_git_run)
worktree_path = ensure_worktree(project, "feat/name")
assert worktree_path == tmp_path / ".worktrees" / "feat" / "name"
assert calls == [["worktree", "add", "-b", "feat/name", str(worktree_path), "main"]]