fix(telegram): align prompt upload resume flow (#105)

This commit is contained in:
banteg
2026-01-12 21:41:18 +04:00
committed by GitHub
parent 8b2903ffa3
commit 88591216f9
3 changed files with 416 additions and 17 deletions
+31 -4
View File
@@ -763,7 +763,12 @@ async def _handle_media_group(
messages: Sequence[TelegramIncomingMessage], messages: Sequence[TelegramIncomingMessage],
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
run_prompt: Callable[ run_prompt: Callable[
[TelegramIncomingMessage, str, RunContext | None], Awaitable[None] [TelegramIncomingMessage, str, ResolvedMessage], Awaitable[None]
]
| None = None,
resolve_prompt: Callable[
[TelegramIncomingMessage, str, RunContext | None],
Awaitable[ResolvedMessage | None],
] ]
| None = None, | None = None,
) -> None: ) -> None:
@@ -810,12 +815,29 @@ async def _handle_media_group(
if cfg.files.enabled and cfg.files.auto_put: if cfg.files.enabled and cfg.files.auto_put:
caption_text = command_msg.text.strip() caption_text = command_msg.text.strip()
if cfg.files.auto_put_mode == "prompt" and caption_text: if cfg.files.auto_put_mode == "prompt" and caption_text:
if resolve_prompt is None:
try:
resolved = cfg.runtime.resolve_message(
text=caption_text,
reply_text=command_msg.reply_to_text,
ambient_context=ambient_context,
chat_id=command_msg.chat_id,
)
except DirectiveError as exc:
await reply(text=f"error:\n{exc}")
return
else:
resolved = await resolve_prompt(
command_msg, caption_text, ambient_context
)
if resolved is None:
return
saved_group = await _save_file_put_group( saved_group = await _save_file_put_group(
cfg, cfg,
command_msg, command_msg,
"", "",
ordered, ordered,
ambient_context, resolved.context,
topic_store, topic_store,
) )
if saved_group is None: if saved_group is None:
@@ -840,8 +862,13 @@ async def _handle_media_group(
if item.rel_path is not None if item.rel_path is not None
] ]
files_text = "\n".join(f"- {path}" for path in paths) files_text = "\n".join(f"- {path}" for path in paths)
prompt = f"{caption_text}\n\n[uploaded files]\n{files_text}" prompt_base = resolved.prompt
await run_prompt(command_msg, prompt, saved_group.context) annotation = f"[uploaded files]\n{files_text}"
if prompt_base and prompt_base.strip():
prompt = f"{prompt_base}\n\n{annotation}"
else:
prompt = annotation
await run_prompt(command_msg, prompt, resolved)
return return
if not caption_text: if not caption_text:
await _handle_file_put_group( await _handle_file_put_group(
+160 -13
View File
@@ -17,6 +17,7 @@ from ..model import EngineId, ResumeToken
from ..scheduler import ThreadJob, ThreadScheduler from ..scheduler import ThreadJob, ThreadScheduler
from ..settings import TelegramTransportSettings from ..settings import TelegramTransportSettings
from ..transport import MessageRef from ..transport import MessageRef
from ..transport_runtime import ResolvedMessage
from ..context import RunContext from ..context import RunContext
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
from .commands import ( from .commands import (
@@ -484,29 +485,168 @@ async def run_main_loop(
scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job) scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
def _build_upload_prompt(base: str, annotation: str) -> str:
if base and base.strip():
return f"{base}\n\n{annotation}"
return annotation
async def resolve_prompt_message(
msg: TelegramIncomingMessage,
text: str,
ambient_context: RunContext | None,
) -> ResolvedMessage | None:
reply = partial(
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
try:
resolved = cfg.runtime.resolve_message(
text=text,
reply_text=msg.reply_to_text,
ambient_context=ambient_context,
chat_id=msg.chat_id,
)
except DirectiveError as exc:
await reply(text=f"error:\n{exc}")
return None
topic_key = (
_topic_key(msg, cfg, scope_chat_ids=topics_chat_ids)
if topic_store is not None
else None
)
effective_context = ambient_context
if (
topic_store is not None
and topic_key is not None
and resolved.context is not None
and resolved.context_source == "directives"
):
await topic_store.set_context(*topic_key, resolved.context)
await _maybe_rename_topic(
cfg,
topic_store,
chat_id=topic_key[0],
thread_id=topic_key[1],
context=resolved.context,
)
effective_context = resolved.context
if (
topic_store is not None
and topic_key is not None
and effective_context is None
and resolved.context_source not in {"directives", "reply_ctx"}
):
chat_project = (
_topics_chat_project(cfg, msg.chat_id)
if cfg.topics.enabled
else None
)
await reply(
text="this topic isn't bound to a project yet.\n"
f"{_usage_ctx_set(chat_project=chat_project)} or "
f"{_usage_topic(chat_project=chat_project)}",
)
return None
return resolved
async def run_prompt_from_upload( async def run_prompt_from_upload(
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
prompt_text: str, prompt_text: str,
context: RunContext | None, resolved: ResolvedMessage,
) -> None: ) -> None:
chat_id = msg.chat_id
user_msg_id = msg.message_id
reply_id = msg.reply_to_message_id
reply_ref = ( reply_ref = (
MessageRef( MessageRef(
channel_id=msg.chat_id, channel_id=msg.chat_id,
message_id=msg.reply_to_message_id, message_id=msg.reply_to_message_id,
thread_id=msg.thread_id,
) )
if msg.reply_to_message_id is not None if msg.reply_to_message_id is not None
else None else None
) )
await run_job( resume_token = resolved.resume_token
msg.chat_id, engine_override = resolved.engine_override
msg.message_id, context = resolved.context
chat_session_key = _chat_session_key(msg, store=chat_session_store)
topic_key = (
_topic_key(msg, cfg, scope_chat_ids=topics_chat_ids)
if topic_store is not None
else None
)
if resume_token is None and reply_id is not None:
running_task = running_tasks.get(
MessageRef(channel_id=chat_id, message_id=reply_id)
)
if running_task is not None:
tg.start_soon(
send_with_resume,
cfg,
scheduler.enqueue_resume,
running_task,
chat_id,
user_msg_id,
msg.thread_id,
chat_session_key,
prompt_text,
)
return
if (
resume_token is None
and topic_store is not None
and topic_key is not None
):
engine_for_session = cfg.runtime.resolve_engine(
engine_override=engine_override,
context=context,
)
stored = await topic_store.get_session_resume(
topic_key[0], topic_key[1], engine_for_session
)
if stored is not None:
resume_token = stored
if (
resume_token is None
and chat_session_store is not None
and chat_session_key is not None
):
engine_for_session = cfg.runtime.resolve_engine(
engine_override=engine_override,
context=context,
)
stored = await chat_session_store.get_session_resume(
chat_session_key[0],
chat_session_key[1],
engine_for_session,
)
if stored is not None:
resume_token = stored
if resume_token is None:
await run_job(
chat_id,
user_msg_id,
prompt_text,
None,
context,
msg.thread_id,
chat_session_key,
reply_ref,
scheduler.note_thread_known,
engine_override,
)
return
await scheduler.enqueue_resume(
chat_id,
user_msg_id,
prompt_text, prompt_text,
None, resume_token,
context, context,
msg.thread_id, msg.thread_id,
None, chat_session_key,
reply_ref,
scheduler.note_thread_known,
) )
async def handle_prompt_upload( async def handle_prompt_upload(
@@ -515,19 +655,25 @@ async def run_main_loop(
ambient_context: RunContext | None, ambient_context: RunContext | None,
topic_store: TopicStateStore | None, topic_store: TopicStateStore | None,
) -> None: ) -> None:
resolved = await resolve_prompt_message(
msg,
caption_text,
ambient_context,
)
if resolved is None:
return
saved = await _save_file_put( saved = await _save_file_put(
cfg, cfg,
msg, msg,
"", "",
ambient_context, resolved.context,
topic_store, topic_store,
) )
if saved is None: if saved is None:
return return
prompt = ( annotation = f"[uploaded file: {saved.rel_path.as_posix()}]"
f"{caption_text}\n\n[uploaded file: {saved.rel_path.as_posix()}]" prompt = _build_upload_prompt(resolved.prompt, annotation)
) await run_prompt_from_upload(msg, prompt, resolved)
await run_prompt_from_upload(msg, prompt, saved.context)
async def flush_media_group(key: tuple[int, str]) -> None: async def flush_media_group(key: tuple[int, str]) -> None:
while True: while True:
@@ -548,6 +694,7 @@ async def run_main_loop(
messages, messages,
topic_store, topic_store,
run_prompt_from_upload, run_prompt_from_upload,
resolve_prompt_message,
) )
return return
+225
View File
@@ -1536,6 +1536,231 @@ async def test_run_main_loop_auto_resumes_chat_sessions(tmp_path: Path) -> None:
assert runner2.calls[0][1] == ResumeToken(engine=CODEX_ENGINE, value=resume_value) assert runner2.calls[0][1] == ResumeToken(engine=CODEX_ENGINE, value=resume_value)
@pytest.mark.anyio
async def test_run_main_loop_prompt_upload_uses_caption_directives(
tmp_path: Path,
) -> None:
payload = b"hello"
proj_dir = tmp_path / "proj"
other_dir = tmp_path / "other"
proj_dir.mkdir()
other_dir.mkdir()
class _UploadBot(_FakeBot):
async def get_file(self, file_id: str) -> File | None:
_ = file_id
return File(file_path="files/hello.txt")
async def download_file(self, file_path: str) -> bytes | None:
_ = file_path
return payload
transport = _FakeTransport()
bot = _UploadBot()
runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
exec_cfg = ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
)
projects = ProjectsConfig(
projects={
"proj": ProjectConfig(
alias="proj",
path=proj_dir,
worktrees_dir=Path(".worktrees"),
),
"other": ProjectConfig(
alias="other",
path=other_dir,
worktrees_dir=Path(".worktrees"),
),
},
default_project="proj",
)
runtime = TransportRuntime(router=_make_router(runner), projects=projects)
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg,
files=TelegramFilesSettings(
enabled=True,
auto_put=True,
auto_put_mode="prompt",
),
)
async def poller(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=1,
text="/other do thing",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
chat_type="private",
document=TelegramDocument(
file_id="doc-1",
file_name="hello.txt",
mime_type="text/plain",
file_size=len(payload),
raw={"file_id": "doc-1"},
),
)
await run_main_loop(cfg, poller)
saved_path = other_dir / "incoming" / "hello.txt"
assert saved_path.read_bytes() == payload
assert runner.calls
prompt_text, _ = runner.calls[0]
assert prompt_text.startswith("do thing")
assert "/other" not in prompt_text
assert "[uploaded file: incoming/hello.txt]" in prompt_text
@pytest.mark.anyio
async def test_run_main_loop_prompt_upload_auto_resumes_chat_sessions(
tmp_path: Path,
) -> None:
payload = b"hello"
resume_value = "resume-123"
state_path = tmp_path / "takopi.toml"
project_dir = tmp_path / "proj"
project_dir.mkdir()
class _UploadBot(_FakeBot):
async def get_file(self, file_id: str) -> File | None:
_ = file_id
return File(file_path="files/hello.txt")
async def download_file(self, file_path: str) -> bytes | None:
_ = file_path
return payload
projects = ProjectsConfig(
projects={
"proj": ProjectConfig(
alias="proj",
path=project_dir,
worktrees_dir=Path(".worktrees"),
)
},
default_project="proj",
)
bot = _UploadBot()
transport = _FakeTransport()
runner = ScriptRunner(
[Return(answer="ok")],
engine=CODEX_ENGINE,
resume_value=resume_value,
)
exec_cfg = ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
)
runtime = TransportRuntime(
router=_make_router(runner),
projects=projects,
config_path=state_path,
)
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg,
session_mode="chat",
files=TelegramFilesSettings(
enabled=True,
auto_put=True,
auto_put_mode="prompt",
),
)
async def poller(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=1,
text="hello",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
chat_type="private",
document=TelegramDocument(
file_id="doc-1",
file_name="hello.txt",
mime_type="text/plain",
file_size=len(payload),
raw={"file_id": "doc-1"},
),
)
await run_main_loop(cfg, poller)
store = ChatSessionStore(resolve_sessions_path(state_path))
stored = await store.get_session_resume(123, None, CODEX_ENGINE)
assert stored == ResumeToken(engine=CODEX_ENGINE, value=resume_value)
transport2 = _FakeTransport()
runner2 = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
exec_cfg2 = ExecBridgeConfig(
transport=transport2,
presenter=MarkdownPresenter(),
final_notify=True,
)
runtime2 = TransportRuntime(
router=_make_router(runner2),
projects=projects,
config_path=state_path,
)
cfg2 = TelegramBridgeConfig(
bot=bot,
runtime=runtime2,
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg2,
session_mode="chat",
files=TelegramFilesSettings(
enabled=True,
auto_put=True,
auto_put_mode="prompt",
),
)
async def poller2(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=2,
text="followup",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
chat_type="private",
document=TelegramDocument(
file_id="doc-2",
file_name="hello2.txt",
mime_type="text/plain",
file_size=len(payload),
raw={"file_id": "doc-2"},
),
)
await run_main_loop(cfg2, poller2)
assert runner2.calls[0][1] == ResumeToken(
engine=CODEX_ENGINE,
value=resume_value,
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_run_main_loop_hides_resume_line_when_disabled( async def test_run_main_loop_hides_resume_line_when_disabled(
tmp_path: Path, tmp_path: Path,