fix(telegram): align prompt upload resume flow (#105)

This commit is contained in:
banteg
2026-01-12 21:41:18 +04:00
committed by GitHub
parent 8b2903ffa3
commit 88591216f9
3 changed files with 416 additions and 17 deletions
+31 -4
View File
@@ -763,7 +763,12 @@ async def _handle_media_group(
messages: Sequence[TelegramIncomingMessage],
topic_store: TopicStateStore | None,
run_prompt: Callable[
[TelegramIncomingMessage, str, RunContext | None], Awaitable[None]
[TelegramIncomingMessage, str, ResolvedMessage], Awaitable[None]
]
| None = None,
resolve_prompt: Callable[
[TelegramIncomingMessage, str, RunContext | None],
Awaitable[ResolvedMessage | None],
]
| None = None,
) -> None:
@@ -810,12 +815,29 @@ async def _handle_media_group(
if cfg.files.enabled and cfg.files.auto_put:
caption_text = command_msg.text.strip()
if cfg.files.auto_put_mode == "prompt" and caption_text:
if resolve_prompt is None:
try:
resolved = cfg.runtime.resolve_message(
text=caption_text,
reply_text=command_msg.reply_to_text,
ambient_context=ambient_context,
chat_id=command_msg.chat_id,
)
except DirectiveError as exc:
await reply(text=f"error:\n{exc}")
return
else:
resolved = await resolve_prompt(
command_msg, caption_text, ambient_context
)
if resolved is None:
return
saved_group = await _save_file_put_group(
cfg,
command_msg,
"",
ordered,
ambient_context,
resolved.context,
topic_store,
)
if saved_group is None:
@@ -840,8 +862,13 @@ async def _handle_media_group(
if item.rel_path is not None
]
files_text = "\n".join(f"- {path}" for path in paths)
prompt = f"{caption_text}\n\n[uploaded files]\n{files_text}"
await run_prompt(command_msg, prompt, saved_group.context)
prompt_base = resolved.prompt
annotation = f"[uploaded files]\n{files_text}"
if prompt_base and prompt_base.strip():
prompt = f"{prompt_base}\n\n{annotation}"
else:
prompt = annotation
await run_prompt(command_msg, prompt, resolved)
return
if not caption_text:
await _handle_file_put_group(
+160 -13
View File
@@ -17,6 +17,7 @@ from ..model import EngineId, ResumeToken
from ..scheduler import ThreadJob, ThreadScheduler
from ..settings import TelegramTransportSettings
from ..transport import MessageRef
from ..transport_runtime import ResolvedMessage
from ..context import RunContext
from .bridge import CANCEL_CALLBACK_DATA, TelegramBridgeConfig, send_plain
from .commands import (
@@ -484,29 +485,168 @@ async def run_main_loop(
scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
def _build_upload_prompt(base: str, annotation: str) -> str:
if base and base.strip():
return f"{base}\n\n{annotation}"
return annotation
async def resolve_prompt_message(
msg: TelegramIncomingMessage,
text: str,
ambient_context: RunContext | None,
) -> ResolvedMessage | None:
reply = partial(
send_plain,
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
thread_id=msg.thread_id,
)
try:
resolved = cfg.runtime.resolve_message(
text=text,
reply_text=msg.reply_to_text,
ambient_context=ambient_context,
chat_id=msg.chat_id,
)
except DirectiveError as exc:
await reply(text=f"error:\n{exc}")
return None
topic_key = (
_topic_key(msg, cfg, scope_chat_ids=topics_chat_ids)
if topic_store is not None
else None
)
effective_context = ambient_context
if (
topic_store is not None
and topic_key is not None
and resolved.context is not None
and resolved.context_source == "directives"
):
await topic_store.set_context(*topic_key, resolved.context)
await _maybe_rename_topic(
cfg,
topic_store,
chat_id=topic_key[0],
thread_id=topic_key[1],
context=resolved.context,
)
effective_context = resolved.context
if (
topic_store is not None
and topic_key is not None
and effective_context is None
and resolved.context_source not in {"directives", "reply_ctx"}
):
chat_project = (
_topics_chat_project(cfg, msg.chat_id)
if cfg.topics.enabled
else None
)
await reply(
text="this topic isn't bound to a project yet.\n"
f"{_usage_ctx_set(chat_project=chat_project)} or "
f"{_usage_topic(chat_project=chat_project)}",
)
return None
return resolved
async def run_prompt_from_upload(
msg: TelegramIncomingMessage,
prompt_text: str,
context: RunContext | None,
resolved: ResolvedMessage,
) -> None:
chat_id = msg.chat_id
user_msg_id = msg.message_id
reply_id = msg.reply_to_message_id
reply_ref = (
MessageRef(
channel_id=msg.chat_id,
message_id=msg.reply_to_message_id,
thread_id=msg.thread_id,
)
if msg.reply_to_message_id is not None
else None
)
await run_job(
msg.chat_id,
msg.message_id,
resume_token = resolved.resume_token
engine_override = resolved.engine_override
context = resolved.context
chat_session_key = _chat_session_key(msg, store=chat_session_store)
topic_key = (
_topic_key(msg, cfg, scope_chat_ids=topics_chat_ids)
if topic_store is not None
else None
)
if resume_token is None and reply_id is not None:
running_task = running_tasks.get(
MessageRef(channel_id=chat_id, message_id=reply_id)
)
if running_task is not None:
tg.start_soon(
send_with_resume,
cfg,
scheduler.enqueue_resume,
running_task,
chat_id,
user_msg_id,
msg.thread_id,
chat_session_key,
prompt_text,
)
return
if (
resume_token is None
and topic_store is not None
and topic_key is not None
):
engine_for_session = cfg.runtime.resolve_engine(
engine_override=engine_override,
context=context,
)
stored = await topic_store.get_session_resume(
topic_key[0], topic_key[1], engine_for_session
)
if stored is not None:
resume_token = stored
if (
resume_token is None
and chat_session_store is not None
and chat_session_key is not None
):
engine_for_session = cfg.runtime.resolve_engine(
engine_override=engine_override,
context=context,
)
stored = await chat_session_store.get_session_resume(
chat_session_key[0],
chat_session_key[1],
engine_for_session,
)
if stored is not None:
resume_token = stored
if resume_token is None:
await run_job(
chat_id,
user_msg_id,
prompt_text,
None,
context,
msg.thread_id,
chat_session_key,
reply_ref,
scheduler.note_thread_known,
engine_override,
)
return
await scheduler.enqueue_resume(
chat_id,
user_msg_id,
prompt_text,
None,
resume_token,
context,
msg.thread_id,
None,
reply_ref,
scheduler.note_thread_known,
chat_session_key,
)
async def handle_prompt_upload(
@@ -515,19 +655,25 @@ async def run_main_loop(
ambient_context: RunContext | None,
topic_store: TopicStateStore | None,
) -> None:
resolved = await resolve_prompt_message(
msg,
caption_text,
ambient_context,
)
if resolved is None:
return
saved = await _save_file_put(
cfg,
msg,
"",
ambient_context,
resolved.context,
topic_store,
)
if saved is None:
return
prompt = (
f"{caption_text}\n\n[uploaded file: {saved.rel_path.as_posix()}]"
)
await run_prompt_from_upload(msg, prompt, saved.context)
annotation = f"[uploaded file: {saved.rel_path.as_posix()}]"
prompt = _build_upload_prompt(resolved.prompt, annotation)
await run_prompt_from_upload(msg, prompt, resolved)
async def flush_media_group(key: tuple[int, str]) -> None:
while True:
@@ -548,6 +694,7 @@ async def run_main_loop(
messages,
topic_store,
run_prompt_from_upload,
resolve_prompt_message,
)
return
+225
View File
@@ -1536,6 +1536,231 @@ async def test_run_main_loop_auto_resumes_chat_sessions(tmp_path: Path) -> None:
assert runner2.calls[0][1] == ResumeToken(engine=CODEX_ENGINE, value=resume_value)
@pytest.mark.anyio
async def test_run_main_loop_prompt_upload_uses_caption_directives(
tmp_path: Path,
) -> None:
payload = b"hello"
proj_dir = tmp_path / "proj"
other_dir = tmp_path / "other"
proj_dir.mkdir()
other_dir.mkdir()
class _UploadBot(_FakeBot):
async def get_file(self, file_id: str) -> File | None:
_ = file_id
return File(file_path="files/hello.txt")
async def download_file(self, file_path: str) -> bytes | None:
_ = file_path
return payload
transport = _FakeTransport()
bot = _UploadBot()
runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
exec_cfg = ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
)
projects = ProjectsConfig(
projects={
"proj": ProjectConfig(
alias="proj",
path=proj_dir,
worktrees_dir=Path(".worktrees"),
),
"other": ProjectConfig(
alias="other",
path=other_dir,
worktrees_dir=Path(".worktrees"),
),
},
default_project="proj",
)
runtime = TransportRuntime(router=_make_router(runner), projects=projects)
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg,
files=TelegramFilesSettings(
enabled=True,
auto_put=True,
auto_put_mode="prompt",
),
)
async def poller(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=1,
text="/other do thing",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
chat_type="private",
document=TelegramDocument(
file_id="doc-1",
file_name="hello.txt",
mime_type="text/plain",
file_size=len(payload),
raw={"file_id": "doc-1"},
),
)
await run_main_loop(cfg, poller)
saved_path = other_dir / "incoming" / "hello.txt"
assert saved_path.read_bytes() == payload
assert runner.calls
prompt_text, _ = runner.calls[0]
assert prompt_text.startswith("do thing")
assert "/other" not in prompt_text
assert "[uploaded file: incoming/hello.txt]" in prompt_text
@pytest.mark.anyio
async def test_run_main_loop_prompt_upload_auto_resumes_chat_sessions(
tmp_path: Path,
) -> None:
payload = b"hello"
resume_value = "resume-123"
state_path = tmp_path / "takopi.toml"
project_dir = tmp_path / "proj"
project_dir.mkdir()
class _UploadBot(_FakeBot):
async def get_file(self, file_id: str) -> File | None:
_ = file_id
return File(file_path="files/hello.txt")
async def download_file(self, file_path: str) -> bytes | None:
_ = file_path
return payload
projects = ProjectsConfig(
projects={
"proj": ProjectConfig(
alias="proj",
path=project_dir,
worktrees_dir=Path(".worktrees"),
)
},
default_project="proj",
)
bot = _UploadBot()
transport = _FakeTransport()
runner = ScriptRunner(
[Return(answer="ok")],
engine=CODEX_ENGINE,
resume_value=resume_value,
)
exec_cfg = ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
)
runtime = TransportRuntime(
router=_make_router(runner),
projects=projects,
config_path=state_path,
)
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg,
session_mode="chat",
files=TelegramFilesSettings(
enabled=True,
auto_put=True,
auto_put_mode="prompt",
),
)
async def poller(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=1,
text="hello",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
chat_type="private",
document=TelegramDocument(
file_id="doc-1",
file_name="hello.txt",
mime_type="text/plain",
file_size=len(payload),
raw={"file_id": "doc-1"},
),
)
await run_main_loop(cfg, poller)
store = ChatSessionStore(resolve_sessions_path(state_path))
stored = await store.get_session_resume(123, None, CODEX_ENGINE)
assert stored == ResumeToken(engine=CODEX_ENGINE, value=resume_value)
transport2 = _FakeTransport()
runner2 = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
exec_cfg2 = ExecBridgeConfig(
transport=transport2,
presenter=MarkdownPresenter(),
final_notify=True,
)
runtime2 = TransportRuntime(
router=_make_router(runner2),
projects=projects,
config_path=state_path,
)
cfg2 = TelegramBridgeConfig(
bot=bot,
runtime=runtime2,
chat_id=123,
startup_msg="",
exec_cfg=exec_cfg2,
session_mode="chat",
files=TelegramFilesSettings(
enabled=True,
auto_put=True,
auto_put_mode="prompt",
),
)
async def poller2(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=2,
text="followup",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
chat_type="private",
document=TelegramDocument(
file_id="doc-2",
file_name="hello2.txt",
mime_type="text/plain",
file_size=len(payload),
raw={"file_id": "doc-2"},
),
)
await run_main_loop(cfg2, poller2)
assert runner2.calls[0][1] == ResumeToken(
engine=CODEX_ENGINE,
value=resume_value,
)
@pytest.mark.anyio
async def test_run_main_loop_hides_resume_line_when_disabled(
tmp_path: Path,