From c06a0abc171f17f27abb9b7fa1558ac763642de9 Mon Sep 17 00:00:00 2001 From: banteg <4562643+banteg@users.noreply.github.com> Date: Sat, 10 Jan 2026 22:51:31 +0400 Subject: [PATCH] feat: telegram forum topics support (#80) --- docs/transports/telegram.md | 35 ++ docs/user-guide.md | 410 ++++++++++++++++++ readme.md | 7 + src/takopi/cli.py | 64 +++ src/takopi/config.py | 2 +- src/takopi/logging.py | 6 +- src/takopi/runner.py | 8 +- src/takopi/runner_bridge.py | 8 +- src/takopi/scheduler.py | 3 + src/takopi/settings.py | 58 +-- src/takopi/telegram/backend.py | 27 +- src/takopi/telegram/bridge.py | 627 ++++++++++++++++++++++++++- src/takopi/telegram/client.py | 113 +++++ src/takopi/telegram/onboarding.py | 35 ++ src/takopi/telegram/topic_state.py | 307 +++++++++++++ src/takopi/telegram/types.py | 4 + src/takopi/transport.py | 1 + src/takopi/transport_runtime.py | 98 +++-- tests/test_cli_chat_id.py | 70 +++ tests/test_onboarding_interactive.py | 48 ++ tests/test_runner_utils.py | 416 ++++++++++++++++++ tests/test_telegram_bridge.py | 271 ++++++++++-- tests/test_telegram_incoming.py | 26 +- tests/test_telegram_queue.py | 48 +- tests/test_telegram_topic_state.py | 49 +++ tests/test_transport_runtime.py | 90 ++++ 26 files changed, 2718 insertions(+), 113 deletions(-) create mode 100644 docs/user-guide.md create mode 100644 src/takopi/telegram/topic_state.py create mode 100644 tests/test_cli_chat_id.py create mode 100644 tests/test_runner_utils.py create mode 100644 tests/test_telegram_topic_state.py diff --git a/docs/transports/telegram.md b/docs/transports/telegram.md index 95537e6..2048cbf 100644 --- a/docs/transports/telegram.md +++ b/docs/transports/telegram.md @@ -34,6 +34,41 @@ Set `OPENAI_API_KEY` in the environment. If transcription is enabled but the API is missing or the audio download fails, takopi replies with a short error and skips the run. +## Forum topics (optional) + +Takopi can bind Telegram forum topics to a project/branch and persist resume tokens +per topic, so replies keep the right context even after restarts. + +Configuration (under `[transports.telegram]`): + +```toml +[transports.telegram.topics] +enabled = true +mode = "multi_project_chat" # or "per_project_chat" +``` + +Requirements: + +- `multi_project_chat`: `chat_id` must be a forum-enabled supergroup (topics enabled). +- `per_project_chat`: each `projects..chat_id` must point to a forum-enabled + supergroup for that project. +- The bot needs the **Manage Topics** permission in the relevant chat(s). + +Commands: + +- `multi_project_chat`: `/topic @branch` creates a topic in the main chat + and binds it. +- `per_project_chat`: `/topic @branch` creates a topic in the project chat and binds it. +- `/ctx` inside a topic shows the bound context and stored session engines. + `/ctx set ...` and `/ctx clear` update the binding. +- `/new` inside a topic clears stored resume tokens for that topic. + +State is stored in `telegram_topics_state.json` alongside the config file. +Delete it to reset all topic bindings and stored sessions. + +Note: `multi_project_chat` does not assume a default project; topics must be bound +before running without directives. + ## Outbox model - Single worker processes one op at a time. diff --git a/docs/user-guide.md b/docs/user-guide.md new file mode 100644 index 0000000..e66c23d --- /dev/null +++ b/docs/user-guide.md @@ -0,0 +1,410 @@ +# Takopi User Guide + +Takopi is a command-line tool that lets you control coding agents—like Codex, Claude, and others—through Telegram. Send a message, and takopi runs the agent in your repo, streaming progress back to your chat. It supports multi-repo workflows, git worktrees, and per-project routing. + +This guide starts simple and layers on features as you go. Jump to any section or read straight through. + +## Prerequisites + +Before you begin, make sure you have: + +- A Telegram account +- Python 3.14+ and `uv` installed +- At least one supported agent CLI installed and on your `PATH` (codex, claude, opencode, pi) +- Basic familiarity with git (especially if you plan to use worktrees) + +## Key concepts + +A few terms you'll see throughout: + +| Term | Meaning | +|------|---------| +| **Engine** | A coding agent backend (Codex, Claude, opencode, pi) | +| **Project** | A registered git repository with an alias | +| **Worktree** | A git feature that lets you check out multiple branches simultaneously in separate directories | +| **Topic** | A Telegram forum thread bound to a specific project/branch context | +| **Resume token** | State that allows an engine to continue from where it left off | + +--- + +## 1. Installation and setup + +Install takopi with: + +```sh +uv tool install -U takopi +``` + +Run it once to start the onboarding wizard: + +```sh +takopi +``` + +The wizard walks you through: + +1. Creating a Telegram bot token via [@BotFather](https://t.me/BotFather) +2. Capturing your `chat_id` (the wizard listens for a message from you) +3. Choosing a default engine + +To re-run onboarding later, use `takopi --onboard`. + +Your configuration is stored at `~/.takopi/takopi.toml`. + +### Minimal configuration + +After onboarding, your config looks something like this: + +```toml +default_engine = "codex" +transport = "telegram" + +[transports.telegram] +bot_token = "123456789:ABCdefGHIjklMNOpqrsTUVwxyz" +chat_id = 123456789 +``` + +--- + +## 2. Your first handoff + +The simplest workflow: + +1. `cd` into any git repository +2. Run `takopi` +3. Send a message to your bot + +Takopi streams progress in the chat and sends a final response when the agent finishes. + +### Basic controls + +- **Reply** to a bot message with more instructions to continue the conversation +- **Cancel** a run by clicking the cancel button or replying to the progress message with `/cancel` + +--- + +## 3. Switching engines + +Prefix your message with an engine directive to override the default: + +``` +/codex hard reset the timeline +/claude shrink and store artifacts forever +/opencode hide their paper until they reply +/pi render a diorama of this timeline +``` + +Directives are only parsed at the start of the first non-empty line. + +### Setting up engines + +Takopi shells out to the agent CLIs. Install them and make sure they're on your `PATH` +(codex, claude, opencode, pi). Authentication is handled by each CLI (login, +config files, or environment variables). + +--- + +## 4. Projects + +For repos you work with often, register them as projects: + +```sh +cd ~/dev/happy-gadgets +takopi init happy-gadgets +``` + +This adds a project entry to your config (for example): + +```toml +[projects.happy-gadgets] +path = "~/dev/happy-gadgets" +``` + +Now you can target it from anywhere using the `/project` directive: + +``` +/happy-gadgets pinky-link two threads +``` + +If you expect to add or edit projects while takopi is running, enable config +watching so changes are picked up automatically: + +```toml +watch_config = true +``` + +### Project-specific settings + +Projects can override global defaults: + +```toml +[projects.happy-gadgets] +path = "~/dev/happy-gadgets" +default_engine = "claude" +worktrees_dir = ".worktrees" +worktree_base = "master" +``` + +### Setting a default project + +If you mostly work in one repo: + +```toml +default_project = "happy-gadgets" +``` + +--- + +## 5. Worktrees + +Worktrees let you work on multiple branches without switching back and forth. Use `@branch` to run a task in a dedicated worktree: + +``` +/happy-gadgets @feat/memory-box freeze artifacts forever +``` + +Takopi creates (or reuses) a worktree at: + +``` +/ +``` + +`worktrees_root` is `/` unless `worktrees_dir` is an +absolute path. If the branch matches the repo's current branch, Takopi runs in the +main repo instead of creating a new worktree. + +### Worktree configuration + +```toml +[projects.happy-gadgets] +path = "~/dev/happy-gadgets" +worktrees_dir = ".worktrees" # relative to project path +worktree_base = "master" # base branch for new worktrees +``` + +To avoid `.worktrees/` showing up as untracked, add it to your global gitignore: + +```sh +git config --global core.excludesfile ~/.config/git/ignore +echo ".worktrees/" >> ~/.config/git/ignore +``` + +### Context persistence + +Takopi adds a `ctx:` footer to messages with project and branch info. When you reply, this context carries forward—no need to repeat `/project @branch` each time. + +--- + +## 6. Per-project chat routing + +Give each project its own Telegram chat: + +```sh +takopi chat-id --project happy-gadgets +``` + +Send any message in the target chat. Takopi captures the `chat_id` and updates your config: + +```toml +[projects.happy-gadgets] +path = "~/dev/happy-gadgets" +chat_id = -1001234567890 +``` + +Messages from that chat automatically route to the project. + +### Rules for chat IDs + +- Each `projects.*.chat_id` must be unique +- Project chat IDs must not match `transports.telegram.chat_id` +- Telegram uses positive IDs for private chats and negative IDs for groups/supergroups + +### Capture a chat ID without saving + +To see a chat ID without writing to config: + +```sh +takopi chat-id +``` + +--- + +## 7. Topics + +Topics bind Telegram forum threads to specific project/branch contexts. They also preserve resume tokens, so agents can pick up where they left off. + +### Enabling topics + +```toml +[transports.telegram.topics] +enabled = true +mode = "multi_project_chat" # or "per_project_chat" +``` + +Your bot needs **Manage Topics** permission in the group. + +### Topic modes explained + +**`multi_project_chat`** — One forum-enabled supergroup for everything. Create topics per project/branch combination. + +``` +┌────────────────────────────┐ +│ takopi projects │ +├────────────────────────────┤ +│ takopi @master │ +│ takopi @feat/topics │ +│ happy-gadgets @master │ +│ happy-gadgets @feat/camera │ +└────────────────────────────┘ +``` + +**`per_project_chat`** — Each project has its own forum-enabled supergroup. Topics still include the project name for consistency, but the project is inferred from the chat. Regular messages in that chat also infer the project, so `/project` is usually optional. + +``` +┌────────────────────────────────┐ ┌───────────────────────────────────┐ +│ takopi │ │ happy-gadgets │ +├────────────────────────────────┤ ├───────────────────────────────────┤ +│ takopi @master │ │ happy-gadgets @master │ +│ takopi @feat/topics │ │ happy-gadgets @feat/happy-camera │ +│ takopi @feat/voice │ │ happy-gadgets @feat/memory-box │ +└────────────────────────────────┘ └───────────────────────────────────┘ +``` + +### Topic commands + +Run these inside a topic thread: + +| Command | Description | +|---------|-------------| +| `/topic @branch` | Create a new topic bound to context | +| `/ctx` | Show the current binding | +| `/ctx set @branch` | Update the binding | +| `/ctx clear` | Remove the binding | +| `/new` | Clear resume tokens for this topic | + +In `per_project_chat` mode, omit the project: `/topic @branch` or `/ctx set @branch`. + +### Configuration examples + +**Multi-project chat:** + +```toml +[transports.telegram] +chat_id = -1001234567890 + +[transports.telegram.topics] +enabled = true +mode = "multi_project_chat" +``` + +**Per-project chat:** + +```toml +[transports.telegram] +chat_id = 123456789 # main chat (private, for non-project messages) + +[transports.telegram.topics] +enabled = true +mode = "per_project_chat" + +[projects.takopi] +path = "~/dev/takopi" +chat_id = -1001111111111 # forum-enabled group +``` + +Topic state is stored in `telegram_topics_state.json` next to your config file. + +--- + +## 8. Voice notes + +Dictate tasks instead of typing: + +```toml +[transports.telegram] +voice_transcription = true +``` + +Set `OPENAI_API_KEY` in your environment (uses OpenAI's transcription API with the +`gpt-4o-mini-transcribe` model). + +When you send a voice note, takopi transcribes it and runs the result as a normal text message. If transcription fails, you'll get an error message and the run is skipped. + +--- + +## 9. Configuration reference + +Full example with all options: + +```toml +# Global defaults +default_engine = "codex" +default_project = "takopi" +transport = "telegram" +watch_config = true # hot-reload on config changes (except transport) + +[transports.telegram] +bot_token = "123456789:ABCdefGHIjklMNOpqrsTUVwxyz" +chat_id = 123456789 +voice_transcription = true + +[transports.telegram.topics] +enabled = true +mode = "multi_project_chat" + +# Project definitions +[projects.takopi] +path = "~/dev/takopi" +default_engine = "codex" +worktrees_dir = ".worktrees" +worktree_base = "master" +# chat_id = -1001234567890 # optional: dedicated chat + +[projects.happy-planet] +path = "~/dev/happy-planet" +default_engine = "claude" +worktrees_dir = "~/.takopi/worktrees/happy-planet" +worktree_base = "develop" +``` + +--- + +## 10. Command cheatsheet + +### Message directives + +| Directive | Example | Description | +|-----------|---------|-------------| +| `/engine` | `/codex make threads resolve their differences` | Use a specific engine | +| `/project` | `/happy-gadgets add escape-pod` | Target a project | +| `@branch` | `@feat/happy-camera rewind to checkpoint` | Run in a worktree | +| Combined | `/happy-gadgets @feat/flower-pin observe unseen` | Project + branch | + +### In-chat commands + +| Command | Description | +|---------|-------------| +| `/cancel` | Reply to the progress message to stop the current run | +| `/topic @branch` | Create/bind a topic | +| `/ctx` | Show current context | +| `/ctx set @branch` | Update context binding | +| `/ctx clear` | Remove context binding | +| `/new` | Clear resume tokens | + +### CLI commands + +| Command | Description | +|---------|-------------| +| `takopi` | Start the bot (runs onboarding if first time) | +| `takopi --onboard` | Re-run onboarding wizard | +| `takopi init ` | Register current directory as a project | +| `takopi chat-id` | Capture a chat ID | +| `takopi chat-id --project ` | Set a project's chat ID | +| `takopi --debug` | Write debug logs to `debug.log` | + +--- + +## 11. Troubleshooting + +If something isn't working, rerun with `takopi --debug` and check `debug.log` +for errors. Include it when reporting issues. diff --git a/readme.md b/readme.md index 5bae650..33a5018 100644 --- a/readme.md +++ b/readme.md @@ -20,6 +20,8 @@ parallel runs across threads, per thread queue support. optional voice note transcription for telegram (routes transcript like typed text). +telegram forum topics: bind a topic to a project/branch and keep per-topic session resumes. + per-project chat routing: assign different telegram chats to different projects. ## requirements @@ -67,6 +69,11 @@ bot_token = "123456789:ABCdefGHIjklMNOpqrsTUVwxyz" chat_id = 123456789 voice_transcription = true +[transports.telegram.topics] +enabled = true +mode = "multi_project_chat" # or "per_project_chat" +# per_project_chat uses projects..chat_id to infer the project + [codex] # optional: profile from ~/.codex/config.toml profile = "takopi" diff --git a/src/takopi/cli.py b/src/takopi/cli.py index 1131398..bedd617 100644 --- a/src/takopi/cli.py +++ b/src/takopi/cli.py @@ -38,6 +38,7 @@ from .plugins import ( from .transports import SetupResult, get_transport from .transport_runtime import TransportRuntime from .utils.git import resolve_default_base, resolve_main_worktree_root +from .telegram import onboarding logger = get_logger(__name__) @@ -271,6 +272,9 @@ def _run_auto_router( debug: bool, onboard: bool, ) -> None: + if debug: + os.environ.setdefault("TAKOPI_LOG_FILE", "debug.log") + os.environ.setdefault("TAKOPI_LOG_FORMAT", "json") setup_logging(debug=debug) lock_handle: LockHandle | None = None try: @@ -514,6 +518,65 @@ def init( typer.echo(f"saved project {alias!r} to {_config_path_display(config_path)}") +def chat_id( + token: str | None = typer.Option( + None, + "--token", + help="Telegram bot token (defaults to config if available).", + ), + project: str | None = typer.Option( + None, + "--project", + help="Project alias to print a chat_id snippet for.", + ), +) -> None: + """Capture a Telegram chat id and exit.""" + setup_logging(debug=False, cache_logger_on_first_use=False) + if token is None: + settings, _ = _load_settings_optional() + if settings is not None: + tg = settings.transports.telegram + if tg.bot_token is not None: + token = tg.bot_token.get_secret_value().strip() or None + chat = onboarding.capture_chat_id(token=token) + if chat is None: + raise typer.Exit(code=1) + if project: + project = project.strip() + if not project: + raise ConfigError("Invalid `--project`; expected a non-empty string.") + + config, config_path = load_or_init_config() + if config_path.exists(): + applied = migrate_config(config, config_path=config_path) + if applied: + write_config(config, config_path) + + projects = _ensure_projects_table(config, config_path) + entry = projects.get(project) + if entry is None: + lowered = project.lower() + for key, value in projects.items(): + if isinstance(key, str) and key.lower() == lowered: + entry = value + project = key + break + if entry is None: + raise ConfigError( + f"Unknown project {project!r}; run `takopi init {project}` first." + ) + if not isinstance(entry, dict): + raise ConfigError( + f"Invalid `projects.{project}` in {config_path}; expected a table." + ) + entry["chat_id"] = chat.chat_id + write_config(config, config_path) + typer.echo(f"updated projects.{project}.chat_id = {chat.chat_id}") + return + + typer.echo(f"chat_id = {chat.chat_id}") + + def _print_entrypoints( label: str, entrypoints: list[EntryPoint], *, allowlist: set[str] | None ) -> None: @@ -704,6 +767,7 @@ def create_app() -> typer.Typer: help="Run takopi with auto-router (subcommands override the default engine).", ) app.command(name="init")(init) + app.command(name="chat-id")(chat_id) app.command(name="plugins")(plugins_cmd) app.callback()(app_main) for engine_id in _engine_ids_for_cli(): diff --git a/src/takopi/config.py b/src/takopi/config.py index b882c81..cc9cc15 100644 --- a/src/takopi/config.py +++ b/src/takopi/config.py @@ -158,7 +158,7 @@ def parse_projects_config( raise ConfigError( f"Invalid `worktrees_dir` for project {alias!r} in {config_path}." ) - worktrees_dir = Path(worktrees_dir_raw.strip()) + worktrees_dir = Path(worktrees_dir_raw.strip()).expanduser() default_engine_raw = raw_entry.get("default_engine") default_engine = None diff --git a/src/takopi/logging.py b/src/takopi/logging.py index e72308e..ec2fc28 100644 --- a/src/takopi/logging.py +++ b/src/takopi/logging.py @@ -206,7 +206,9 @@ class SafeWriter(io.TextIOBase): pass -def setup_logging(*, debug: bool = False) -> None: +def setup_logging( + *, debug: bool = False, cache_logger_on_first_use: bool = True +) -> None: global _MIN_LEVEL, _PIPELINE_LEVEL_NAME global _log_file_handle @@ -261,7 +263,7 @@ def setup_logging(*, debug: bool = False) -> None: structlog.configure( processors=processors, logger_factory=structlog.PrintLoggerFactory(file=safe_stream), - cache_logger_on_first_use=True, + cache_logger_on_first_use=cache_logger_on_first_use, ) diff --git a/src/takopi/runner.py b/src/takopi/runner.py index 265b067..ef4f15a 100644 --- a/src/takopi/runner.py +++ b/src/takopi/runner.py @@ -54,9 +54,9 @@ class ResumeTokenMixin: class SessionLockMixin: engine: EngineId - session_locks: WeakValueDictionary[str, anyio.Lock] | None = None + session_locks: WeakValueDictionary[str, anyio.Semaphore] | None = None - def lock_for(self, token: ResumeToken) -> anyio.Lock: + def lock_for(self, token: ResumeToken) -> anyio.Semaphore: locks = self.session_locks if locks is None: locks = WeakValueDictionary() @@ -64,7 +64,7 @@ class SessionLockMixin: key = f"{token.engine}:{token.value}" lock = locks.get(key) if lock is None: - lock = anyio.Lock() + lock = anyio.Semaphore(1) locks[key] = lock return lock @@ -105,7 +105,7 @@ class BaseRunner(SessionLockMixin): yield evt return - lock: anyio.Lock | None = None + lock: anyio.Semaphore | None = None acquired = False try: async for evt in self.run_impl(prompt, None): diff --git a/src/takopi/runner_bridge.py b/src/takopi/runner_bridge.py index e065eec..27f4cd8 100644 --- a/src/takopi/runner_bridge.py +++ b/src/takopi/runner_bridge.py @@ -302,9 +302,11 @@ async def run_runner_with_cancel( bind_run_context(resume=evt.resume.value) if running_task is not None and running_task.resume is None: running_task.resume = evt.resume - running_task.resume_ready.set() - if on_thread_known is not None: - await on_thread_known(evt.resume, running_task.done) + try: + if on_thread_known is not None: + await on_thread_known(evt.resume, running_task.done) + finally: + running_task.resume_ready.set() elif isinstance(evt, CompletedEvent): outcome.resume = evt.resume or outcome.resume outcome.completed = evt diff --git a/src/takopi/scheduler.py b/src/takopi/scheduler.py index ead25d7..13b495b 100644 --- a/src/takopi/scheduler.py +++ b/src/takopi/scheduler.py @@ -17,6 +17,7 @@ class ThreadJob: text: str resume_token: ResumeToken context: RunContext | None = None + thread_id: int | None = None RunJob = Callable[[ThreadJob], Awaitable[None]] @@ -69,6 +70,7 @@ class ThreadScheduler: text: str, resume_token: ResumeToken, context: RunContext | None = None, + thread_id: int | None = None, ) -> None: await self.enqueue( ThreadJob( @@ -77,6 +79,7 @@ class ThreadScheduler: text=text, resume_token=resume_token, context=context, + thread_id=thread_id, ) ) diff --git a/src/takopi/settings.py b/src/takopi/settings.py index bfb2c17..6f4042b 100644 --- a/src/takopi/settings.py +++ b/src/takopi/settings.py @@ -16,16 +16,43 @@ from pydantic import ( from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings.sources import TomlConfigSettingsSource -from .config import ConfigError, ProjectConfig, ProjectsConfig, HOME_CONFIG_PATH +from .config import ( + ConfigError, + HOME_CONFIG_PATH, + ProjectConfig, + ProjectsConfig, + _normalize_engine_id, + _normalize_project_path, +) from .config_migrations import migrate_config_file +class TelegramTopicsSettings(BaseModel): + model_config = ConfigDict(extra="forbid") + + enabled: bool = False + mode: str = "multi_project_chat" + + @field_validator("mode", mode="before") + @classmethod + def _validate_mode(cls, value: Any) -> str: + if not isinstance(value, str): + raise ValueError("topics.mode must be a string") + cleaned = value.strip() + if cleaned not in {"per_project_chat", "multi_project_chat"}: + raise ValueError( + "topics.mode must be 'per_project_chat' or 'multi_project_chat'" + ) + return cleaned + + class TelegramTransportSettings(BaseModel): model_config = ConfigDict(extra="forbid") bot_token: SecretStr | None = None chat_id: int | None = None voice_transcription: bool = False + topics: TelegramTopicsSettings = Field(default_factory=TelegramTopicsSettings) @field_validator("bot_token", mode="before") @classmethod @@ -241,7 +268,7 @@ class TakopiSettings(BaseSettings): raise ConfigError( f"Invalid `worktrees_dir` for project {alias!r} in {config_path}." ) - worktrees_dir = Path(worktrees_dir_raw.strip()) + worktrees_dir = Path(worktrees_dir_raw.strip()).expanduser() default_engine_raw = entry.default_engine default_engine = None @@ -401,30 +428,3 @@ def _load_settings_from_path(cfg_path: Path) -> TakopiSettings: raise ConfigError(f"Invalid config in {cfg_path}: {exc}") from exc except Exception as exc: # pragma: no cover - safety net raise ConfigError(f"Failed to load config {cfg_path}: {exc}") from exc - - -def _normalize_engine_id( - value: str, - *, - engine_ids: Iterable[str], - config_path: Path, - label: str, -) -> str: - engine_map = {engine.lower(): engine for engine in engine_ids} - cleaned = value.strip() - if not cleaned: - raise ConfigError(f"Invalid `{label}` in {config_path}; expected a string.") - engine = engine_map.get(cleaned.lower()) - if engine is None: - available = ", ".join(sorted(engine_map.values())) - raise ConfigError( - f"Unknown `{label}` {cleaned!r} in {config_path}. Available: {available}." - ) - return engine - - -def _normalize_project_path(value: str, *, config_path: Path) -> Path: - path = Path(value).expanduser() - if not path.is_absolute(): - path = config_path.parent / path - return path diff --git a/src/takopi/telegram/backend.py b/src/takopi/telegram/backend.py index 8688c45..c720ce1 100644 --- a/src/takopi/telegram/backend.py +++ b/src/takopi/telegram/backend.py @@ -9,13 +9,16 @@ from ..backends import EngineBackend from ..runner_bridge import ExecBridgeConfig from ..config import ConfigError from ..logging import get_logger -from ..settings import load_settings, require_telegram_config +from pydantic import ValidationError + +from ..settings import TelegramTopicsSettings, load_settings, require_telegram_config from ..transports import SetupResult, TransportBackend from ..transport_runtime import TransportRuntime from .bridge import ( TelegramBridgeConfig, TelegramPresenter, TelegramTransport, + TelegramTopicsConfig, TelegramVoiceTranscriptionConfig, run_main_loop, ) @@ -56,6 +59,26 @@ def _build_voice_transcription_config( ) +def _build_topics_config( + transport_config: dict[str, object], + *, + config_path: Path, +) -> TelegramTopicsConfig: + raw = transport_config.get("topics") or {} + if not isinstance(raw, dict): + raise ConfigError( + f"Invalid `transports.telegram.topics` in {config_path}; expected a table." + ) + try: + settings = TelegramTopicsSettings.model_validate(raw) + except ValidationError as exc: + raise ConfigError(f"Invalid topics config in {config_path}: {exc}") from exc + return TelegramTopicsConfig( + enabled=settings.enabled, + mode=settings.mode, + ) + + class TelegramBackend(TransportBackend): id = "telegram" description = "Telegram bot" @@ -111,6 +134,7 @@ class TelegramBackend(TransportBackend): final_notify=final_notify, ) voice_transcription = _build_voice_transcription_config(transport_config) + topics = _build_topics_config(transport_config, config_path=config_path) cfg = TelegramBridgeConfig( bot=bot, runtime=runtime, @@ -118,6 +142,7 @@ class TelegramBackend(TransportBackend): startup_msg=startup_msg, exec_cfg=exec_cfg, voice_transcription=voice_transcription, + topics=topics, ) async def run_loop() -> None: diff --git a/src/takopi/telegram/bridge.py b/src/takopi/telegram/bridge.py index 4b8eca3..4cbc5d2 100644 --- a/src/takopi/telegram/bridge.py +++ b/src/takopi/telegram/bridge.py @@ -47,6 +47,7 @@ from .types import ( TelegramIncomingUpdate, ) from .render import prepare_telegram +from .topic_state import TopicStateStore, TopicThreadSnapshot, resolve_state_path from .transcribe import transcribe_audio logger = get_logger(__name__) @@ -91,6 +92,154 @@ def _parse_slash_command(text: str) -> tuple[str | None, str]: return command.lower(), args_text +_TOPICS_COMMANDS = {"ctx", "new", "topic"} + + +def _topics_chat_project(cfg: TelegramBridgeConfig, chat_id: int) -> str | None: + context = cfg.runtime.default_context_for_chat(chat_id) + return context.project if context is not None else None + + +def _topics_chat_allowed(cfg: TelegramBridgeConfig, chat_id: int) -> bool: + if cfg.topics.mode == "per_project_chat": + return _topics_chat_project(cfg, chat_id) is not None + return chat_id == cfg.chat_id + + +def _topics_command_error(cfg: TelegramBridgeConfig, chat_id: int) -> str | None: + if cfg.topics.mode == "per_project_chat": + if _topics_chat_project(cfg, chat_id) is None: + return "topics commands are only available in project chats." + elif chat_id != cfg.chat_id: + return "topics commands are only available in the main chat." + return None + + +def _merge_topic_context( + *, chat_project: str | None, bound: RunContext | None +) -> RunContext | None: + if chat_project is None: + return bound + if bound is None: + return RunContext(project=chat_project, branch=None) + if bound.project is None: + return RunContext(project=chat_project, branch=bound.branch) + return bound + + +def _topic_key( + msg: TelegramIncomingMessage, cfg: TelegramBridgeConfig +) -> tuple[int, int] | None: + if not cfg.topics.enabled: + return None + if not _topics_chat_allowed(cfg, msg.chat_id): + return None + if msg.thread_id is None: + return None + return (msg.chat_id, msg.thread_id) + + +def _format_context(runtime: TransportRuntime, context: RunContext | None) -> str: + if context is None or context.project is None: + return "none" + project = runtime.project_alias_for_key(context.project) + if context.branch: + return f"{project} @{context.branch}" + return project + + +def _usage_ctx_set(cfg: TelegramBridgeConfig) -> str: + if cfg.topics.mode == "per_project_chat": + return "usage: /ctx set [@branch]" + return "usage: /ctx set [@branch]" + + +def _usage_topic(cfg: TelegramBridgeConfig) -> str: + if cfg.topics.mode == "per_project_chat": + return "usage: /topic @branch" + return "usage: /topic @branch" + + +def _parse_project_branch_args( + args_text: str, + *, + runtime: TransportRuntime, + cfg: TelegramBridgeConfig, + require_branch: bool, + chat_project: str | None, +) -> tuple[RunContext | None, str | None]: + tokens = _split_command_args(args_text) + if not tokens: + return None, _usage_topic(cfg) if require_branch else _usage_ctx_set(cfg) + if len(tokens) > 2: + return None, "too many arguments" + project_token: str | None = None + branch: str | None = None + first = tokens[0] + if first.startswith("@"): + branch = first[1:] or None + else: + project_token = first + if len(tokens) == 2: + second = tokens[1] + if not second.startswith("@"): + return None, "branch must be prefixed with @" + branch = second[1:] or None + + project_key: str | None = None + if cfg.topics.mode == "per_project_chat": + if chat_project is None: + return None, "topics are only available in project chats" + if project_token is None: + project_key = chat_project + else: + normalized = runtime.normalize_project_key(project_token) + if normalized is None: + return None, f"unknown project {project_token!r}" + if normalized != chat_project: + expected = runtime.project_alias_for_key(chat_project) + return None, (f"project mismatch for this chat; expected {expected!r}.") + project_key = normalized + else: + if project_token is None: + return None, "project is required in multi_project_chat mode" + project_key = runtime.normalize_project_key(project_token) + if project_key is None: + return None, f"unknown project {project_token!r}" + + if require_branch and not branch: + return None, "branch is required" + + return RunContext(project=project_key, branch=branch), None + + +def _format_ctx_status( + *, + cfg: TelegramBridgeConfig, + runtime: TransportRuntime, + bound: RunContext | None, + resolved: RunContext | None, + context_source: str, + snapshot: TopicThreadSnapshot | None, +) -> str: + lines = [ + f"topics: enabled ({cfg.topics.mode})", + f"bound ctx: {_format_context(runtime, bound)}", + f"resolved ctx: {_format_context(runtime, resolved)} (source: {context_source})", + ] + if cfg.topics.mode == "multi_project_chat" and bound is None: + topic_usage = _usage_topic(cfg).removeprefix("usage: ").strip() + ctx_usage = _usage_ctx_set(cfg).removeprefix("usage: ").strip() + lines.append( + f"note: unbound topic — bind with `{topic_usage}` or `{ctx_usage}`" + ) + sessions = None + if snapshot is not None and snapshot.sessions: + sessions = ", ".join(sorted(snapshot.sessions)) + lines.append(f"sessions: {sessions or 'none'}") + return "\n".join(lines) + + def _build_bot_commands(runtime: TransportRuntime) -> list[dict[str, str]]: commands: list[dict[str, str]] = [] seen: set[str] = set() @@ -263,6 +412,12 @@ class TelegramVoiceTranscriptionConfig: enabled: bool = False +@dataclass(frozen=True) +class TelegramTopicsConfig: + enabled: bool = False + mode: str = "multi_project_chat" + + def _as_int(value: int | str, *, label: str) -> int: if isinstance(value, bool) or not isinstance(value, int): raise TypeError(f"Telegram {label} must be int") @@ -286,6 +441,7 @@ class TelegramTransport: chat_id = _as_int(channel_id, label="chat_id") reply_to_message_id: int | None = None replace_message_id: int | None = None + message_thread_id: int | None = None disable_notification = None if options is not None: disable_notification = not options.notify @@ -297,6 +453,10 @@ class TelegramTransport: replace_message_id = _as_int( options.replace.message_id, label="replace_message_id" ) + if options.thread_id is not None: + message_thread_id = _as_int( + options.thread_id, label="message_thread_id" + ) entities = message.extra.get("entities") parse_mode = message.extra.get("parse_mode") reply_markup = message.extra.get("reply_markup") @@ -305,6 +465,7 @@ class TelegramTransport: text=message.text, reply_to_message_id=reply_to_message_id, disable_notification=disable_notification, + message_thread_id=message_thread_id, entities=entities, parse_mode=parse_mode, reply_markup=reply_markup, @@ -363,6 +524,7 @@ class TelegramBridgeConfig: exec_cfg: ExecBridgeConfig voice_transcription: TelegramVoiceTranscriptionConfig | None = None chat_ids: tuple[int, ...] | None = None + topics: TelegramTopicsConfig = TelegramTopicsConfig() def _allowed_chat_ids(cfg: TelegramBridgeConfig) -> set[int]: @@ -401,6 +563,62 @@ async def _send_startup(cfg: TelegramBridgeConfig) -> None: logger.info("startup.sent", chat_id=cfg.chat_id) +async def _validate_topics_setup(cfg: TelegramBridgeConfig) -> None: + if not cfg.topics.enabled: + return + me = await cfg.bot.get_me() + bot_id = me.get("id") if isinstance(me, dict) else None + if not isinstance(bot_id, int): + raise ConfigError("Failed to fetch bot id for topics validation.") + if cfg.topics.mode == "per_project_chat": + chat_ids = cfg.runtime.project_chat_ids() + if not chat_ids: + raise ConfigError( + "Topics enabled but no project chats are configured; " + "set projects..chat_id for forum chats." + ) + else: + chat_ids = (cfg.chat_id,) + + for chat_id in chat_ids: + chat = await cfg.bot.get_chat(chat_id) + if not isinstance(chat, dict): + raise ConfigError( + f"Failed to fetch chat info for topics validation ({chat_id})." + ) + chat_type = chat.get("type") + is_forum = chat.get("is_forum") + if chat_type != "supergroup": + raise ConfigError( + "Topics enabled but chat is not a supergroup; convert the group " + "and enable Topics." + ) + if is_forum is not True: + raise ConfigError( + "Topics enabled but chat does not have Topics enabled; " + "turn on Topics in group settings." + ) + member = await cfg.bot.get_chat_member(chat_id, bot_id) + if not isinstance(member, dict): + raise ConfigError( + "Failed to fetch bot permissions; promote the bot to admin with " + "Manage Topics." + ) + status = member.get("status") + if status == "creator": + continue + if status != "administrator": + raise ConfigError( + "Topics enabled but bot is not an admin; promote it and grant " + "Manage Topics." + ) + if member.get("can_manage_topics") is not True: + raise ConfigError( + "Topics enabled but bot lacks Manage Topics permission; " + "grant can_manage_topics." + ) + + async def _drain_backlog(cfg: TelegramBridgeConfig, offset: int | None) -> int | None: drained = 0 while True: @@ -555,6 +773,279 @@ async def _transcribe_voice( return transcript +def _topic_title( + *, cfg: TelegramBridgeConfig, runtime: TransportRuntime, context: RunContext +) -> str: + project = ( + runtime.project_alias_for_key(context.project) + if context.project is not None + else "" + ) + if context.branch: + if project: + return f"{project} @{context.branch}" + return f"@{context.branch}" + return project or "topic" + + +async def _maybe_rename_topic( + cfg: TelegramBridgeConfig, + store: TopicStateStore, + *, + chat_id: int, + thread_id: int, + context: RunContext, + snapshot: TopicThreadSnapshot | None = None, +) -> None: + title = _topic_title(cfg=cfg, runtime=cfg.runtime, context=context) + if snapshot is None: + snapshot = await store.get_thread(chat_id, thread_id) + if snapshot is not None and snapshot.topic_title == title: + return + updated = await cfg.bot.edit_forum_topic( + chat_id=chat_id, + message_thread_id=thread_id, + name=title, + ) + if not updated: + logger.warning( + "topics.rename.failed", + chat_id=chat_id, + thread_id=thread_id, + title=title, + ) + return + await store.set_context(chat_id, thread_id, context, topic_title=title) + + +async def _handle_ctx_command( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + store: TopicStateStore, +) -> None: + error = _topics_command_error(cfg, msg.chat_id) + if error is not None: + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text=error, + ) + return + chat_project = ( + _topics_chat_project(cfg, msg.chat_id) + if cfg.topics.mode == "per_project_chat" + else None + ) + tkey = _topic_key(msg, cfg) + if tkey is None: + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text="this command only works inside a topic.", + ) + return + tokens = _split_command_args(args_text) + action = tokens[0].lower() if tokens else "show" + if action in {"show", ""}: + snapshot = await store.get_thread(*tkey) + bound = snapshot.context if snapshot is not None else None + ambient = _merge_topic_context(chat_project=chat_project, bound=bound) + resolved = cfg.runtime.resolve_message( + text="", + reply_text=msg.reply_to_text, + chat_id=msg.chat_id, + ambient_context=ambient, + ) + text = _format_ctx_status( + cfg=cfg, + runtime=cfg.runtime, + bound=bound, + resolved=resolved.context, + context_source=resolved.context_source, + snapshot=snapshot, + ) + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text=text, + ) + return + if action == "set": + rest = " ".join(tokens[1:]) + context, error = _parse_project_branch_args( + rest, + runtime=cfg.runtime, + cfg=cfg, + require_branch=False, + chat_project=chat_project, + ) + if error is not None: + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text=f"error:\n{error}\n{_usage_ctx_set(cfg)}", + ) + return + if context is None: + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text=f"error:\n{_usage_ctx_set(cfg)}", + ) + return + await store.set_context(*tkey, context) + await _maybe_rename_topic( + cfg, + store, + chat_id=tkey[0], + thread_id=tkey[1], + context=context, + ) + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text=f"topic bound to {_format_context(cfg.runtime, context)}", + ) + return + if action == "clear": + await store.clear_context(*tkey) + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text="topic binding cleared.", + ) + return + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text="unknown /ctx command. use /ctx, /ctx set, or /ctx clear.", + ) + + +async def _handle_new_command( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + store: TopicStateStore, +) -> None: + error = _topics_command_error(cfg, msg.chat_id) + if error is not None: + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text=error, + ) + return + tkey = _topic_key(msg, cfg) + if tkey is None: + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text="this command only works inside a topic.", + ) + return + await store.clear_sessions(*tkey) + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text="cleared stored sessions for this topic.", + ) + + +async def _handle_topic_command( + cfg: TelegramBridgeConfig, + msg: TelegramIncomingMessage, + args_text: str, + store: TopicStateStore, +) -> None: + error = _topics_command_error(cfg, msg.chat_id) + if error is not None: + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text=error, + ) + return + chat_project = ( + _topics_chat_project(cfg, msg.chat_id) + if cfg.topics.mode == "per_project_chat" + else None + ) + context, error = _parse_project_branch_args( + args_text, + runtime=cfg.runtime, + cfg=cfg, + require_branch=True, + chat_project=chat_project, + ) + if error is not None or context is None: + usage = _usage_topic(cfg) + text = f"error:\n{error}\n{usage}" if error else usage + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text=text, + ) + return + target_chat_id = ( + msg.chat_id if cfg.topics.mode == "per_project_chat" else cfg.chat_id + ) + existing = await store.find_thread_for_context(target_chat_id, context) + if existing is not None: + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text=f"topic already exists for {_format_context(cfg.runtime, context)} " + "in this chat.", + ) + return + title = _topic_title(cfg=cfg, runtime=cfg.runtime, context=context) + created = await cfg.bot.create_forum_topic(target_chat_id, title) + thread_id = created.get("message_thread_id") if isinstance(created, dict) else None + if isinstance(thread_id, bool) or not isinstance(thread_id, int): + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text="failed to create topic.", + ) + return + await store.set_context( + target_chat_id, + thread_id, + context, + topic_title=title, + created_by_bot=True, + ) + await _send_plain( + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + text=f"created topic {title!r}.", + ) + await cfg.exec_cfg.transport.send( + channel_id=target_chat_id, + message=RenderedMessage( + text=f"topic bound to {_format_context(cfg.runtime, context)}" + ), + options=SendOptions(thread_id=thread_id), + ) + + async def _handle_cancel( cfg: TelegramBridgeConfig, msg: TelegramIncomingMessage, @@ -650,10 +1141,13 @@ async def _wait_for_resume(running_task: RunningTask) -> ResumeToken | None: async def _send_with_resume( cfg: TelegramBridgeConfig, - enqueue: Callable[[int, int, str, ResumeToken, RunContext | None], Awaitable[None]], + enqueue: Callable[ + [int, int, str, ResumeToken, RunContext | None, int | None], Awaitable[None] + ], running_task: RunningTask, chat_id: int, user_msg_id: int, + thread_id: int | None, text: str, ) -> None: resume = await _wait_for_resume(running_task) @@ -666,7 +1160,14 @@ async def _send_with_resume( notify=False, ) return - await enqueue(chat_id, user_msg_id, text, resume, running_task.context) + await enqueue( + chat_id, + user_msg_id, + text, + resume, + running_task.context, + thread_id, + ) async def _send_runner_unavailable( @@ -1031,8 +1532,22 @@ async def run_main_loop( transport_snapshot = ( dict(transport_config) if transport_config is not None else None ) + topic_store: TopicStateStore | None = None try: + if cfg.topics.enabled: + config_path = cfg.runtime.config_path + if config_path is None: + raise ConfigError( + "Topics enabled but config path is not set; cannot locate state file." + ) + topic_store = TopicStateStore(resolve_state_path(config_path)) + await _validate_topics_setup(cfg) + logger.info( + "topics.enabled", + mode=cfg.topics.mode, + state_path=str(resolve_state_path(config_path)), + ) await _set_command_menu(cfg) async with anyio.create_task_group() as tg: config_path = cfg.runtime.config_path @@ -1077,17 +1592,42 @@ async def run_main_loop( tg.start_soon(run_config_watch) + def wrap_on_thread_known( + base_cb: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None, + topic_key: tuple[int, int] | None, + ) -> Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None: + if base_cb is None and topic_key is None: + return None + + async def _wrapped(token: ResumeToken, done: anyio.Event) -> None: + if base_cb is not None: + await base_cb(token, done) + if topic_store is not None and topic_key is not None: + await topic_store.set_session_resume( + topic_key[0], topic_key[1], token + ) + + return _wrapped + async def run_job( chat_id: int, user_msg_id: int, text: str, resume_token: ResumeToken | None, context: RunContext | None, + thread_id: int | None = None, reply_ref: MessageRef | None = None, on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None = None, engine_override: EngineId | None = None, ) -> None: + topic_key = ( + (chat_id, thread_id) + if topic_store is not None + and thread_id is not None + and _topics_chat_allowed(cfg, chat_id) + else None + ) await _run_engine( exec_cfg=cfg.exec_cfg, runtime=cfg.runtime, @@ -1098,7 +1638,7 @@ async def run_main_loop( resume_token=resume_token, context=context, reply_ref=reply_ref, - on_thread_known=on_thread_known, + on_thread_known=wrap_on_thread_known(on_thread_known, topic_key), engine_override=engine_override, ) @@ -1109,7 +1649,9 @@ async def run_main_loop( job.text, job.resume_token, job.context, + job.thread_id, None, + scheduler.note_thread_known, ) scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job) @@ -1137,12 +1679,42 @@ async def run_main_loop( if reply_id is not None else None ) + topic_key = _topic_key(msg, cfg) if topic_store is not None else None + chat_project = ( + _topics_chat_project(cfg, chat_id) + if cfg.topics.enabled and cfg.topics.mode == "per_project_chat" + else None + ) + bound_context = ( + await topic_store.get_context(*topic_key) + if topic_store is not None and topic_key is not None + else None + ) + ambient_context = _merge_topic_context( + chat_project=chat_project, bound=bound_context + ) if _is_cancel_command(text): tg.start_soon(_handle_cancel, cfg, msg, running_tasks) continue command_id, args_text = _parse_slash_command(text) + if ( + cfg.topics.enabled + and topic_store is not None + and command_id in _TOPICS_COMMANDS + ): + if command_id == "ctx": + tg.start_soon( + _handle_ctx_command, cfg, msg, args_text, topic_store + ) + elif command_id == "new": + tg.start_soon(_handle_new_command, cfg, msg, topic_store) + else: + tg.start_soon( + _handle_topic_command, cfg, msg, args_text, topic_store + ) + continue if ( command_id is not None and command_id not in command_cache.reserved_commands @@ -1167,6 +1739,7 @@ async def run_main_loop( resolved = cfg.runtime.resolve_message( text=text, reply_text=reply_text, + ambient_context=ambient_context, chat_id=chat_id, ) except DirectiveError as exc: @@ -1182,6 +1755,37 @@ async def run_main_loop( resume_token = resolved.resume_token engine_override = resolved.engine_override context = resolved.context + if ( + topic_store is not None + and topic_key is not None + and resolved.context is not None + and resolved.context_source == "directives" + ): + await topic_store.set_context(*topic_key, resolved.context) + await _maybe_rename_topic( + cfg, + topic_store, + chat_id=topic_key[0], + thread_id=topic_key[1], + context=resolved.context, + ) + ambient_context = resolved.context + if ( + topic_store is not None + and topic_key is not None + and ambient_context is None + and resolved.context_source not in {"directives", "reply_ctx"} + ): + await _send_plain( + cfg.exec_cfg.transport, + chat_id=chat_id, + user_msg_id=user_msg_id, + text=( + "this topic isn't bound to a project yet.\n" + f"{_usage_ctx_set(cfg)} or {_usage_topic(cfg)}" + ), + ) + continue if resume_token is None and reply_id is not None: running_task = running_tasks.get( MessageRef(channel_id=chat_id, message_id=reply_id) @@ -1194,9 +1798,24 @@ async def run_main_loop( running_task, chat_id, user_msg_id, + msg.thread_id, text, ) continue + if ( + resume_token is None + and topic_store is not None + and topic_key is not None + ): + engine_for_session = cfg.runtime.resolve_engine( + engine_override=engine_override, + context=context, + ) + stored = await topic_store.get_session_resume( + topic_key[0], topic_key[1], engine_for_session + ) + if stored is not None: + resume_token = stored if resume_token is None: tg.start_soon( @@ -1206,6 +1825,7 @@ async def run_main_loop( text, None, context, + msg.thread_id, reply_ref, scheduler.note_thread_known, engine_override, @@ -1217,6 +1837,7 @@ async def run_main_loop( text, resume_token, context, + msg.thread_id, ) finally: await cfg.exec_cfg.transport.close() diff --git a/src/takopi/telegram/client.py b/src/takopi/telegram/client.py index 2bcbc3b..39dfc21 100644 --- a/src/takopi/telegram/client.py +++ b/src/takopi/telegram/client.py @@ -105,6 +105,10 @@ def _parse_incoming_message( msg_chat_id = chat.get("id") if not isinstance(msg_chat_id, int): return None + chat_type = chat.get("type") if isinstance(chat.get("type"), str) else None + is_forum = chat.get("is_forum") + if not isinstance(is_forum, bool): + is_forum = None allowed = chat_ids if allowed is None and chat_id is not None: allowed = {chat_id} @@ -131,6 +135,12 @@ def _parse_incoming_message( if isinstance(sender, dict) and isinstance(sender.get("id"), int) else None ) + thread_id = msg.get("message_thread_id") + if isinstance(thread_id, bool) or not isinstance(thread_id, int): + thread_id = None + is_topic_message = msg.get("is_topic_message") + if not isinstance(is_topic_message, bool): + is_topic_message = None return TelegramIncomingMessage( transport="telegram", chat_id=msg_chat_id, @@ -139,6 +149,10 @@ def _parse_incoming_message( reply_to_message_id=reply_to_message_id, reply_to_text=reply_to_text, sender_id=sender_id, + thread_id=thread_id, + is_topic_message=is_topic_message, + chat_type=chat_type, + is_forum=is_forum, voice=voice_payload, raw=msg, ) @@ -237,6 +251,7 @@ class BotClient(Protocol): text: str, reply_to_message_id: int | None = None, disable_notification: bool | None = False, + message_thread_id: int | None = None, entities: list[dict] | None = None, parse_mode: str | None = None, reply_markup: dict[str, Any] | None = None, @@ -279,6 +294,23 @@ class BotClient(Protocol): show_alert: bool | None = None, ) -> bool: ... + async def get_chat(self, chat_id: int) -> dict | None: ... + + async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None: ... + + async def create_forum_topic( + self, + chat_id: int, + name: str, + ) -> dict | None: ... + + async def edit_forum_topic( + self, + chat_id: int, + message_thread_id: int, + name: str, + ) -> bool: ... + if TYPE_CHECKING: from anyio.abc import TaskGroup @@ -721,6 +753,7 @@ class TelegramClient: text: str, reply_to_message_id: int | None = None, disable_notification: bool | None = False, + message_thread_id: int | None = None, entities: list[dict] | None = None, parse_mode: str | None = None, reply_markup: dict[str, Any] | None = None, @@ -734,6 +767,7 @@ class TelegramClient: text=text, reply_to_message_id=reply_to_message_id, disable_notification=disable_notification, + message_thread_id=message_thread_id, entities=entities, parse_mode=parse_mode, reply_markup=reply_markup, @@ -744,6 +778,8 @@ class TelegramClient: params["disable_notification"] = disable_notification if reply_to_message_id is not None: params["reply_to_message_id"] = reply_to_message_id + if message_thread_id is not None: + params["message_thread_id"] = message_thread_id if entities is not None: params["entities"] = entities if parse_mode is not None: @@ -921,3 +957,80 @@ class TelegramClient: chat_id=None, ) ) + + async def get_chat(self, chat_id: int) -> dict | None: + async def execute() -> dict | None: + if self._client_override is not None: + return await self._client_override.get_chat(chat_id) + result = await self._post("getChat", {"chat_id": chat_id}) + return result if isinstance(result, dict) else None + + return await self.enqueue_op( + key=self.unique_key("get_chat"), + label="get_chat", + execute=execute, + priority=SEND_PRIORITY, + chat_id=chat_id, + ) + + async def get_chat_member(self, chat_id: int, user_id: int) -> dict | None: + async def execute() -> dict | None: + if self._client_override is not None: + return await self._client_override.get_chat_member(chat_id, user_id) + result = await self._post( + "getChatMember", {"chat_id": chat_id, "user_id": user_id} + ) + return result if isinstance(result, dict) else None + + return await self.enqueue_op( + key=self.unique_key("get_chat_member"), + label="get_chat_member", + execute=execute, + priority=SEND_PRIORITY, + chat_id=chat_id, + ) + + async def create_forum_topic(self, chat_id: int, name: str) -> dict | None: + async def execute() -> dict | None: + if self._client_override is not None: + return await self._client_override.create_forum_topic(chat_id, name) + result = await self._post( + "createForumTopic", {"chat_id": chat_id, "name": name} + ) + return result if isinstance(result, dict) else None + + return await self.enqueue_op( + key=self.unique_key("create_forum_topic"), + label="create_forum_topic", + execute=execute, + priority=SEND_PRIORITY, + chat_id=chat_id, + ) + + async def edit_forum_topic( + self, chat_id: int, message_thread_id: int, name: str + ) -> bool: + async def execute() -> bool: + if self._client_override is not None: + return await self._client_override.edit_forum_topic( + chat_id, message_thread_id, name + ) + result = await self._post( + "editForumTopic", + { + "chat_id": chat_id, + "message_thread_id": message_thread_id, + "name": name, + }, + ) + return bool(result) + + return bool( + await self.enqueue_op( + key=self.unique_key("edit_forum_topic"), + label="edit_forum_topic", + execute=execute, + priority=SEND_PRIORITY, + chat_id=chat_id, + ) + ) diff --git a/src/takopi/telegram/onboarding.py b/src/takopi/telegram/onboarding.py index b74f3ce..aa3145a 100644 --- a/src/takopi/telegram/onboarding.py +++ b/src/takopi/telegram/onboarding.py @@ -344,6 +344,41 @@ def _prompt_token(console: Console) -> tuple[str, dict[str, Any]] | None: return None +def capture_chat_id(*, token: str | None = None) -> ChatInfo | None: + console = Console() + with _suppress_logging(): + if token is not None: + token = token.strip() + if not token: + console.print(" token cannot be empty") + return None + console.print(" validating...") + info = anyio.run(_get_bot_info, token) + if not info: + console.print(" failed to connect, check the token and try again") + return None + else: + token_info = _prompt_token(console) + if token_info is None: + return None + token, info = token_info + + bot_ref = f"@{info['username']}" + console.print("") + console.print(f" send /start to {bot_ref} (works in groups too)") + console.print(" waiting...") + try: + chat = anyio.run(_wait_for_chat, token) + except KeyboardInterrupt: + console.print(" cancelled") + return None + if chat is None: + console.print(" cancelled") + return None + console.print(f" got chat_id {chat.chat_id} from {chat.display}") + return chat + + def interactive_setup(*, force: bool) -> bool: console = Console() config_path = HOME_CONFIG_PATH diff --git a/src/takopi/telegram/topic_state.py b/src/takopi/telegram/topic_state.py new file mode 100644 index 0000000..3ed140b --- /dev/null +++ b/src/takopi/telegram/topic_state.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +import json +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, cast + +import anyio + +from ..context import RunContext +from ..logging import get_logger +from ..model import ResumeToken + +logger = get_logger(__name__) + +STATE_VERSION = 1 +STATE_FILENAME = "telegram_topics_state.json" + + +@dataclass(frozen=True, slots=True) +class TopicThreadSnapshot: + chat_id: int + thread_id: int + context: RunContext | None + sessions: dict[str, str] + topic_title: str | None + created_by_bot: bool | None + updated_at: float | None + + +def resolve_state_path(config_path: Path) -> Path: + return config_path.with_name(STATE_FILENAME) + + +def _thread_key(chat_id: int, thread_id: int) -> str: + return f"{chat_id}:{thread_id}" + + +def _parse_context(raw: object) -> RunContext | None: + if not isinstance(raw, dict): + return None + payload = cast(dict[str, object], raw) + project = payload.get("project") + branch = payload.get("branch") + if project is not None and not isinstance(project, str): + project = None + if isinstance(project, str): + project = project.strip() or None + if branch is not None and not isinstance(branch, str): + branch = None + if isinstance(branch, str): + branch = branch.strip() or None + if project is None and branch is None: + return None + return RunContext(project=project, branch=branch) + + +def _dump_context(context: RunContext | None) -> dict[str, str] | None: + if context is None or (context.project is None and context.branch is None): + return None + payload: dict[str, str] = {} + if context.project is not None: + payload["project"] = context.project + if context.branch is not None: + payload["branch"] = context.branch + return payload or None + + +class TopicStateStore: + def __init__(self, path: Path) -> None: + self._path = path + self._lock = anyio.Lock() + self._loaded = False + self._mtime_ns: int | None = None + self._data: dict[str, Any] = { + "version": STATE_VERSION, + "threads": {}, + } + + async def get_thread( + self, chat_id: int, thread_id: int + ) -> TopicThreadSnapshot | None: + async with self._lock: + self._reload_locked_if_needed() + thread = self._get_thread_locked(chat_id, thread_id) + if thread is None: + return None + return self._snapshot_locked(thread, chat_id, thread_id) + + async def get_context(self, chat_id: int, thread_id: int) -> RunContext | None: + async with self._lock: + self._reload_locked_if_needed() + thread = self._get_thread_locked(chat_id, thread_id) + if thread is None: + return None + return _parse_context(thread.get("context")) + + async def set_context( + self, + chat_id: int, + thread_id: int, + context: RunContext, + *, + topic_title: str | None = None, + created_by_bot: bool | None = None, + ) -> None: + async with self._lock: + self._reload_locked_if_needed() + thread = self._ensure_thread_locked(chat_id, thread_id) + thread["context"] = _dump_context(context) + if topic_title is not None: + thread["topic_title"] = topic_title + if created_by_bot is not None: + thread["created_by_bot"] = created_by_bot + thread["updated_at"] = time.time() + self._save_locked() + + async def clear_context(self, chat_id: int, thread_id: int) -> None: + async with self._lock: + self._reload_locked_if_needed() + thread = self._get_thread_locked(chat_id, thread_id) + if thread is None: + return + thread.pop("context", None) + thread["updated_at"] = time.time() + self._save_locked() + + async def get_session_resume( + self, chat_id: int, thread_id: int, engine: str + ) -> ResumeToken | None: + async with self._lock: + self._reload_locked_if_needed() + thread = self._get_thread_locked(chat_id, thread_id) + if thread is None: + return None + sessions = thread.get("sessions") + if not isinstance(sessions, dict): + return None + entry = sessions.get(engine) + if not isinstance(entry, dict): + return None + value = entry.get("resume") + if not isinstance(value, str) or not value: + return None + return ResumeToken(engine=engine, value=value) + + async def set_session_resume( + self, chat_id: int, thread_id: int, token: ResumeToken + ) -> None: + async with self._lock: + self._reload_locked_if_needed() + thread = self._ensure_thread_locked(chat_id, thread_id) + sessions = thread.get("sessions") + if not isinstance(sessions, dict): + sessions = {} + thread["sessions"] = sessions + sessions[token.engine] = { + "resume": token.value, + "updated_at": time.time(), + } + thread["updated_at"] = time.time() + self._save_locked() + + async def clear_sessions(self, chat_id: int, thread_id: int) -> None: + async with self._lock: + self._reload_locked_if_needed() + thread = self._get_thread_locked(chat_id, thread_id) + if thread is None: + return + thread.pop("sessions", None) + thread["updated_at"] = time.time() + self._save_locked() + + async def find_thread_for_context( + self, chat_id: int, context: RunContext + ) -> int | None: + async with self._lock: + self._reload_locked_if_needed() + threads = self._data.get("threads") + if not isinstance(threads, dict): + return None + for raw_key, payload in threads.items(): + if not isinstance(raw_key, str) or not isinstance(payload, dict): + continue + parsed = _parse_context(payload.get("context")) + if parsed is None: + continue + if parsed.project != context.project or parsed.branch != context.branch: + continue + if not raw_key.startswith(f"{chat_id}:"): + continue + try: + _, thread_str = raw_key.split(":", 1) + return int(thread_str) + except (ValueError, TypeError): + continue + return None + + def _snapshot_locked( + self, thread: dict[str, Any], chat_id: int, thread_id: int + ) -> TopicThreadSnapshot: + sessions: dict[str, str] = {} + raw_sessions = thread.get("sessions") + if isinstance(raw_sessions, dict): + for engine, entry in raw_sessions.items(): + if not isinstance(engine, str) or not isinstance(entry, dict): + continue + value = entry.get("resume") + if isinstance(value, str) and value: + sessions[engine] = value + updated_at = thread.get("updated_at") + if not isinstance(updated_at, (int, float)): + updated_at = None + topic_title = thread.get("topic_title") + if not isinstance(topic_title, str): + topic_title = None + created_by_bot = thread.get("created_by_bot") + if not isinstance(created_by_bot, bool): + created_by_bot = None + return TopicThreadSnapshot( + chat_id=chat_id, + thread_id=thread_id, + context=_parse_context(thread.get("context")), + sessions=sessions, + topic_title=topic_title, + created_by_bot=created_by_bot, + updated_at=updated_at, + ) + + def _stat_mtime_ns(self) -> int | None: + try: + return self._path.stat().st_mtime_ns + except FileNotFoundError: + return None + + def _reload_locked_if_needed(self) -> None: + current = self._stat_mtime_ns() + if self._loaded and current == self._mtime_ns: + return + self._load_locked() + + def _load_locked(self) -> None: + self._loaded = True + self._mtime_ns = self._stat_mtime_ns() + if self._mtime_ns is None: + self._data = {"version": STATE_VERSION, "threads": {}} + return + try: + payload = json.loads(self._path.read_text()) + except Exception as exc: + logger.warning( + "telegram.topic_state.load_failed", + path=str(self._path), + error=str(exc), + error_type=exc.__class__.__name__, + ) + self._data = {"version": STATE_VERSION, "threads": {}} + return + if not isinstance(payload, dict): + self._data = {"version": STATE_VERSION, "threads": {}} + return + version = payload.get("version") + if version != STATE_VERSION: + logger.warning( + "telegram.topic_state.version_mismatch", + path=str(self._path), + version=version, + expected=STATE_VERSION, + ) + self._data = {"version": STATE_VERSION, "threads": {}} + return + threads = payload.get("threads") + if not isinstance(threads, dict): + threads = {} + self._data = {"version": STATE_VERSION, "threads": threads} + + def _save_locked(self) -> None: + self._path.parent.mkdir(parents=True, exist_ok=True) + payload = {"version": STATE_VERSION, "threads": self._data.get("threads", {})} + tmp_path = self._path.with_suffix(f"{self._path.suffix}.tmp") + with open(tmp_path, "w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + os.replace(tmp_path, self._path) + self._mtime_ns = self._stat_mtime_ns() + + def _get_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any] | None: + threads = self._data.get("threads") + if not isinstance(threads, dict): + return None + entry = threads.get(_thread_key(chat_id, thread_id)) + return entry if isinstance(entry, dict) else None + + def _ensure_thread_locked(self, chat_id: int, thread_id: int) -> dict[str, Any]: + threads = self._data.get("threads") + if not isinstance(threads, dict): + threads = {} + self._data["threads"] = threads + key = _thread_key(chat_id, thread_id) + entry = threads.get(key) + if isinstance(entry, dict): + return entry + entry = {"chat_id": chat_id, "thread_id": thread_id} + threads[key] = entry + return entry diff --git a/src/takopi/telegram/types.py b/src/takopi/telegram/types.py index c8e7f7e..7ddc6f5 100644 --- a/src/takopi/telegram/types.py +++ b/src/takopi/telegram/types.py @@ -22,6 +22,10 @@ class TelegramIncomingMessage: reply_to_message_id: int | None reply_to_text: str | None sender_id: int | None + thread_id: int | None = None + is_topic_message: bool | None = None + chat_type: str | None = None + is_forum: bool | None = None voice: TelegramVoice | None = None raw: dict[str, Any] | None = None diff --git a/src/takopi/transport.py b/src/takopi/transport.py index b0a8789..a17d74c 100644 --- a/src/takopi/transport.py +++ b/src/takopi/transport.py @@ -25,6 +25,7 @@ class SendOptions: reply_to: MessageRef | None = None notify: bool = True replace: MessageRef | None = None + thread_id: int | None = None class Transport(Protocol): diff --git a/src/takopi/transport_runtime.py b/src/takopi/transport_runtime.py index 4b074b8..9189ff7 100644 --- a/src/takopi/transport_runtime.py +++ b/src/takopi/transport_runtime.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Iterable, Mapping from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, Literal from .config import ConfigError, ProjectsConfig from .context import RunContext @@ -21,6 +21,13 @@ class ResolvedMessage: resume_token: ResumeToken | None engine_override: EngineId | None context: RunContext | None + context_source: Literal[ + "reply_ctx", + "directives", + "ambient", + "default_project", + "none", + ] = "none" @dataclass(frozen=True, slots=True) @@ -130,6 +137,7 @@ class TransportRuntime: *, text: str, reply_text: str | None, + ambient_context: RunContext | None = None, chat_id: int | None = None, ) -> ResolvedMessage: directives = parse_directives( @@ -140,52 +148,76 @@ class TransportRuntime: reply_ctx = parse_context_line(reply_text, projects=self._projects) resume_token = self._router.resolve_resume(directives.prompt, reply_text) chat_project = self._projects.project_for_chat(chat_id) + default_project = chat_project or self._projects.default_project - if resume_token is not None: - context = reply_ctx - if context is None and chat_project is not None: - context = RunContext(project=chat_project, branch=None) - return ResolvedMessage( - prompt=directives.prompt, - resume_token=resume_token, - engine_override=None, - context=context, - ) + context_source: Literal[ + "reply_ctx", + "directives", + "ambient", + "default_project", + "none", + ] = "none" + context: RunContext | None = None if reply_ctx is not None: - engine_override = None - if reply_ctx.project is not None: - project = self._projects.projects.get(reply_ctx.project) - if project is not None and project.default_engine is not None: - engine_override = project.default_engine - return ResolvedMessage( - prompt=directives.prompt, - resume_token=None, - engine_override=engine_override, - context=reply_ctx, - ) - - project_key = directives.project - if project_key is None: - project_key = chat_project or self._projects.default_project - - context = None - if project_key is not None or directives.branch is not None: - context = RunContext(project=project_key, branch=directives.branch) + context = reply_ctx + context_source = "reply_ctx" + else: + project_key = directives.project + branch = directives.branch + if project_key is None: + if ambient_context is not None and ambient_context.project is not None: + project_key = ambient_context.project + else: + project_key = default_project + if branch is None: + if ( + ambient_context is not None + and ambient_context.branch is not None + and project_key == ambient_context.project + ): + branch = ambient_context.branch + if project_key is not None or branch is not None: + context = RunContext(project=project_key, branch=branch) + if directives.project is not None or directives.branch is not None: + context_source = "directives" + elif ambient_context is not None and ambient_context.project is not None: + context_source = "ambient" + elif default_project is not None: + context_source = "default_project" engine_override = directives.engine - if engine_override is None and project_key is not None: - project = self._projects.projects.get(project_key) + if engine_override is None and context is not None: + project = ( + self._projects.projects.get(context.project) + if context.project is not None + else None + ) if project is not None and project.default_engine is not None: engine_override = project.default_engine return ResolvedMessage( prompt=directives.prompt, - resume_token=None, + resume_token=resume_token, engine_override=engine_override, context=context, + context_source=context_source, ) + @property + def default_project(self) -> str | None: + return self._projects.default_project + + def normalize_project_key(self, value: str) -> str | None: + key = value.strip().lower() + if key in self._projects.projects: + return key + return None + + def project_alias_for_key(self, key: str) -> str: + project = self._projects.projects.get(key) + return project.alias if project is not None else key + def default_context_for_chat(self, chat_id: int | None) -> RunContext | None: project_key = self._projects.project_for_chat(chat_id) if project_key is None: diff --git a/tests/test_cli_chat_id.py b/tests/test_cli_chat_id.py new file mode 100644 index 0000000..65cbb11 --- /dev/null +++ b/tests/test_cli_chat_id.py @@ -0,0 +1,70 @@ +from pathlib import Path + +from typer.testing import CliRunner + +from takopi import cli +from takopi.settings import TakopiSettings +from takopi.telegram import onboarding + + +def test_chat_id_command_updates_project_chat_id(monkeypatch, tmp_path) -> None: + config_path = tmp_path / "takopi.toml" + config_path.write_text( + '[projects.z80]\npath = "/tmp/repo"\n', + encoding="utf-8", + ) + monkeypatch.setattr("takopi.config.HOME_CONFIG_PATH", config_path) + monkeypatch.setattr(cli, "_load_settings_optional", lambda: (None, None)) + + def _capture(*, token: str | None = None): + assert token == "token" + return onboarding.ChatInfo( + chat_id=123, + username=None, + title="takopi", + first_name=None, + last_name=None, + chat_type="supergroup", + ) + + monkeypatch.setattr(cli.onboarding, "capture_chat_id", _capture) + + runner = CliRunner() + result = runner.invoke( + cli.create_app(), + ["chat-id", "--token", "token", "--project", "z80"], + ) + + assert result.exit_code == 0 + saved = config_path.read_text(encoding="utf-8") + assert "chat_id = 123" in saved + assert "updated projects.z80.chat_id = 123" in result.output + + +def test_chat_id_command_uses_config_token(monkeypatch) -> None: + settings = TakopiSettings.model_validate( + { + "transport": "telegram", + "transports": {"telegram": {"bot_token": "config-token"}}, + } + ) + monkeypatch.setattr(cli, "_load_settings_optional", lambda: (settings, Path("x"))) + + def _capture(*, token: str | None = None): + assert token == "config-token" + return onboarding.ChatInfo( + chat_id=321, + username=None, + title="takopi", + first_name=None, + last_name=None, + chat_type="supergroup", + ) + + monkeypatch.setattr(cli.onboarding, "capture_chat_id", _capture) + + runner = CliRunner() + result = runner.invoke(cli.create_app(), ["chat-id"]) + + assert result.exit_code == 0 + assert "chat_id = 321" in result.output diff --git a/tests/test_onboarding_interactive.py b/tests/test_onboarding_interactive.py index 7eeab0c..c1c2e68 100644 --- a/tests/test_onboarding_interactive.py +++ b/tests/test_onboarding_interactive.py @@ -226,3 +226,51 @@ def test_interactive_setup_recovers_from_malformed_toml(monkeypatch, tmp_path) - saved = config_path.read_text(encoding="utf-8") assert "[transports.telegram]" in saved assert 'bot_token = "123456789:ABCdef"' in saved + + +def test_capture_chat_id_with_token(monkeypatch) -> None: + def _fake_run(func, *args, **kwargs): + if func is onboarding._get_bot_info: + return {"username": "my_bot"} + if func is onboarding._wait_for_chat: + return onboarding.ChatInfo( + chat_id=456, + username=None, + title="takopi", + first_name=None, + last_name=None, + chat_type="supergroup", + ) + raise AssertionError(f"unexpected anyio.run target: {func}") + + monkeypatch.setattr(onboarding.anyio, "run", _fake_run) + + chat = onboarding.capture_chat_id(token="123456789:ABCdef") + + assert chat is not None + assert chat.chat_id == 456 + + +def test_capture_chat_id_prompts_for_token(monkeypatch) -> None: + monkeypatch.setattr( + onboarding, "_prompt_token", lambda _console: ("token", {"username": "bot"}) + ) + + def _fake_run(func, *args, **kwargs): + if func is onboarding._wait_for_chat: + return onboarding.ChatInfo( + chat_id=789, + username="alice", + title=None, + first_name="Alice", + last_name=None, + chat_type="private", + ) + raise AssertionError(f"unexpected anyio.run target: {func}") + + monkeypatch.setattr(onboarding.anyio, "run", _fake_run) + + chat = onboarding.capture_chat_id() + + assert chat is not None + assert chat.chat_id == 789 diff --git a/tests/test_runner_utils.py b/tests/test_runner_utils.py new file mode 100644 index 0000000..f03d01b --- /dev/null +++ b/tests/test_runner_utils.py @@ -0,0 +1,416 @@ +import re +from collections.abc import AsyncIterator +from typing import Any + +import pytest + +import takopi.runner as runner_module +from takopi.model import ( + ActionEvent, + CompletedEvent, + EngineId, + ResumeToken, + StartedEvent, + TakopiEvent, +) +from takopi.runner import ( + BaseRunner, + JsonlRunState, + JsonlSubprocessRunner, + ResumeTokenMixin, +) + + +class _DummyRunner(ResumeTokenMixin, BaseRunner): + engine = EngineId("dummy") + resume_re = re.compile(r"(?im)^`?dummy resume (?P[^`\s]+)`?$") + + async def run_impl( + self, prompt: str, resume: ResumeToken | None + ) -> AsyncIterator[StartedEvent | CompletedEvent]: + token = resume or ResumeToken(engine=self.engine, value="token") + yield StartedEvent(engine=self.engine, resume=token, title="dummy") + yield CompletedEvent( + engine=self.engine, + ok=True, + answer=prompt, + resume=token, + ) + + +class _DummyJsonlRunner(JsonlSubprocessRunner): + engine = EngineId("dummy-jsonl") + + def command(self) -> str: + return "dummy" + + def build_args( + self, + prompt: str, + resume: ResumeToken | None, + *, + state: object, + ) -> list[str]: + _ = prompt, resume, state + return [] + + def translate( + self, + data: Any, + *, + state: Any, + resume: ResumeToken | None, + found_session: ResumeToken | None, + ) -> list[TakopiEvent]: + _ = data, state, resume, found_session + return [] + + +class _BareJsonlRunner(JsonlSubprocessRunner): + engine = EngineId("bare-jsonl") + + +class _RunJsonlRunner(_DummyJsonlRunner): + def stdin_payload( + self, + prompt: str, + resume: ResumeToken | None, + *, + state: Any, + ) -> bytes | None: + _ = prompt, resume, state + return None + + async def iter_json_lines(self, stream: Any) -> AsyncIterator[bytes]: + _ = stream + yield b'{"type": "started", "resume": "sid"}' + yield b'{"type": "completed", "resume": "sid"}' + + def translate( + self, + data: Any, + *, + state: Any, + resume: ResumeToken | None, + found_session: ResumeToken | None, + ) -> list[TakopiEvent]: + _ = state, resume, found_session + token_value = "sid" + if isinstance(data, dict) and isinstance(data.get("resume"), str): + token_value = data["resume"] + token = ResumeToken(engine=self.engine, value=token_value) + if isinstance(data, dict) and data.get("type") == "started": + return [StartedEvent(engine=self.engine, resume=token, title="t")] + if isinstance(data, dict) and data.get("type") == "completed": + return [ + CompletedEvent(engine=self.engine, ok=True, answer="done", resume=token) + ] + return [] + + +class _BranchingJsonlRunner(_DummyJsonlRunner): + def stdin_payload( + self, + prompt: str, + resume: ResumeToken | None, + *, + state: Any, + ) -> bytes | None: + _ = prompt, resume, state + return None + + async def iter_json_lines(self, stream: Any) -> AsyncIterator[bytes]: + _ = stream + yield b"raise" + yield b"" + yield b"invalid" + yield b'{"type": "translate_error"}' + yield b'{"type": "started", "resume": "sid"}' + yield b'{"type": "started", "resume": "sid"}' + yield b'{"type": "completed", "resume": "sid"}' + yield b'{"type": "after"}' + + def decode_jsonl(self, *, line: bytes) -> Any | None: + if line == b"raise": + raise ValueError("boom") + if line == b"invalid": + return None + return super().decode_jsonl(line=line) + + def translate( + self, + data: Any, + *, + state: Any, + resume: ResumeToken | None, + found_session: ResumeToken | None, + ) -> list[TakopiEvent]: + _ = state, resume, found_session + if isinstance(data, dict) and data.get("type") == "translate_error": + raise RuntimeError("nope") + token_value = "sid" + if isinstance(data, dict) and isinstance(data.get("resume"), str): + token_value = data["resume"] + token = ResumeToken(engine=self.engine, value=token_value) + if isinstance(data, dict) and data.get("type") == "started": + return [StartedEvent(engine=self.engine, resume=token, title="t")] + if isinstance(data, dict) and data.get("type") == "completed": + return [ + CompletedEvent(engine=self.engine, ok=True, answer="done", resume=token) + ] + return [] + + +@pytest.mark.anyio +async def test_base_runner_run_locked_handles_resume() -> None: + runner = _DummyRunner() + events = [evt async for evt in runner.run("hello", None)] + assert isinstance(events[0], StartedEvent) + assert isinstance(events[-1], CompletedEvent) + + resume = ResumeToken(engine=runner.engine, value="resume") + resumed = [evt async for evt in runner.run("again", resume)] + assert isinstance(resumed[0], StartedEvent) + assert resumed[0].resume == resume + + +@pytest.mark.anyio +async def test_base_runner_rejects_wrong_resume_engine() -> None: + runner = _DummyRunner() + bad_resume = ResumeToken(engine=EngineId("other"), value="oops") + with pytest.raises(RuntimeError): + _ = [evt async for evt in runner.run("hello", bad_resume)] + + +@pytest.mark.anyio +async def test_base_runner_run_impl_not_implemented() -> None: + class _BareRunner(BaseRunner): + engine = EngineId("bare") + + runner = _BareRunner() + with pytest.raises(NotImplementedError): + _ = [evt async for evt in runner.run_impl("hello", None)] + + +def test_resume_token_format_and_extract() -> None: + runner = _DummyRunner() + token = ResumeToken(engine=runner.engine, value="abc") + assert runner.format_resume(token) == "`dummy resume abc`" + assert runner.is_resume_line("`dummy resume abc`") is True + text = "`dummy resume first`\n`dummy resume second`" + assert runner.extract_resume(text) == ResumeToken( + engine=runner.engine, value="second" + ) + assert runner.extract_resume(None) is None + + with pytest.raises(RuntimeError): + runner.format_resume(ResumeToken(engine=EngineId("other"), value="bad")) + + +def test_session_lock_reuse() -> None: + runner = _DummyRunner() + token = ResumeToken(engine=runner.engine, value="one") + lock1 = runner.lock_for(token) + lock2 = runner.lock_for(token) + other = runner.lock_for(ResumeToken(engine=runner.engine, value="two")) + assert lock1 is lock2 + assert other is not lock1 + + +@pytest.mark.anyio +async def test_run_with_resume_lock_passthrough() -> None: + runner = _DummyRunner() + events = [ + evt async for evt in runner.run_with_resume_lock("hello", None, runner.run_impl) + ] + assert events + + +def test_jsonl_helpers() -> None: + runner = _DummyJsonlRunner() + state = JsonlRunState() + + note1 = runner.next_note_id(state) + note2 = runner.next_note_id(state) + assert note1.endswith(".1") + assert note2.endswith(".2") + + event = runner.note_event("warn", state=state) + assert isinstance(event, ActionEvent) + assert event.action.detail == {} + + invalid = runner.invalid_json_events(raw="x", line="{}", state=state) + invalid_event = invalid[0] + assert isinstance(invalid_event, ActionEvent) + assert invalid_event.action.detail["line"] == "{}" + + assert runner.decode_jsonl(line=b'{"a": 1}') == {"a": 1} + assert runner.decode_jsonl(line=b"{") is None + + err_events = runner.decode_error_events( + raw="oops", line="{}", error=ValueError("nope"), state=state + ) + err_event = err_events[0] + assert isinstance(err_event, ActionEvent) + assert err_event.action.detail["error"] == "nope" + + translated = runner.translate_error_events( + data={"type": "foo", "item": {"type": "bar"}}, + error=ValueError("boom"), + state=state, + ) + translated_event = translated[0] + assert isinstance(translated_event, ActionEvent) + detail = translated_event.action.detail + assert detail["type"] == "foo" + assert detail["item_type"] == "bar" + + resume = ResumeToken(engine=runner.engine, value="sid") + processed = runner.process_error_events( + 2, resume=resume, found_session=None, state=state + ) + processed_event = processed[-1] + assert isinstance(processed_event, CompletedEvent) + assert processed_event.ok is False + assert processed_event.resume == resume + + stream_end = runner.stream_end_events( + resume=None, found_session=resume, state=state + ) + stream_event = stream_end[-1] + assert isinstance(stream_event, CompletedEvent) + assert stream_event.resume == resume + + started = StartedEvent(engine=runner.engine, resume=resume, title="t") + found, emit = runner.handle_started_event( + started, expected_session=None, found_session=None + ) + assert found == resume + assert emit is True + + found, emit = runner.handle_started_event( + started, expected_session=None, found_session=resume + ) + assert found == resume + assert emit is False + + mismatch = StartedEvent(engine=EngineId("other"), resume=resume, title="t") + with pytest.raises(RuntimeError): + runner.handle_started_event(mismatch, expected_session=None, found_session=None) + + other_resume = ResumeToken(engine=runner.engine, value="other") + with pytest.raises(RuntimeError): + runner.handle_started_event( + StartedEvent(engine=runner.engine, resume=other_resume, title="t"), + expected_session=resume, + found_session=None, + ) + + with pytest.raises(RuntimeError): + runner.handle_started_event( + StartedEvent(engine=runner.engine, resume=other_resume, title="t"), + expected_session=None, + found_session=resume, + ) + + +def test_next_note_id_requires_state_field() -> None: + runner = _DummyJsonlRunner() + with pytest.raises(RuntimeError): + runner.next_note_id(object()) + + +def test_jsonl_base_methods_raise_and_defaults() -> None: + runner = _BareJsonlRunner() + with pytest.raises(NotImplementedError): + runner.command() + with pytest.raises(NotImplementedError): + runner.build_args("hi", None, state=None) + with pytest.raises(NotImplementedError): + runner.translate(data={}, state=None, resume=None, found_session=None) + assert runner.pipes_error_message().startswith("bare-jsonl") + state = runner.new_state("hi", None) + assert isinstance(state, JsonlRunState) + assert runner.start_run("hi", None, state=state) is None + assert runner.env(state=state) is None + assert runner.stdin_payload("hi", None, state=state) == b"hi" + + +@pytest.mark.anyio +async def test_jsonl_run_impl_smoke(monkeypatch: pytest.MonkeyPatch) -> None: + class _FakeProc: + def __init__(self) -> None: + self.stdout = object() + self.stderr = object() + self.stdin = None + self.pid = 123 + + async def wait(self) -> int: + return 0 + + class _FakeManager: + def __init__(self, proc: _FakeProc) -> None: + self._proc = proc + + async def __aenter__(self) -> _FakeProc: + return self._proc + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + proc = _FakeProc() + + def fake_manage_subprocess(*args: Any, **kwargs: Any) -> _FakeManager: + _ = args, kwargs + return _FakeManager(proc) + + async def fake_drain_stderr(*args: Any, **kwargs: Any) -> None: + _ = args, kwargs + return None + + monkeypatch.setattr(runner_module, "manage_subprocess", fake_manage_subprocess) + monkeypatch.setattr(runner_module, "drain_stderr", fake_drain_stderr) + + runner = _RunJsonlRunner() + events = [evt async for evt in runner.run_impl("hello", None)] + assert any(isinstance(evt, CompletedEvent) for evt in events) + + +@pytest.mark.anyio +async def test_jsonl_run_impl_branches(monkeypatch: pytest.MonkeyPatch) -> None: + class _FakeProc: + def __init__(self) -> None: + self.stdout = object() + self.stderr = object() + self.stdin = None + self.pid = 456 + + async def wait(self) -> int: + return 0 + + class _FakeManager: + def __init__(self, proc: _FakeProc) -> None: + self._proc = proc + + async def __aenter__(self) -> _FakeProc: + return self._proc + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + proc = _FakeProc() + + def fake_manage_subprocess(*args: Any, **kwargs: Any) -> _FakeManager: + _ = args, kwargs + return _FakeManager(proc) + + async def fake_drain_stderr(*args: Any, **kwargs: Any) -> None: + _ = args, kwargs + return None + + monkeypatch.setattr(runner_module, "manage_subprocess", fake_manage_subprocess) + monkeypatch.setattr(runner_module, "drain_stderr", fake_drain_stderr) + + runner = _BranchingJsonlRunner() + events = [evt async for evt in runner.run_impl("hello", None)] + assert any(isinstance(evt, CompletedEvent) for evt in events) diff --git a/tests/test_telegram_bridge.py b/tests/test_telegram_bridge.py index 96faa40..17671c7 100644 --- a/tests/test_telegram_bridge.py +++ b/tests/test_telegram_bridge.py @@ -1,5 +1,6 @@ +from dataclasses import replace from pathlib import Path -from typing import cast +from typing import Any, cast import anyio import pytest @@ -18,6 +19,8 @@ from takopi.telegram.bridge import ( _send_with_resume, run_main_loop, ) +from takopi.telegram.client import BotClient +from takopi.telegram.topic_state import TopicStateStore, resolve_state_path from takopi.context import RunContext from takopi.config import ProjectConfig, ProjectsConfig, empty_projects_config from takopi.runner_bridge import ExecBridgeConfig, RunningTask @@ -92,12 +95,13 @@ class _FakeTransport: return None -class _FakeBot: +class _FakeBot(BotClient): def __init__(self) -> None: self.command_calls: list[dict] = [] self.callback_calls: list[dict] = [] self.send_calls: list[dict] = [] self.edit_calls: list[dict] = [] + self.edit_topic_calls: list[dict[str, Any]] = [] self.delete_calls: list[dict] = [] async def get_updates( @@ -105,13 +109,13 @@ class _FakeBot: offset: int | None, timeout_s: int = 50, allowed_updates: list[str] | None = None, - ) -> list[dict] | None: + ) -> list[dict[str, Any]] | None: _ = offset _ = timeout_s _ = allowed_updates return [] - async def get_file(self, file_id: str) -> dict | None: + async def get_file(self, file_id: str) -> dict[str, Any] | None: _ = file_id return None @@ -125,18 +129,20 @@ class _FakeBot: text: str, reply_to_message_id: int | None = None, disable_notification: bool | None = False, - entities: list[dict] | None = None, + message_thread_id: int | None = None, + entities: list[dict[str, Any]] | None = None, parse_mode: str | None = None, reply_markup: dict | None = None, *, replace_message_id: int | None = None, - ) -> dict: + ) -> dict[str, Any]: self.send_calls.append( { "chat_id": chat_id, "text": text, "reply_to_message_id": reply_to_message_id, "disable_notification": disable_notification, + "message_thread_id": message_thread_id, "entities": entities, "parse_mode": parse_mode, "reply_markup": reply_markup, @@ -150,12 +156,12 @@ class _FakeBot: chat_id: int, message_id: int, text: str, - entities: list[dict] | None = None, + entities: list[dict[str, Any]] | None = None, parse_mode: str | None = None, reply_markup: dict | None = None, *, wait: bool = True, - ) -> dict: + ) -> dict[str, Any]: self.edit_calls.append( { "chat_id": chat_id, @@ -175,9 +181,9 @@ class _FakeBot: async def set_my_commands( self, - commands: list[dict], + commands: list[dict[str, Any]], *, - scope: dict | None = None, + scope: dict[str, Any] | None = None, language_code: str | None = None, ) -> bool: self.command_calls.append( @@ -189,9 +195,39 @@ class _FakeBot: ) return True - async def get_me(self) -> dict | None: + async def get_me(self) -> dict[str, Any] | None: return {"id": 1} + async def get_chat(self, chat_id: int) -> dict[str, Any] | None: + _ = chat_id + return {"id": chat_id, "type": "supergroup", "is_forum": True} + + async def get_chat_member( + self, chat_id: int, user_id: int + ) -> dict[str, Any] | None: + _ = chat_id + _ = user_id + return {"status": "administrator", "can_manage_topics": True} + + async def create_forum_topic( + self, chat_id: int, name: str + ) -> dict[str, Any] | None: + _ = chat_id + _ = name + return {"message_thread_id": 1} + + async def edit_forum_topic( + self, chat_id: int, message_thread_id: int, name: str + ) -> bool: + self.edit_topic_calls.append( + { + "chat_id": chat_id, + "message_thread_id": message_thread_id, + "name": name, + } + ) + return True + async def close(self) -> None: return None @@ -457,19 +493,19 @@ async def test_telegram_transport_passes_reply_markup() -> None: @pytest.mark.anyio async def test_telegram_transport_edit_wait_false_returns_ref() -> None: - class _OutboxBot: + class _OutboxBot(BotClient): def __init__(self) -> None: - self.edit_calls: list[dict[str, object]] = [] + self.edit_calls: list[dict[str, Any]] = [] async def get_updates( self, offset: int | None, timeout_s: int = 50, allowed_updates: list[str] | None = None, - ) -> list[dict] | None: + ) -> list[dict[str, Any]] | None: return None - async def get_file(self, file_id: str) -> dict | None: + async def get_file(self, file_id: str) -> dict[str, Any] | None: _ = file_id return None @@ -483,7 +519,8 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None: text: str, reply_to_message_id: int | None = None, disable_notification: bool | None = False, - entities: list[dict] | None = None, + message_thread_id: int | None = None, + entities: list[dict[str, Any]] | None = None, parse_mode: str | None = None, reply_markup: dict | None = None, *, @@ -497,7 +534,7 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None: chat_id: int, message_id: int, text: str, - entities: list[dict] | None = None, + entities: list[dict[str, Any]] | None = None, parse_mode: str | None = None, reply_markup: dict | None = None, *, @@ -527,14 +564,14 @@ async def test_telegram_transport_edit_wait_false_returns_ref() -> None: async def set_my_commands( self, - commands: list[dict[str, object]], + commands: list[dict[str, Any]], *, - scope: dict[str, object] | None = None, + scope: dict[str, Any] | None = None, language_code: str | None = None, ) -> bool: return False - async def get_me(self) -> dict | None: + async def get_me(self) -> dict[str, Any] | None: return None async def close(self) -> None: @@ -755,11 +792,115 @@ def test_resolve_message_accepts_backticked_ctx_line() -> None: assert resolved.context == RunContext(project="takopi", branch="feat/api") +def test_topic_title_matches_command_syntax() -> None: + transport = _FakeTransport() + cfg = _make_cfg(transport) + + title = bridge._topic_title( + cfg=cfg, + runtime=cfg.runtime, + context=RunContext(project="takopi", branch="master"), + ) + + assert title == "takopi @master" + + title = bridge._topic_title( + cfg=cfg, + runtime=cfg.runtime, + context=RunContext(project="takopi", branch=None), + ) + + assert title == "takopi" + + title = bridge._topic_title( + cfg=cfg, + runtime=cfg.runtime, + context=RunContext(project=None, branch="main"), + ) + + assert title == "@main" + + +def test_topic_title_per_project_chat_includes_project() -> None: + transport = _FakeTransport() + cfg = replace( + _make_cfg(transport), + topics=bridge.TelegramTopicsConfig( + enabled=True, + mode="per_project_chat", + ), + ) + + title = bridge._topic_title( + cfg=cfg, + runtime=cfg.runtime, + context=RunContext(project="takopi", branch="master"), + ) + + assert title == "takopi @master" + + +@pytest.mark.anyio +async def test_maybe_rename_topic_updates_title(tmp_path: Path) -> None: + transport = _FakeTransport() + cfg = _make_cfg(transport) + store = TopicStateStore(tmp_path / "telegram_topics_state.json") + + await store.set_context( + 123, + 77, + RunContext(project="takopi", branch="old"), + topic_title="takopi @old", + ) + + await bridge._maybe_rename_topic( + cfg, + store, + chat_id=123, + thread_id=77, + context=RunContext(project="takopi", branch="new"), + ) + + bot = cast(_FakeBot, cfg.bot) + assert bot.edit_topic_calls + assert bot.edit_topic_calls[-1]["name"] == "takopi @new" + snapshot = await store.get_thread(123, 77) + assert snapshot is not None + assert snapshot.topic_title == "takopi @new" + + +@pytest.mark.anyio +async def test_maybe_rename_topic_skips_when_title_matches(tmp_path: Path) -> None: + transport = _FakeTransport() + cfg = _make_cfg(transport) + store = TopicStateStore(tmp_path / "telegram_topics_state.json") + + await store.set_context( + 123, + 77, + RunContext(project="takopi", branch="main"), + topic_title="takopi @main", + ) + snapshot = await store.get_thread(123, 77) + + await bridge._maybe_rename_topic( + cfg, + store, + chat_id=123, + thread_id=77, + context=RunContext(project="takopi", branch="main"), + snapshot=snapshot, + ) + + bot = cast(_FakeBot, cfg.bot) + assert bot.edit_topic_calls == [] + + @pytest.mark.anyio async def test_send_with_resume_waits_for_token() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) - sent: list[tuple[int, int, str, ResumeToken, RunContext | None]] = [] + sent: list[tuple[int, int, str, ResumeToken, RunContext | None, int | None]] = [] async def enqueue( chat_id: int, @@ -767,8 +908,9 @@ async def test_send_with_resume_waits_for_token() -> None: text: str, resume: ResumeToken, context: RunContext | None, + thread_id: int | None, ) -> None: - sent.append((chat_id, user_msg_id, text, resume, context)) + sent.append((chat_id, user_msg_id, text, resume, context, thread_id)) running_task = RunningTask() @@ -785,11 +927,19 @@ async def test_send_with_resume_waits_for_token() -> None: running_task, 123, 10, + None, "hello", ) assert sent == [ - (123, 10, "hello", ResumeToken(engine=CODEX_ENGINE, value="abc123"), None) + ( + 123, + 10, + "hello", + ResumeToken(engine=CODEX_ENGINE, value="abc123"), + None, + None, + ) ] assert transport.send_calls == [] @@ -798,7 +948,7 @@ async def test_send_with_resume_waits_for_token() -> None: async def test_send_with_resume_reports_when_missing() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) - sent: list[tuple[int, int, str, ResumeToken, RunContext | None]] = [] + sent: list[tuple[int, int, str, ResumeToken, RunContext | None, int | None]] = [] async def enqueue( chat_id: int, @@ -806,8 +956,9 @@ async def test_send_with_resume_reports_when_missing() -> None: text: str, resume: ResumeToken, context: RunContext | None, + thread_id: int | None, ) -> None: - sent.append((chat_id, user_msg_id, text, resume, context)) + sent.append((chat_id, user_msg_id, text, resume, context, thread_id)) running_task = RunningTask() running_task.done.set() @@ -818,6 +969,7 @@ async def test_send_with_resume_reports_when_missing() -> None: running_task, 123, 10, + None, "hello", ) @@ -903,6 +1055,75 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None: tg.cancel_scope.cancel() +@pytest.mark.anyio +async def test_run_main_loop_persists_topic_sessions_in_per_project_chat( + tmp_path: Path, +) -> None: + project_chat_id = -100 + resume_value = "resume-123" + + transport = _FakeTransport() + bot = _FakeBot() + runner = ScriptRunner( + [Return(answer="ok")], + engine=CODEX_ENGINE, + resume_value=resume_value, + ) + exec_cfg = ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ) + projects = ProjectsConfig( + projects={ + "takopi": ProjectConfig( + alias="takopi", + path=Path("."), + worktrees_dir=Path(".worktrees"), + chat_id=project_chat_id, + ) + }, + default_project=None, + chat_map={project_chat_id: "takopi"}, + ) + runtime = TransportRuntime( + router=_make_router(runner), + projects=projects, + config_path=tmp_path / "takopi.toml", + ) + cfg = TelegramBridgeConfig( + bot=bot, + runtime=runtime, + chat_id=123, + startup_msg="", + exec_cfg=exec_cfg, + topics=bridge.TelegramTopicsConfig( + enabled=True, + mode="per_project_chat", + ), + ) + + async def poller(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=project_chat_id, + message_id=1, + text="hello", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + thread_id=77, + ) + + with anyio.fail_after(2): + await run_main_loop(cfg, poller) + + state_path = resolve_state_path(runtime.config_path or tmp_path / "takopi.toml") + store = TopicStateStore(state_path) + stored = await store.get_session_resume(project_chat_id, 77, CODEX_ENGINE) + assert stored == ResumeToken(engine=CODEX_ENGINE, value=resume_value) + + @pytest.mark.anyio async def test_run_main_loop_handles_command_plugins(monkeypatch) -> None: class _Command: diff --git a/tests/test_telegram_incoming.py b/tests/test_telegram_incoming.py index 5f7c8bb..6b08a10 100644 --- a/tests/test_telegram_incoming.py +++ b/tests/test_telegram_incoming.py @@ -11,7 +11,7 @@ def test_parse_incoming_update_maps_fields() -> None: "message": { "message_id": 10, "text": "hello", - "chat": {"id": 123}, + "chat": {"id": 123, "type": "supergroup", "is_forum": True}, "from": {"id": 99}, "reply_to_message": {"message_id": 5, "text": "prev"}, }, @@ -27,6 +27,10 @@ def test_parse_incoming_update_maps_fields() -> None: assert msg.reply_to_message_id == 5 assert msg.reply_to_text == "prev" assert msg.sender_id == 99 + assert msg.thread_id is None + assert msg.is_topic_message is None + assert msg.chat_type == "supergroup" + assert msg.is_forum is True assert msg.voice is None assert msg.raw == update["message"] @@ -102,3 +106,23 @@ def test_parse_incoming_update_callback_query() -> None: assert msg.callback_query_id == "cbq-1" assert msg.data == "takopi:cancel" assert msg.sender_id == 321 + + +def test_parse_incoming_update_topic_fields() -> None: + update = { + "update_id": 1, + "message": { + "message_id": 10, + "text": "hello", + "message_thread_id": 77, + "is_topic_message": True, + "chat": {"id": -100, "type": "supergroup", "is_forum": True}, + }, + } + + msg = parse_incoming_update(update, chat_id=-100) + assert isinstance(msg, TelegramIncomingMessage) + assert msg.thread_id == 77 + assert msg.is_topic_message is True + assert msg.chat_type == "supergroup" + assert msg.is_forum is True diff --git a/tests/test_telegram_queue.py b/tests/test_telegram_queue.py index e0ae77a..d401712 100644 --- a/tests/test_telegram_queue.py +++ b/tests/test_telegram_queue.py @@ -1,14 +1,17 @@ +from typing import Any + import anyio import pytest -from takopi.telegram.client import TelegramClient, TelegramRetryAfter +from takopi.telegram.client import BotClient, TelegramClient, TelegramRetryAfter -class _FakeBot: +class _FakeBot(BotClient): def __init__(self) -> None: self.calls: list[str] = [] self.edit_calls: list[str] = [] self.delete_calls: list[tuple[int, int]] = [] + self.topic_calls: list[tuple[int, int, str]] = [] self._edit_attempts = 0 self._updates_attempts = 0 self.retry_after: float | None = None @@ -20,14 +23,16 @@ class _FakeBot: text: str, reply_to_message_id: int | None = None, disable_notification: bool | None = False, - entities: list[dict] | None = None, + message_thread_id: int | None = None, + entities: list[dict[str, Any]] | None = None, parse_mode: str | None = None, reply_markup: dict | None = None, *, replace_message_id: int | None = None, - ) -> dict: + ) -> dict[str, Any]: _ = reply_to_message_id _ = disable_notification + _ = message_thread_id _ = entities _ = parse_mode _ = reply_markup @@ -40,12 +45,12 @@ class _FakeBot: chat_id: int, message_id: int, text: str, - entities: list[dict] | None = None, + entities: list[dict[str, Any]] | None = None, parse_mode: str | None = None, reply_markup: dict | None = None, *, wait: bool = True, - ) -> dict: + ) -> dict[str, Any]: _ = chat_id _ = message_id _ = entities @@ -71,9 +76,9 @@ class _FakeBot: async def set_my_commands( self, - commands: list[dict], + commands: list[dict[str, Any]], *, - scope: dict | None = None, + scope: dict[str, Any] | None = None, language_code: str | None = None, ) -> bool: _ = commands @@ -86,7 +91,7 @@ class _FakeBot: offset: int | None, timeout_s: int = 50, allowed_updates: list[str] | None = None, - ) -> list[dict] | None: + ) -> list[dict[str, Any]] | None: _ = offset _ = timeout_s _ = allowed_updates @@ -96,7 +101,7 @@ class _FakeBot: self._updates_attempts += 1 return [] - async def get_file(self, file_id: str) -> dict | None: + async def get_file(self, file_id: str) -> dict[str, Any] | None: _ = file_id return None @@ -107,7 +112,7 @@ class _FakeBot: async def close(self) -> None: return None - async def get_me(self) -> dict | None: + async def get_me(self) -> dict[str, Any] | None: return {"id": 1} async def answer_callback_query( @@ -119,6 +124,27 @@ class _FakeBot: _ = callback_query_id, text, show_alert return True + async def edit_forum_topic( + self, chat_id: int, message_thread_id: int, name: str + ) -> bool: + self.calls.append("edit_forum_topic") + self.topic_calls.append((chat_id, message_thread_id, name)) + return True + + +@pytest.mark.anyio +async def test_edit_forum_topic_uses_outbox() -> None: + bot = _FakeBot() + client = TelegramClient(client=bot, private_chat_rps=0.0, group_chat_rps=0.0) + + result = await client.edit_forum_topic( + chat_id=7, message_thread_id=42, name="takopi @main" + ) + + assert result is True + assert bot.calls == ["edit_forum_topic"] + assert bot.topic_calls == [(7, 42, "takopi @main")] + @pytest.mark.anyio async def test_edits_coalesce_latest() -> None: diff --git a/tests/test_telegram_topic_state.py b/tests/test_telegram_topic_state.py new file mode 100644 index 0000000..3c839d7 --- /dev/null +++ b/tests/test_telegram_topic_state.py @@ -0,0 +1,49 @@ +import pytest + +from takopi.context import RunContext +from takopi.model import ResumeToken +from takopi.telegram.topic_state import TopicStateStore + + +@pytest.mark.anyio +async def test_topic_state_store_roundtrip(tmp_path) -> None: + path = tmp_path / "telegram_topics_state.json" + store = TopicStateStore(path) + context = RunContext(project="proj", branch="feat/topic") + await store.set_context(1, 10, context) + await store.set_session_resume(1, 10, ResumeToken(engine="codex", value="abc123")) + + snapshot = await store.get_thread(1, 10) + assert snapshot is not None + assert snapshot.context == context + assert snapshot.sessions == {"codex": "abc123"} + + store2 = TopicStateStore(path) + snapshot2 = await store2.get_thread(1, 10) + assert snapshot2 is not None + assert snapshot2.context == context + assert snapshot2.sessions == {"codex": "abc123"} + + +@pytest.mark.anyio +async def test_topic_state_store_clear_and_find(tmp_path) -> None: + path = tmp_path / "telegram_topics_state.json" + store = TopicStateStore(path) + context = RunContext(project="proj", branch="main") + await store.set_context(2, 20, context) + await store.set_session_resume( + 2, 20, ResumeToken(engine="claude", value="resume-token") + ) + + found = await store.find_thread_for_context(2, context) + assert found == 20 + + await store.clear_sessions(2, 20) + snapshot = await store.get_thread(2, 20) + assert snapshot is not None + assert snapshot.sessions == {} + + await store.clear_context(2, 20) + snapshot = await store.get_thread(2, 20) + assert snapshot is not None + assert snapshot.context is None diff --git a/tests/test_transport_runtime.py b/tests/test_transport_runtime.py index 8cd6db6..0bad688 100644 --- a/tests/test_transport_runtime.py +++ b/tests/test_transport_runtime.py @@ -71,3 +71,93 @@ def test_resolve_message_defaults_to_chat_project() -> None: ) assert resolved.context == RunContext(project="proj", branch=None) + + +def test_resolve_message_uses_ambient_context() -> None: + runtime = _make_runtime() + ambient = RunContext(project="proj", branch="feat/ambient") + + resolved = runtime.resolve_message( + text="hello", + reply_text=None, + ambient_context=ambient, + ) + + assert resolved.context == ambient + assert resolved.context_source == "ambient" + + +def test_resolve_message_reply_ctx_overrides_ambient() -> None: + runtime = _make_runtime() + ambient = RunContext(project="proj", branch="feat/ambient") + + resolved = runtime.resolve_message( + text="hello", + reply_text="`ctx: proj @ reply`", + ambient_context=ambient, + ) + + assert resolved.context == RunContext(project="proj", branch="reply") + assert resolved.context_source == "reply_ctx" + + +def test_resolve_message_directives_override_ambient() -> None: + runtime = _make_runtime() + ambient = RunContext(project="proj", branch="feat/ambient") + + resolved = runtime.resolve_message( + text="/proj @main do it", + reply_text=None, + ambient_context=ambient, + ) + + assert resolved.context == RunContext(project="proj", branch="main") + assert resolved.context_source == "directives" + + +def test_resolve_message_branch_directive_merges_with_ambient_project() -> None: + runtime = _make_runtime() + ambient = RunContext(project="proj", branch="feat/ambient") + + resolved = runtime.resolve_message( + text="@hotfix do it", + reply_text=None, + ambient_context=ambient, + ) + + assert resolved.context == RunContext(project="proj", branch="hotfix") + assert resolved.context_source == "directives" + + +def test_resolve_message_project_directive_clears_ambient_branch() -> None: + codex = ScriptRunner([Return(answer="ok")], engine="codex") + router = AutoRouter( + entries=[RunnerEntry(engine=codex.engine, runner=codex)], + default_engine=codex.engine, + ) + projects = ProjectsConfig( + projects={ + "proj": ProjectConfig( + alias="proj", + path=Path("."), + worktrees_dir=Path(".worktrees"), + ), + "other": ProjectConfig( + alias="other", + path=Path("."), + worktrees_dir=Path(".worktrees"), + ), + }, + default_project=None, + ) + runtime = TransportRuntime(router=router, projects=projects) + ambient = RunContext(project="proj", branch="feat/ambient") + + resolved = runtime.resolve_message( + text="/other do it", + reply_text=None, + ambient_context=ambient, + ) + + assert resolved.context == RunContext(project="other", branch=None) + assert resolved.context_source == "directives"