diff --git a/src/takopi/telegram/commands.py b/src/takopi/telegram/commands.py index ad125d0..602cbd3 100644 --- a/src/takopi/telegram/commands.py +++ b/src/takopi/telegram/commands.py @@ -763,7 +763,12 @@ async def _handle_media_group( messages: Sequence[TelegramIncomingMessage], topic_store: TopicStateStore | None, run_prompt: Callable[ - [TelegramIncomingMessage, str, RunContext | None], Awaitable[None] + [TelegramIncomingMessage, str, ResolvedMessage], Awaitable[None] + ] + | None = None, + resolve_prompt: Callable[ + [TelegramIncomingMessage, str, RunContext | None], + Awaitable[ResolvedMessage | None], ] | None = None, ) -> None: @@ -810,12 +815,29 @@ async def _handle_media_group( if cfg.files.enabled and cfg.files.auto_put: caption_text = command_msg.text.strip() if cfg.files.auto_put_mode == "prompt" and caption_text: + if resolve_prompt is None: + try: + resolved = cfg.runtime.resolve_message( + text=caption_text, + reply_text=command_msg.reply_to_text, + ambient_context=ambient_context, + chat_id=command_msg.chat_id, + ) + except DirectiveError as exc: + await reply(text=f"error:\n{exc}") + return + else: + resolved = await resolve_prompt( + command_msg, caption_text, ambient_context + ) + if resolved is None: + return saved_group = await _save_file_put_group( cfg, command_msg, "", ordered, - ambient_context, + resolved.context, topic_store, ) if saved_group is None: @@ -840,8 +862,13 @@ async def _handle_media_group( if item.rel_path is not None ] files_text = "\n".join(f"- {path}" for path in paths) - prompt = f"{caption_text}\n\n[uploaded files]\n{files_text}" - await run_prompt(command_msg, prompt, saved_group.context) + prompt_base = resolved.prompt + annotation = f"[uploaded files]\n{files_text}" + if prompt_base and prompt_base.strip(): + prompt = f"{prompt_base}\n\n{annotation}" + else: + prompt = annotation + await run_prompt(command_msg, prompt, resolved) return if not caption_text: await _handle_file_put_group( diff --git a/src/takopi/telegram/loop.py b/src/takopi/telegram/loop.py index 1c5194b..2ed780f 100644 --- a/src/takopi/telegram/loop.py +++ b/src/takopi/telegram/loop.py @@ -17,6 +17,7 @@ from ..model import EngineId, ResumeToken from ..scheduler import ThreadJob, ThreadScheduler from ..settings import TelegramTransportSettings from ..transport import MessageRef +from ..transport_runtime import ResolvedMessage from ..context import RunContext from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain from .commands import ( @@ -484,29 +485,168 @@ async def run_main_loop( scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job) + def _build_upload_prompt(base: str, annotation: str) -> str: + if base and base.strip(): + return f"{base}\n\n{annotation}" + return annotation + + async def resolve_prompt_message( + msg: TelegramIncomingMessage, + text: str, + ambient_context: RunContext | None, + ) -> ResolvedMessage | None: + reply = partial( + send_plain, + cfg.exec_cfg.transport, + chat_id=msg.chat_id, + user_msg_id=msg.message_id, + thread_id=msg.thread_id, + ) + try: + resolved = cfg.runtime.resolve_message( + text=text, + reply_text=msg.reply_to_text, + ambient_context=ambient_context, + chat_id=msg.chat_id, + ) + except DirectiveError as exc: + await reply(text=f"error:\n{exc}") + return None + topic_key = ( + _topic_key(msg, cfg, scope_chat_ids=topics_chat_ids) + if topic_store is not None + else None + ) + effective_context = ambient_context + if ( + topic_store is not None + and topic_key is not None + and resolved.context is not None + and resolved.context_source == "directives" + ): + await topic_store.set_context(*topic_key, resolved.context) + await _maybe_rename_topic( + cfg, + topic_store, + chat_id=topic_key[0], + thread_id=topic_key[1], + context=resolved.context, + ) + effective_context = resolved.context + if ( + topic_store is not None + and topic_key is not None + and effective_context is None + and resolved.context_source not in {"directives", "reply_ctx"} + ): + chat_project = ( + _topics_chat_project(cfg, msg.chat_id) + if cfg.topics.enabled + else None + ) + await reply( + text="this topic isn't bound to a project yet.\n" + f"{_usage_ctx_set(chat_project=chat_project)} or " + f"{_usage_topic(chat_project=chat_project)}", + ) + return None + return resolved + async def run_prompt_from_upload( msg: TelegramIncomingMessage, prompt_text: str, - context: RunContext | None, + resolved: ResolvedMessage, ) -> None: + chat_id = msg.chat_id + user_msg_id = msg.message_id + reply_id = msg.reply_to_message_id reply_ref = ( MessageRef( channel_id=msg.chat_id, message_id=msg.reply_to_message_id, + thread_id=msg.thread_id, ) if msg.reply_to_message_id is not None else None ) - await run_job( - msg.chat_id, - msg.message_id, + resume_token = resolved.resume_token + engine_override = resolved.engine_override + context = resolved.context + chat_session_key = _chat_session_key(msg, store=chat_session_store) + topic_key = ( + _topic_key(msg, cfg, scope_chat_ids=topics_chat_ids) + if topic_store is not None + else None + ) + if resume_token is None and reply_id is not None: + running_task = running_tasks.get( + MessageRef(channel_id=chat_id, message_id=reply_id) + ) + if running_task is not None: + tg.start_soon( + send_with_resume, + cfg, + scheduler.enqueue_resume, + running_task, + chat_id, + user_msg_id, + msg.thread_id, + chat_session_key, + prompt_text, + ) + return + if ( + resume_token is None + and topic_store is not None + and topic_key is not None + ): + engine_for_session = cfg.runtime.resolve_engine( + engine_override=engine_override, + context=context, + ) + stored = await topic_store.get_session_resume( + topic_key[0], topic_key[1], engine_for_session + ) + if stored is not None: + resume_token = stored + if ( + resume_token is None + and chat_session_store is not None + and chat_session_key is not None + ): + engine_for_session = cfg.runtime.resolve_engine( + engine_override=engine_override, + context=context, + ) + stored = await chat_session_store.get_session_resume( + chat_session_key[0], + chat_session_key[1], + engine_for_session, + ) + if stored is not None: + resume_token = stored + if resume_token is None: + await run_job( + chat_id, + user_msg_id, + prompt_text, + None, + context, + msg.thread_id, + chat_session_key, + reply_ref, + scheduler.note_thread_known, + engine_override, + ) + return + await scheduler.enqueue_resume( + chat_id, + user_msg_id, prompt_text, - None, + resume_token, context, msg.thread_id, - None, - reply_ref, - scheduler.note_thread_known, + chat_session_key, ) async def handle_prompt_upload( @@ -515,19 +655,25 @@ async def run_main_loop( ambient_context: RunContext | None, topic_store: TopicStateStore | None, ) -> None: + resolved = await resolve_prompt_message( + msg, + caption_text, + ambient_context, + ) + if resolved is None: + return saved = await _save_file_put( cfg, msg, "", - ambient_context, + resolved.context, topic_store, ) if saved is None: return - prompt = ( - f"{caption_text}\n\n[uploaded file: {saved.rel_path.as_posix()}]" - ) - await run_prompt_from_upload(msg, prompt, saved.context) + annotation = f"[uploaded file: {saved.rel_path.as_posix()}]" + prompt = _build_upload_prompt(resolved.prompt, annotation) + await run_prompt_from_upload(msg, prompt, resolved) async def flush_media_group(key: tuple[int, str]) -> None: while True: @@ -548,6 +694,7 @@ async def run_main_loop( messages, topic_store, run_prompt_from_upload, + resolve_prompt_message, ) return diff --git a/tests/test_telegram_bridge.py b/tests/test_telegram_bridge.py index c3e65e4..0e96758 100644 --- a/tests/test_telegram_bridge.py +++ b/tests/test_telegram_bridge.py @@ -1536,6 +1536,231 @@ async def test_run_main_loop_auto_resumes_chat_sessions(tmp_path: Path) -> None: assert runner2.calls[0][1] == ResumeToken(engine=CODEX_ENGINE, value=resume_value) +@pytest.mark.anyio +async def test_run_main_loop_prompt_upload_uses_caption_directives( + tmp_path: Path, +) -> None: + payload = b"hello" + proj_dir = tmp_path / "proj" + other_dir = tmp_path / "other" + proj_dir.mkdir() + other_dir.mkdir() + + class _UploadBot(_FakeBot): + async def get_file(self, file_id: str) -> File | None: + _ = file_id + return File(file_path="files/hello.txt") + + async def download_file(self, file_path: str) -> bytes | None: + _ = file_path + return payload + + transport = _FakeTransport() + bot = _UploadBot() + runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) + exec_cfg = ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ) + projects = ProjectsConfig( + projects={ + "proj": ProjectConfig( + alias="proj", + path=proj_dir, + worktrees_dir=Path(".worktrees"), + ), + "other": ProjectConfig( + alias="other", + path=other_dir, + worktrees_dir=Path(".worktrees"), + ), + }, + default_project="proj", + ) + runtime = TransportRuntime(router=_make_router(runner), projects=projects) + cfg = TelegramBridgeConfig( + bot=bot, + runtime=runtime, + chat_id=123, + startup_msg="", + exec_cfg=exec_cfg, + files=TelegramFilesSettings( + enabled=True, + auto_put=True, + auto_put_mode="prompt", + ), + ) + + async def poller(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=1, + text="/other do thing", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + chat_type="private", + document=TelegramDocument( + file_id="doc-1", + file_name="hello.txt", + mime_type="text/plain", + file_size=len(payload), + raw={"file_id": "doc-1"}, + ), + ) + + await run_main_loop(cfg, poller) + + saved_path = other_dir / "incoming" / "hello.txt" + assert saved_path.read_bytes() == payload + assert runner.calls + prompt_text, _ = runner.calls[0] + assert prompt_text.startswith("do thing") + assert "/other" not in prompt_text + assert "[uploaded file: incoming/hello.txt]" in prompt_text + + +@pytest.mark.anyio +async def test_run_main_loop_prompt_upload_auto_resumes_chat_sessions( + tmp_path: Path, +) -> None: + payload = b"hello" + resume_value = "resume-123" + state_path = tmp_path / "takopi.toml" + project_dir = tmp_path / "proj" + project_dir.mkdir() + + class _UploadBot(_FakeBot): + async def get_file(self, file_id: str) -> File | None: + _ = file_id + return File(file_path="files/hello.txt") + + async def download_file(self, file_path: str) -> bytes | None: + _ = file_path + return payload + + projects = ProjectsConfig( + projects={ + "proj": ProjectConfig( + alias="proj", + path=project_dir, + worktrees_dir=Path(".worktrees"), + ) + }, + default_project="proj", + ) + bot = _UploadBot() + + transport = _FakeTransport() + runner = ScriptRunner( + [Return(answer="ok")], + engine=CODEX_ENGINE, + resume_value=resume_value, + ) + exec_cfg = ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ) + runtime = TransportRuntime( + router=_make_router(runner), + projects=projects, + config_path=state_path, + ) + cfg = TelegramBridgeConfig( + bot=bot, + runtime=runtime, + chat_id=123, + startup_msg="", + exec_cfg=exec_cfg, + session_mode="chat", + files=TelegramFilesSettings( + enabled=True, + auto_put=True, + auto_put_mode="prompt", + ), + ) + + async def poller(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=1, + text="hello", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + chat_type="private", + document=TelegramDocument( + file_id="doc-1", + file_name="hello.txt", + mime_type="text/plain", + file_size=len(payload), + raw={"file_id": "doc-1"}, + ), + ) + + await run_main_loop(cfg, poller) + + store = ChatSessionStore(resolve_sessions_path(state_path)) + stored = await store.get_session_resume(123, None, CODEX_ENGINE) + assert stored == ResumeToken(engine=CODEX_ENGINE, value=resume_value) + + transport2 = _FakeTransport() + runner2 = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) + exec_cfg2 = ExecBridgeConfig( + transport=transport2, + presenter=MarkdownPresenter(), + final_notify=True, + ) + runtime2 = TransportRuntime( + router=_make_router(runner2), + projects=projects, + config_path=state_path, + ) + cfg2 = TelegramBridgeConfig( + bot=bot, + runtime=runtime2, + chat_id=123, + startup_msg="", + exec_cfg=exec_cfg2, + session_mode="chat", + files=TelegramFilesSettings( + enabled=True, + auto_put=True, + auto_put_mode="prompt", + ), + ) + + async def poller2(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=2, + text="followup", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + chat_type="private", + document=TelegramDocument( + file_id="doc-2", + file_name="hello2.txt", + mime_type="text/plain", + file_size=len(payload), + raw={"file_id": "doc-2"}, + ), + ) + + await run_main_loop(cfg2, poller2) + + assert runner2.calls[0][1] == ResumeToken( + engine=CODEX_ENGINE, + value=resume_value, + ) + + @pytest.mark.anyio async def test_run_main_loop_hides_resume_line_when_disabled( tmp_path: Path,