diff --git a/docs/reference/config.md b/docs/reference/config.md index fd1b6a3..db9be52 100644 --- a/docs/reference/config.md +++ b/docs/reference/config.md @@ -30,6 +30,7 @@ chat_id = 123 | `bot_token` | string | (required) | Telegram bot token from @BotFather. | | `chat_id` | int | (required) | Default chat id. | | `message_overflow` | `"trim"`\|`"split"` | `"trim"` | How to handle long final responses. | +| `forward_coalesce_s` | float | `1.0` | Quiet window for combining a prompt with immediately-following forwarded messages; set `0` to disable. | | `voice_transcription` | bool | `false` | Enable voice note transcription. | | `voice_max_bytes` | int | `10485760` | Max voice note size (bytes). | | `voice_transcription_model` | string | `"gpt-4o-mini-transcribe"` | OpenAI transcription model name. | @@ -106,4 +107,3 @@ model = "..." ``` The shape is engine-defined. - diff --git a/docs/reference/transports/telegram.md b/docs/reference/transports/telegram.md index 11dcfee..13bcaf1 100644 --- a/docs/reference/transports/telegram.md +++ b/docs/reference/transports/telegram.md @@ -72,6 +72,26 @@ In group chats, changing trigger mode requires the sender to be an admin. State is stored in `telegram_chat_prefs_state.json` (chat default) and `telegram_topics_state.json` (topic overrides) alongside the config file. +### Forwarded message coalescing + +Telegram sends a "comment + forwards" burst as separate messages, with the comment +arriving first. Takopi waits briefly so it can attach the forwarded messages and +run once. + +Behavior: + +- When a prompt candidate arrives, Takopi waits for `forward_coalesce_s` seconds + of quiet for that sender + chat/topic. +- Forwarded messages arriving during the window are appended to the prompt + (separated by blank lines) and do not start their own runs. +- Forwarded messages by themselves do not start runs. + +Configuration (under `[transports.telegram]`): + +```toml +forward_coalesce_s = 1.0 # set 0 to disable the delay +``` + ## Chat sessions (optional) If you chose the **handoff** workflow during onboarding, Takopi uses stateless mode diff --git a/src/takopi/settings.py b/src/takopi/settings.py index cbce6ba..678b71e 100644 --- a/src/takopi/settings.py +++ b/src/takopi/settings.py @@ -100,6 +100,7 @@ class TelegramTransportSettings(BaseModel): voice_transcription_model: NonEmptyStr = "gpt-4o-mini-transcribe" session_mode: Literal["stateless", "chat"] = "stateless" show_resume_line: bool = True + forward_coalesce_s: float = Field(default=1.0, ge=0) topics: TelegramTopicsSettings = Field(default_factory=TelegramTopicsSettings) files: TelegramFilesSettings = Field(default_factory=TelegramFilesSettings) diff --git a/src/takopi/telegram/backend.py b/src/takopi/telegram/backend.py index 9956c6f..574ae0e 100644 --- a/src/takopi/telegram/backend.py +++ b/src/takopi/telegram/backend.py @@ -138,6 +138,7 @@ class TelegramBackend(TransportBackend): voice_transcription=settings.voice_transcription, voice_max_bytes=int(settings.voice_max_bytes), voice_transcription_model=settings.voice_transcription_model, + forward_coalesce_s=settings.forward_coalesce_s, topics=settings.topics, files=settings.files, ) diff --git a/src/takopi/telegram/bridge.py b/src/takopi/telegram/bridge.py index bfe234b..8af9576 100644 --- a/src/takopi/telegram/bridge.py +++ b/src/takopi/telegram/bridge.py @@ -124,6 +124,7 @@ class TelegramBridgeConfig: voice_transcription: bool = False voice_max_bytes: int = 10 * 1024 * 1024 voice_transcription_model: str = "gpt-4o-mini-transcribe" + forward_coalesce_s: float = 1.0 files: TelegramFilesSettings = field(default_factory=TelegramFilesSettings) chat_ids: tuple[int, ...] | None = None topics: TelegramTopicsSettings = field(default_factory=TelegramTopicsSettings) diff --git a/src/takopi/telegram/loop.py b/src/takopi/telegram/loop.py index a712269..51ca0d6 100644 --- a/src/takopi/telegram/loop.py +++ b/src/takopi/telegram/loop.py @@ -71,6 +71,8 @@ __all__ = ["poll_updates", "run_main_loop", "send_with_resume"] _MEDIA_GROUP_DEBOUNCE_S = 1.0 +ForwardKey = tuple[int, int, int] + def _chat_session_key( msg: TelegramIncomingMessage, *, store: ChatSessionStore | None @@ -246,6 +248,59 @@ class _MediaGroupState: token: int = 0 +@dataclass(slots=True) +class _PendingPrompt: + msg: TelegramIncomingMessage + text: str + ambient_context: RunContext | None + chat_project: str | None + topic_key: tuple[int, int] | None + chat_session_key: tuple[int, int | None] | None + reply_ref: MessageRef | None + reply_id: int | None + is_voice_transcribed: bool + forwards: list[tuple[int, str]] + cancel_scope: anyio.CancelScope | None = None + + +_FORWARD_FIELDS = ( + "forward_origin", + "forward_from", + "forward_from_chat", + "forward_from_message_id", + "forward_sender_name", + "forward_signature", + "forward_date", + "is_automatic_forward", +) + + +def _forward_key(msg: TelegramIncomingMessage) -> ForwardKey: + return (msg.chat_id, msg.thread_id or 0, msg.sender_id or 0) + + +def _is_forwarded(raw: dict[str, object] | None) -> bool: + if not isinstance(raw, dict): + return False + return any(raw.get(field) is not None for field in _FORWARD_FIELDS) + + +def _forward_fields_present(raw: dict[str, object] | None) -> list[str]: + if not isinstance(raw, dict): + return [] + return [field for field in _FORWARD_FIELDS if raw.get(field) is not None] + + +def _format_forwarded_prompt(forwarded: list[str], prompt: str) -> str: + if not forwarded: + return prompt + separator = "\n\n" + forward_block = separator.join(forwarded) + if prompt.strip(): + return f"{prompt}{separator}{forward_block}" + return forward_block + + def _diff_keys(old: dict[str, object], new: dict[str, object]) -> list[str]: keys = set(old) | set(new) return sorted(key for key in keys if old.get(key) != new.get(key)) @@ -387,9 +442,11 @@ async def run_main_loop( chat_session_store: ChatSessionStore | None = None chat_prefs: ChatPrefsStore | None = None media_groups: dict[tuple[int, str], _MediaGroupState] = {} + pending_prompts: dict[ForwardKey, _PendingPrompt] = {} resolved_topics_scope: str | None = None topics_chat_ids: frozenset[int] = frozenset() bot_username: str | None = None + forward_coalesce_s = max(0.0, float(cfg.forward_coalesce_s)) def refresh_topics_scope() -> None: nonlocal resolved_topics_scope, topics_chat_ids @@ -768,6 +825,305 @@ async def run_main_loop( progress_ref, ) + async def _dispatch_pending_prompt(pending: _PendingPrompt) -> None: + msg = pending.msg + chat_id = msg.chat_id + user_msg_id = msg.message_id + reply = make_reply(cfg, msg) + try: + resolved = cfg.runtime.resolve_message( + text=pending.text, + reply_text=msg.reply_to_text, + ambient_context=pending.ambient_context, + chat_id=chat_id, + ) + except DirectiveError as exc: + await reply(text=f"error:\n{exc}") + return + if pending.is_voice_transcribed: + resolved = ResolvedMessage( + prompt=f"(voice transcribed) {resolved.prompt}", + resume_token=resolved.resume_token, + engine_override=resolved.engine_override, + context=resolved.context, + context_source=resolved.context_source, + ) + + prompt_text = resolved.prompt + if pending.forwards: + forwarded = [ + text + for _, text in sorted( + pending.forwards, + key=lambda item: item[0], + ) + ] + prompt_text = _format_forwarded_prompt( + forwarded, + prompt_text, + ) + + resume_token = resolved.resume_token + context = resolved.context + engine_resolution = await resolve_engine_defaults( + explicit_engine=resolved.engine_override, + context=context, + chat_id=chat_id, + topic_key=pending.topic_key, + ) + engine_override = engine_resolution.engine + effective_context = pending.ambient_context + if ( + topic_store is not None + and pending.topic_key is not None + and resolved.context is not None + and resolved.context_source == "directives" + ): + await topic_store.set_context(*pending.topic_key, resolved.context) + await _maybe_rename_topic( + cfg, + topic_store, + chat_id=pending.topic_key[0], + thread_id=pending.topic_key[1], + context=resolved.context, + ) + effective_context = resolved.context + if ( + topic_store is not None + and pending.topic_key is not None + and effective_context is None + and resolved.context_source not in {"directives", "reply_ctx"} + ): + await reply( + text="this topic isn't bound to a project yet.\n" + f"{_usage_ctx_set(chat_project=pending.chat_project)} or " + f"{_usage_topic(chat_project=pending.chat_project)}", + ) + return + if resume_token is None and pending.reply_id is not None: + running_task = running_tasks.get( + MessageRef(channel_id=chat_id, message_id=pending.reply_id) + ) + if running_task is not None: + tg.start_soon( + send_with_resume, + cfg, + scheduler.enqueue_resume, + running_task, + chat_id, + user_msg_id, + msg.thread_id, + pending.chat_session_key, + prompt_text, + ) + return + if ( + resume_token is None + and topic_store is not None + and pending.topic_key is not None + ): + engine_for_session = engine_resolution.engine + stored = await topic_store.get_session_resume( + pending.topic_key[0], + pending.topic_key[1], + engine_for_session, + ) + if stored is not None: + resume_token = stored + if ( + resume_token is None + and chat_session_store is not None + and pending.chat_session_key is not None + ): + engine_for_session = engine_resolution.engine + stored = await chat_session_store.get_session_resume( + pending.chat_session_key[0], + pending.chat_session_key[1], + engine_for_session, + ) + if stored is not None: + resume_token = stored + + if resume_token is None: + tg.start_soon( + run_job, + chat_id, + user_msg_id, + prompt_text, + None, + context, + msg.thread_id, + pending.chat_session_key, + pending.reply_ref, + scheduler.note_thread_known, + engine_override, + ) + return + progress_ref = await _send_queued_progress( + cfg, + chat_id=chat_id, + user_msg_id=user_msg_id, + thread_id=msg.thread_id, + resume_token=resume_token, + context=context, + ) + await scheduler.enqueue_resume( + chat_id, + user_msg_id, + prompt_text, + resume_token, + context, + msg.thread_id, + pending.chat_session_key, + progress_ref, + ) + + async def _debounce_prompt_run( + key: ForwardKey, pending: _PendingPrompt + ) -> None: + try: + with anyio.CancelScope() as scope: + pending.cancel_scope = scope + await anyio.sleep(forward_coalesce_s) + except anyio.get_cancelled_exc_class(): + return + if pending_prompts.get(key) is not pending: + return + pending_prompts.pop(key, None) + logger.debug( + "forward.prompt.run", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + message_id=pending.msg.message_id, + forward_count=len(pending.forwards), + debounce_s=forward_coalesce_s, + ) + await _dispatch_pending_prompt(pending) + + def _reschedule_prompt(key: ForwardKey, pending: _PendingPrompt) -> None: + if pending.cancel_scope is not None: + pending.cancel_scope.cancel() + pending.cancel_scope = None + tg.start_soon(_debounce_prompt_run, key, pending) + + def _cancel_pending_prompt(key: ForwardKey) -> None: + pending = pending_prompts.pop(key, None) + if pending is None: + return + if pending.cancel_scope is not None: + pending.cancel_scope.cancel() + logger.debug( + "forward.prompt.cancelled", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + message_id=pending.msg.message_id, + forward_count=len(pending.forwards), + ) + + def _schedule_prompt( + pending: _PendingPrompt, + ) -> None: + if pending.msg.sender_id is None: + logger.debug( + "forward.prompt.bypass", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + message_id=pending.msg.message_id, + reason="missing_sender", + ) + tg.start_soon(_dispatch_pending_prompt, pending) + return + if forward_coalesce_s <= 0: + logger.debug( + "forward.prompt.bypass", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + message_id=pending.msg.message_id, + reason="disabled", + ) + tg.start_soon(_dispatch_pending_prompt, pending) + return + key = _forward_key(pending.msg) + existing = pending_prompts.get(key) + if existing is not None: + if existing.cancel_scope is not None: + existing.cancel_scope.cancel() + if existing.forwards: + pending.forwards = list(existing.forwards) + logger.debug( + "forward.prompt.replace", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + old_message_id=existing.msg.message_id, + new_message_id=pending.msg.message_id, + forward_count=len(pending.forwards), + ) + pending_prompts[key] = pending + logger.debug( + "forward.prompt.schedule", + chat_id=pending.msg.chat_id, + thread_id=pending.msg.thread_id, + sender_id=pending.msg.sender_id, + message_id=pending.msg.message_id, + debounce_s=forward_coalesce_s, + ) + _reschedule_prompt(key, pending) + + def _attach_forward(msg: TelegramIncomingMessage) -> None: + if msg.sender_id is None: + logger.debug( + "forward.message.ignored", + chat_id=msg.chat_id, + thread_id=msg.thread_id, + sender_id=msg.sender_id, + message_id=msg.message_id, + reason="missing_sender", + ) + return + key = _forward_key(msg) + pending = pending_prompts.get(key) + if pending is None: + logger.debug( + "forward.message.ignored", + chat_id=msg.chat_id, + thread_id=msg.thread_id, + sender_id=msg.sender_id, + message_id=msg.message_id, + reason="no_pending_prompt", + ) + return + text = msg.text + if not text.strip(): + logger.debug( + "forward.message.ignored", + chat_id=msg.chat_id, + thread_id=msg.thread_id, + sender_id=msg.sender_id, + message_id=msg.message_id, + reason="empty_text", + ) + return + pending.forwards.append((msg.message_id, text)) + logger.debug( + "forward.message.attached", + chat_id=msg.chat_id, + thread_id=msg.thread_id, + sender_id=msg.sender_id, + message_id=msg.message_id, + prompt_message_id=pending.msg.message_id, + forward_count=len(pending.forwards), + forward_fields=_forward_fields_present(msg.raw), + forward_date=msg.raw.get("forward_date") if msg.raw else None, + message_date=msg.raw.get("date") if msg.raw else None, + text_len=len(text), + ) + _reschedule_prompt(key, pending) + async def handle_prompt_upload( msg: TelegramIncomingMessage, caption_text: str, @@ -848,7 +1204,6 @@ async def run_main_loop( msg.callback_query_id, ) continue - user_msg_id = msg.message_id chat_id = msg.chat_id reply_id = msg.reply_to_message_id reply_ref = ( @@ -859,6 +1214,10 @@ async def run_main_loop( reply = make_reply(cfg, msg) text = msg.text is_voice_transcribed = False + if _is_forwarded(msg.raw): + _attach_forward(msg) + continue + forward_key = _forward_key(msg) if ( cfg.files.enabled and msg.document is not None @@ -898,6 +1257,7 @@ async def run_main_loop( command_id, args_text = _parse_slash_command(text) if command_id == "new": + _cancel_pending_prompt(forward_key) if topic_store is not None and topic_key is not None: tg.start_soon( partial( @@ -1036,135 +1396,31 @@ async def run_main_loop( ) continue - reply_text = msg.reply_to_text - try: - resolved = cfg.runtime.resolve_message( - text=text, - reply_text=reply_text, - ambient_context=ambient_context, - chat_id=chat_id, - ) - except DirectiveError as exc: - await reply(text=f"error:\n{exc}") - continue - if is_voice_transcribed: - resolved = ResolvedMessage( - prompt=f"(voice transcribed) {resolved.prompt}", - resume_token=resolved.resume_token, - engine_override=resolved.engine_override, - context=resolved.context, - context_source=resolved.context_source, - ) - - text = resolved.prompt - resume_token = resolved.resume_token - context = resolved.context - engine_resolution = await resolve_engine_defaults( - explicit_engine=resolved.engine_override, - context=context, - chat_id=chat_id, + pending = _PendingPrompt( + msg=msg, + text=text, + ambient_context=ambient_context, + chat_project=chat_project, topic_key=topic_key, + chat_session_key=chat_session_key, + reply_ref=reply_ref, + reply_id=reply_id, + is_voice_transcribed=is_voice_transcribed, + forwards=[], ) - engine_override = engine_resolution.engine - if ( - topic_store is not None - and topic_key is not None - and resolved.context is not None - and resolved.context_source == "directives" + if reply_id is not None and running_tasks.get( + MessageRef(channel_id=chat_id, message_id=reply_id) ): - await topic_store.set_context(*topic_key, resolved.context) - await _maybe_rename_topic( - cfg, - topic_store, - chat_id=topic_key[0], - thread_id=topic_key[1], - context=resolved.context, - ) - ambient_context = resolved.context - if ( - topic_store is not None - and topic_key is not None - and ambient_context is None - and resolved.context_source not in {"directives", "reply_ctx"} - ): - await reply( - text="this topic isn't bound to a project yet.\n" - f"{_usage_ctx_set(chat_project=chat_project)} or " - f"{_usage_topic(chat_project=chat_project)}", - ) - continue - if resume_token is None and reply_id is not None: - running_task = running_tasks.get( - MessageRef(channel_id=chat_id, message_id=reply_id) - ) - if running_task is not None: - tg.start_soon( - send_with_resume, - cfg, - scheduler.enqueue_resume, - running_task, - chat_id, - user_msg_id, - msg.thread_id, - chat_session_key, - text, - ) - continue - if ( - resume_token is None - and topic_store is not None - and topic_key is not None - ): - engine_for_session = engine_resolution.engine - stored = await topic_store.get_session_resume( - topic_key[0], topic_key[1], engine_for_session - ) - if stored is not None: - resume_token = stored - if ( - resume_token is None - and chat_session_store is not None - and chat_session_key is not None - ): - engine_for_session = engine_resolution.engine - stored = await chat_session_store.get_session_resume( - chat_session_key[0], chat_session_key[1], engine_for_session - ) - if stored is not None: - resume_token = stored - - if resume_token is None: - tg.start_soon( - run_job, - chat_id, - user_msg_id, - text, - None, - context, - msg.thread_id, - chat_session_key, - reply_ref, - scheduler.note_thread_known, - engine_override, - ) - else: - progress_ref = await _send_queued_progress( - cfg, + logger.debug( + "forward.prompt.bypass", chat_id=chat_id, - user_msg_id=user_msg_id, thread_id=msg.thread_id, - resume_token=resume_token, - context=context, - ) - await scheduler.enqueue_resume( - chat_id, - user_msg_id, - text, - resume_token, - context, - msg.thread_id, - chat_session_key, - progress_ref, + sender_id=msg.sender_id, + message_id=msg.message_id, + reason="reply_resume", ) + tg.start_soon(_dispatch_pending_prompt, pending) + continue + _schedule_prompt(pending) finally: await cfg.exec_cfg.transport.close() diff --git a/tests/test_telegram_bridge.py b/tests/test_telegram_bridge.py index 087b5c7..404931e 100644 --- a/tests/test_telegram_bridge.py +++ b/tests/test_telegram_bridge.py @@ -1149,6 +1149,19 @@ def test_resolve_message_accepts_backticked_ctx_line() -> None: assert resolved.context == RunContext(project="takopi", branch="feat/api") +def test_is_forwarded_detects_forward_fields() -> None: + assert telegram_loop._is_forwarded({"forward_origin": {"type": "user"}}) + assert telegram_loop._is_forwarded({"forward_from": {"id": 1}}) + assert telegram_loop._is_forwarded({"forward_from_chat": {"id": 1}}) + assert telegram_loop._is_forwarded({"forward_from_message_id": 2}) + assert telegram_loop._is_forwarded({"forward_sender_name": "anon"}) + assert telegram_loop._is_forwarded({"forward_signature": "sig"}) + assert telegram_loop._is_forwarded({"forward_date": 123}) + assert telegram_loop._is_forwarded({"is_automatic_forward": True}) + assert not telegram_loop._is_forwarded({"text": "hello"}) + assert not telegram_loop._is_forwarded(None) + + def test_topic_title_matches_command_syntax() -> None: transport = _FakeTransport() cfg = _make_cfg(transport) @@ -1979,6 +1992,129 @@ async def test_run_main_loop_voice_transcript_preserves_directive( assert codex_runner.calls[0][0].startswith("(voice transcribed) do thing") +@pytest.mark.anyio +async def test_run_main_loop_debounces_forwarded_messages_preserves_directives() -> ( + None +): + codex_runner = ScriptRunner([Return(answer="codex")], engine=CODEX_ENGINE) + claude_runner = ScriptRunner([Return(answer="claude")], engine="claude") + router = AutoRouter( + entries=[ + RunnerEntry(engine=claude_runner.engine, runner=claude_runner), + RunnerEntry(engine=codex_runner.engine, runner=codex_runner), + ], + default_engine=claude_runner.engine, + ) + runtime = TransportRuntime(router=router, projects=_empty_projects()) + transport = _FakeTransport() + exec_cfg = ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ) + cfg = TelegramBridgeConfig( + bot=_FakeBot(), + runtime=runtime, + chat_id=123, + startup_msg="", + exec_cfg=exec_cfg, + ) + + async def poller(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=1, + text="/codex summarize these", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + ) + await anyio.sleep(_cfg.forward_coalesce_s / 2) + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=2, + text="a", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + raw={"forward_origin": {"type": "user"}}, + ) + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=3, + text="b", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + raw={"forward_origin": {"type": "user"}}, + ) + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=4, + text="c", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + raw={"forward_origin": {"type": "user"}}, + ) + + await run_main_loop(cfg, poller) + + assert not claude_runner.calls + assert len(codex_runner.calls) == 1 + prompt_text, _ = codex_runner.calls[0] + assert prompt_text == "summarize these\n\na\n\nb\n\nc" + + +@pytest.mark.anyio +async def test_run_main_loop_ignores_forwarded_without_prompt() -> None: + runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) + runtime = TransportRuntime(router=_make_router(runner), projects=_empty_projects()) + transport = _FakeTransport() + exec_cfg = ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ) + cfg = TelegramBridgeConfig( + bot=_FakeBot(), + runtime=runtime, + chat_id=123, + startup_msg="", + exec_cfg=exec_cfg, + ) + + async def poller(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=1, + text="a", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + raw={"forward_origin": {"type": "user"}}, + ) + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=2, + text="b", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + raw={"forward_origin": {"type": "user"}}, + ) + + await run_main_loop(cfg, poller) + + assert runner.calls == [] + + @pytest.mark.anyio async def test_run_main_loop_prompt_upload_auto_resumes_chat_sessions( tmp_path: Path,