diff --git a/src/takopi/telegram/loop.py b/src/takopi/telegram/loop.py index 8fa69e3..6c6ba12 100644 --- a/src/takopi/telegram/loop.py +++ b/src/takopi/telegram/loop.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import deque from collections.abc import AsyncIterator, Awaitable, Callable, Mapping from dataclasses import dataclass from functools import partial @@ -77,6 +78,9 @@ logger = get_logger(__name__) __all__ = ["poll_updates", "run_main_loop", "send_with_resume"] ForwardKey = tuple[int, int, int] +MessageKey = tuple[int, int] +_SEEN_MESSAGES_LIMIT = 2048 +_SEEN_UPDATES_LIMIT = 4096 _handle_file_put_default = handle_file_put_default @@ -360,6 +364,16 @@ class TelegramMsgContext: ambient_context: RunContext | None +@dataclass(frozen=True, slots=True) +class MessageClassification: + text: str + command_id: str | None + args_text: str + is_cancel: bool + is_forward_candidate: bool + is_media_group_document: bool + + @dataclass(frozen=True, slots=True) class TelegramCommandContext: cfg: TelegramBridgeConfig @@ -374,6 +388,30 @@ class TelegramCommandContext: task_group: TaskGroup +def _classify_message( + msg: TelegramIncomingMessage, *, files_enabled: bool +) -> MessageClassification: + text = msg.text + command_id, args_text = parse_slash_command(text) + is_forward_candidate = ( + _is_forwarded(msg.raw) + and msg.document is None + and msg.voice is None + and msg.media_group_id is None + ) + is_media_group_document = ( + files_enabled and msg.document is not None and msg.media_group_id is not None + ) + return MessageClassification( + text=text, + command_id=command_id, + args_text=args_text, + is_cancel=is_cancel_command(text), + is_forward_candidate=is_forward_candidate, + is_media_group_document=is_media_group_document, + ) + + @dataclass(slots=True) class TelegramLoopState: running_tasks: RunningTasks @@ -392,6 +430,10 @@ class TelegramLoopState: forward_coalesce_s: float media_group_debounce_s: float transport_id: str | None + seen_update_ids: set[int] + seen_update_order: deque[int] + seen_message_keys: set[MessageKey] + seen_messages_order: deque[MessageKey] if TYPE_CHECKING: @@ -931,6 +973,10 @@ async def run_main_loop( forward_coalesce_s=max(0.0, float(cfg.forward_coalesce_s)), media_group_debounce_s=max(0.0, float(cfg.media_group_debounce_s)), transport_id=transport_id, + seen_update_ids=set(), + seen_update_order=deque(), + seen_message_keys=set(), + seen_messages_order=deque(), ) def refresh_topics_scope() -> None: @@ -1199,6 +1245,47 @@ async def run_main_loop( await reply(text=f"error:\n{exc}") return None topic_key = resolve_topic_key(msg) + chat_project = ( + _topics_chat_project(cfg, msg.chat_id) + if cfg.topics.enabled + else None + ) + _, ok = await ensure_topic_context( + resolved=resolved, + ambient_context=ambient_context, + topic_key=topic_key, + chat_project=chat_project, + reply=reply, + ) + if not ok: + return None + return resolved + + async def resolve_engine_defaults( + *, + explicit_engine: EngineId | None, + context: RunContext | None, + chat_id: int, + topic_key: tuple[int, int] | None, + ): + return await resolve_engine_for_message( + runtime=cfg.runtime, + context=context, + explicit_engine=explicit_engine, + chat_id=chat_id, + topic_key=topic_key, + topic_store=state.topic_store, + chat_prefs=state.chat_prefs, + ) + + async def ensure_topic_context( + *, + resolved: ResolvedMessage, + ambient_context: RunContext | None, + topic_key: tuple[int, int] | None, + chat_project: str | None, + reply: Callable[..., Awaitable[None]], + ) -> tuple[RunContext | None, bool]: effective_context = ambient_context if ( state.topic_store is not None @@ -1221,35 +1308,13 @@ async def run_main_loop( and effective_context is None and resolved.context_source not in {"directives", "reply_ctx"} ): - chat_project = ( - _topics_chat_project(cfg, msg.chat_id) - if cfg.topics.enabled - else None - ) await reply( text="this topic isn't bound to a project yet.\n" f"{_usage_ctx_set(chat_project=chat_project)} or " f"{_usage_topic(chat_project=chat_project)}", ) - return None - return resolved - - async def resolve_engine_defaults( - *, - explicit_engine: EngineId | None, - context: RunContext | None, - chat_id: int, - topic_key: tuple[int, int] | None, - ): - return await resolve_engine_for_message( - runtime=cfg.runtime, - context=context, - explicit_engine=explicit_engine, - chat_id=chat_id, - topic_key=topic_key, - topic_store=state.topic_store, - chat_prefs=state.chat_prefs, - ) + return effective_context, False + return effective_context, True resume_resolver = ResumeResolver( cfg=cfg, @@ -1260,29 +1325,19 @@ async def run_main_loop( chat_session_store=state.chat_session_store, ) - async def run_prompt_from_upload( + async def dispatch_prompt_run( + *, msg: TelegramIncomingMessage, prompt_text: str, resolved: ResolvedMessage, + topic_key: tuple[int, int] | None, + chat_session_key: tuple[int, int | None] | None, + reply_ref: MessageRef | None, + reply_id: int | None, ) -> None: chat_id = msg.chat_id user_msg_id = msg.message_id - reply_id = msg.reply_to_message_id - reply_ref = ( - MessageRef( - channel_id=msg.chat_id, - message_id=msg.reply_to_message_id, - thread_id=msg.thread_id, - ) - if msg.reply_to_message_id is not None - else None - ) - resume_token = resolved.resume_token context = resolved.context - chat_session_key = _chat_session_key( - msg, store=state.chat_session_store - ) - topic_key = resolve_topic_key(msg) engine_resolution = await resolve_engine_defaults( explicit_engine=resolved.engine_override, context=context, @@ -1291,7 +1346,7 @@ async def run_main_loop( ) engine_override = engine_resolution.engine resume_decision = await resume_resolver.resolve( - resume_token=resume_token, + resume_token=resolved.resume_token, reply_id=reply_id, chat_id=chat_id, user_msg_id=user_msg_id, @@ -1337,17 +1392,44 @@ async def run_main_loop( progress_ref, ) + async def run_prompt_from_upload( + msg: TelegramIncomingMessage, + prompt_text: str, + resolved: ResolvedMessage, + ) -> None: + reply_id = msg.reply_to_message_id + reply_ref = ( + MessageRef( + channel_id=msg.chat_id, + message_id=msg.reply_to_message_id, + thread_id=msg.thread_id, + ) + if msg.reply_to_message_id is not None + else None + ) + chat_session_key = _chat_session_key( + msg, store=state.chat_session_store + ) + topic_key = resolve_topic_key(msg) + await dispatch_prompt_run( + msg=msg, + prompt_text=prompt_text, + resolved=resolved, + topic_key=topic_key, + chat_session_key=chat_session_key, + reply_ref=reply_ref, + reply_id=reply_id, + ) + async def _dispatch_pending_prompt(pending: _PendingPrompt) -> None: msg = pending.msg - chat_id = msg.chat_id - user_msg_id = msg.message_id reply = make_reply(cfg, msg) try: resolved = cfg.runtime.resolve_message( text=pending.text, reply_text=msg.reply_to_text, ambient_context=pending.ambient_context, - chat_id=chat_id, + chat_id=msg.chat_id, ) except DirectiveError as exc: await reply(text=f"error:\n{exc}") @@ -1375,92 +1457,23 @@ async def run_main_loop( prompt_text, ) - resume_token = resolved.resume_token - context = resolved.context - engine_resolution = await resolve_engine_defaults( - explicit_engine=resolved.engine_override, - context=context, - chat_id=chat_id, + _effective_context, ok = await ensure_topic_context( + resolved=resolved, + ambient_context=pending.ambient_context, topic_key=pending.topic_key, + chat_project=pending.chat_project, + reply=reply, ) - engine_override = engine_resolution.engine - effective_context = pending.ambient_context - if ( - state.topic_store is not None - and pending.topic_key is not None - and resolved.context is not None - and resolved.context_source == "directives" - ): - await state.topic_store.set_context( - *pending.topic_key, resolved.context - ) - await _maybe_rename_topic( - cfg, - state.topic_store, - chat_id=pending.topic_key[0], - thread_id=pending.topic_key[1], - context=resolved.context, - ) - effective_context = resolved.context - if ( - state.topic_store is not None - and pending.topic_key is not None - and effective_context is None - and resolved.context_source not in {"directives", "reply_ctx"} - ): - await reply( - text="this topic isn't bound to a project yet.\n" - f"{_usage_ctx_set(chat_project=pending.chat_project)} or " - f"{_usage_topic(chat_project=pending.chat_project)}", - ) + if not ok: return - resume_decision = await resume_resolver.resolve( - resume_token=resume_token, - reply_id=pending.reply_id, - chat_id=chat_id, - user_msg_id=user_msg_id, - thread_id=msg.thread_id, - chat_session_key=pending.chat_session_key, - topic_key=pending.topic_key, - engine_for_session=engine_resolution.engine, + await dispatch_prompt_run( + msg=msg, prompt_text=prompt_text, - ) - if resume_decision.handled_by_running_task: - return - resume_token = resume_decision.resume_token - - if resume_token is None: - tg.start_soon( - run_job, - chat_id, - user_msg_id, - prompt_text, - None, - context, - msg.thread_id, - pending.chat_session_key, - pending.reply_ref, - scheduler.note_thread_known, - engine_override, - ) - return - progress_ref = await _send_queued_progress( - cfg, - chat_id=chat_id, - user_msg_id=user_msg_id, - thread_id=msg.thread_id, - resume_token=resume_token, - context=context, - ) - await scheduler.enqueue_resume( - chat_id, - user_msg_id, - prompt_text, - resume_token, - context, - msg.thread_id, - pending.chat_session_key, - progress_ref, + resolved=resolved, + topic_key=pending.topic_key, + chat_session_key=pending.chat_session_key, + reply_ref=pending.reply_ref, + reply_id=pending.reply_id, ) forward_coalescer = ForwardCoalescer( @@ -1562,23 +1575,14 @@ async def run_main_loop( async def route_message(msg: TelegramIncomingMessage) -> None: reply = make_reply(cfg, msg) - text = msg.text + classification = _classify_message(msg, files_enabled=cfg.files.enabled) + text = classification.text is_voice_transcribed = False - is_forward_candidate = ( - _is_forwarded(msg.raw) - and msg.document is None - and msg.voice is None - and msg.media_group_id is None - ) - if is_forward_candidate: + if classification.is_forward_candidate: forward_coalescer.attach_forward(msg) return forward_key = _forward_key(msg) - if ( - cfg.files.enabled - and msg.document is not None - and msg.media_group_id is not None - ): + if classification.is_media_group_document: media_group_buffer.add(msg) return ctx = await build_message_context(msg) @@ -1591,13 +1595,14 @@ async def run_main_loop( chat_project = ctx.chat_project ambient_context = ctx.ambient_context - if is_cancel_command(text): + if classification.is_cancel: tg.start_soon( handle_cancel, cfg, msg, state.running_tasks, scheduler ) return - command_id, args_text = parse_slash_command(text) + command_id = classification.command_id + args_text = classification.args_text if command_id == "new": forward_coalescer.cancel(forward_key) if state.topic_store is not None and topic_key is not None: @@ -1793,6 +1798,38 @@ async def run_main_loop( sender_id=sender_id, ) return + if update.update_id is not None: + update_id = update.update_id + if update_id in state.seen_update_ids: + logger.debug( + "update.ignored", + reason="duplicate_update", + update_id=update_id, + chat_id=update.chat_id, + sender_id=update.sender_id, + ) + return + state.seen_update_ids.add(update_id) + state.seen_update_order.append(update_id) + if len(state.seen_update_order) > _SEEN_UPDATES_LIMIT: + oldest_update_id = state.seen_update_order.popleft() + state.seen_update_ids.discard(oldest_update_id) + elif isinstance(update, TelegramIncomingMessage): + key = (update.chat_id, update.message_id) + if key in state.seen_message_keys: + logger.debug( + "update.ignored", + reason="duplicate_message", + chat_id=update.chat_id, + message_id=update.message_id, + sender_id=update.sender_id, + ) + return + state.seen_message_keys.add(key) + state.seen_messages_order.append(key) + if len(state.seen_messages_order) > _SEEN_MESSAGES_LIMIT: + oldest = state.seen_messages_order.popleft() + state.seen_message_keys.discard(oldest) if isinstance(update, TelegramCallbackQuery): if update.data == CANCEL_CALLBACK_DATA: tg.start_soon( diff --git a/src/takopi/telegram/parsing.py b/src/takopi/telegram/parsing.py index e1cc487..6ab678e 100644 --- a/src/takopi/telegram/parsing.py +++ b/src/takopi/telegram/parsing.py @@ -36,12 +36,14 @@ def parse_incoming_update( if update.message is not None: return _parse_incoming_message( update.message, + update_id=update.update_id, chat_id=chat_id, chat_ids=chat_ids, ) if update.callback_query is not None: return _parse_callback_query( update.callback_query, + update_id=update.update_id, chat_id=chat_id, chat_ids=chat_ids, ) @@ -51,6 +53,7 @@ def parse_incoming_update( def _parse_incoming_message( msg: Message, *, + update_id: int | None = None, chat_id: int | None = None, chat_ids: set[int] | None = None, ) -> TelegramIncomingMessage | None: @@ -133,12 +136,14 @@ def _parse_incoming_message( voice=voice_payload, document=document_payload, raw=msgspec.to_builtins(msg), + update_id=update_id, ) def _parse_callback_query( query: CallbackQuery, *, + update_id: int | None = None, chat_id: int | None = None, chat_ids: set[int] | None = None, ) -> TelegramCallbackQuery | None: @@ -162,6 +167,7 @@ def _parse_callback_query( data=data, sender_id=sender_id, raw=msgspec.to_builtins(query), + update_id=update_id, ) diff --git a/src/takopi/telegram/types.py b/src/takopi/telegram/types.py index 193c3de..4706faf 100644 --- a/src/takopi/telegram/types.py +++ b/src/takopi/telegram/types.py @@ -41,6 +41,7 @@ class TelegramIncomingMessage: voice: TelegramVoice | None = None document: TelegramDocument | None = None raw: dict[str, Any] | None = None + update_id: int | None = None @property def is_private(self) -> bool: @@ -58,6 +59,7 @@ class TelegramCallbackQuery: data: str | None sender_id: int | None raw: dict[str, Any] | None = None + update_id: int | None = None TelegramIncomingUpdate = TelegramIncomingMessage | TelegramCallbackQuery diff --git a/tests/test_telegram_bridge.py b/tests/test_telegram_bridge.py index f41da99..869afd3 100644 --- a/tests/test_telegram_bridge.py +++ b/tests/test_telegram_bridge.py @@ -1649,6 +1649,115 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None: tg.cancel_scope.cancel() +@pytest.mark.anyio +async def test_run_main_loop_ignores_duplicate_message_id_for_replies() -> None: + transport = FakeTransport() + bot = FakeBot() + codex_runner = ScriptRunner([Return(answer="codex")], engine=CODEX_ENGINE) + claude_runner = ScriptRunner([Return(answer="claude")], engine="claude") + router = AutoRouter( + entries=[ + RunnerEntry(engine=codex_runner.engine, runner=codex_runner), + RunnerEntry(engine=claude_runner.engine, runner=claude_runner), + ], + default_engine=claude_runner.engine, + ) + runtime = TransportRuntime(router=router, projects=_empty_projects()) + cfg = TelegramBridgeConfig( + bot=bot, + runtime=runtime, + chat_id=123, + startup_msg="", + exec_cfg=ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ), + forward_coalesce_s=FAST_FORWARD_COALESCE_S, + media_group_debounce_s=FAST_MEDIA_GROUP_DEBOUNCE_S, + ) + + async def poller(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=42, + text="turn on logging in my lemon config for me", + reply_to_message_id=900, + reply_to_text="done\n`codex resume c-123`", + sender_id=123, + chat_type="private", + ) + # Telegram can occasionally redeliver the same message id with less reply metadata. + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=42, + text="turn on logging in my lemon config for me", + reply_to_message_id=900, + reply_to_text=None, + sender_id=123, + chat_type="private", + ) + + await run_main_loop(cfg, poller) + + assert len(codex_runner.calls) == 1 + assert codex_runner.calls[0][1] == ResumeToken(engine=CODEX_ENGINE, value="c-123") + assert claude_runner.calls == [] + + +@pytest.mark.anyio +async def test_run_main_loop_ignores_duplicate_update_id() -> None: + transport = FakeTransport() + bot = FakeBot() + runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE) + runtime = TransportRuntime(router=_make_router(runner), projects=_empty_projects()) + cfg = TelegramBridgeConfig( + bot=bot, + runtime=runtime, + chat_id=123, + startup_msg="", + exec_cfg=ExecBridgeConfig( + transport=transport, + presenter=MarkdownPresenter(), + final_notify=True, + ), + forward_coalesce_s=FAST_FORWARD_COALESCE_S, + media_group_debounce_s=FAST_MEDIA_GROUP_DEBOUNCE_S, + ) + + async def poller(_cfg: TelegramBridgeConfig): + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=1, + text="first", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + chat_type="private", + update_id=9001, + ) + # Same Telegram update id redelivered. + yield TelegramIncomingMessage( + transport="telegram", + chat_id=123, + message_id=2, + text="second", + reply_to_message_id=None, + reply_to_text=None, + sender_id=123, + chat_type="private", + update_id=9001, + ) + + await run_main_loop(cfg, poller) + + assert len(runner.calls) == 1 + assert runner.calls[0][0] == "first" + + @pytest.mark.anyio async def test_run_main_loop_persists_topic_sessions_in_project_scope( tmp_path: Path, diff --git a/tests/test_telegram_incoming.py b/tests/test_telegram_incoming.py index 382f512..ba3abc1 100644 --- a/tests/test_telegram_incoming.py +++ b/tests/test_telegram_incoming.py @@ -55,6 +55,7 @@ def test_parse_incoming_update_maps_fields() -> None: assert msg.document is None assert msg.raw assert msg.raw["message_id"] == 10 + assert msg.update_id == 1 def test_parse_incoming_update_ignores_implicit_topic_reply() -> None: @@ -83,6 +84,7 @@ def test_parse_incoming_update_ignores_implicit_topic_reply() -> None: assert msg.reply_to_text is None assert msg.reply_to_is_bot is None assert msg.reply_to_username is None + assert msg.update_id == 1 def test_parse_incoming_update_filters_non_matching_chat() -> None: @@ -295,6 +297,7 @@ def test_parse_incoming_update_callback_query() -> None: assert msg.callback_query_id == "cbq-1" assert msg.data == "takopi:cancel" assert msg.sender_id == 321 + assert msg.update_id == 1 def test_parse_incoming_update_topic_fields() -> None: