fix(telegram): dedupe duplicate incoming messages (#198)

This commit is contained in:
banteg
2026-02-08 19:10:34 +04:00
committed by GitHub
parent b293818195
commit fb51708473
5 changed files with 298 additions and 141 deletions
+178 -141
View File
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from collections import deque
from collections.abc import AsyncIterator, Awaitable, Callable, Mapping from collections.abc import AsyncIterator, Awaitable, Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
@@ -77,6 +78,9 @@ logger = get_logger(__name__)
__all__ = ["poll_updates", "run_main_loop", "send_with_resume"] __all__ = ["poll_updates", "run_main_loop", "send_with_resume"]
ForwardKey = tuple[int, int, int] ForwardKey = tuple[int, int, int]
MessageKey = tuple[int, int]
_SEEN_MESSAGES_LIMIT = 2048
_SEEN_UPDATES_LIMIT = 4096
_handle_file_put_default = handle_file_put_default _handle_file_put_default = handle_file_put_default
@@ -360,6 +364,16 @@ class TelegramMsgContext:
ambient_context: RunContext | None ambient_context: RunContext | None
@dataclass(frozen=True, slots=True)
class MessageClassification:
text: str
command_id: str | None
args_text: str
is_cancel: bool
is_forward_candidate: bool
is_media_group_document: bool
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class TelegramCommandContext: class TelegramCommandContext:
cfg: TelegramBridgeConfig cfg: TelegramBridgeConfig
@@ -374,6 +388,30 @@ class TelegramCommandContext:
task_group: TaskGroup task_group: TaskGroup
def _classify_message(
msg: TelegramIncomingMessage, *, files_enabled: bool
) -> MessageClassification:
text = msg.text
command_id, args_text = parse_slash_command(text)
is_forward_candidate = (
_is_forwarded(msg.raw)
and msg.document is None
and msg.voice is None
and msg.media_group_id is None
)
is_media_group_document = (
files_enabled and msg.document is not None and msg.media_group_id is not None
)
return MessageClassification(
text=text,
command_id=command_id,
args_text=args_text,
is_cancel=is_cancel_command(text),
is_forward_candidate=is_forward_candidate,
is_media_group_document=is_media_group_document,
)
@dataclass(slots=True) @dataclass(slots=True)
class TelegramLoopState: class TelegramLoopState:
running_tasks: RunningTasks running_tasks: RunningTasks
@@ -392,6 +430,10 @@ class TelegramLoopState:
forward_coalesce_s: float forward_coalesce_s: float
media_group_debounce_s: float media_group_debounce_s: float
transport_id: str | None transport_id: str | None
seen_update_ids: set[int]
seen_update_order: deque[int]
seen_message_keys: set[MessageKey]
seen_messages_order: deque[MessageKey]
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -931,6 +973,10 @@ async def run_main_loop(
forward_coalesce_s=max(0.0, float(cfg.forward_coalesce_s)), forward_coalesce_s=max(0.0, float(cfg.forward_coalesce_s)),
media_group_debounce_s=max(0.0, float(cfg.media_group_debounce_s)), media_group_debounce_s=max(0.0, float(cfg.media_group_debounce_s)),
transport_id=transport_id, transport_id=transport_id,
seen_update_ids=set(),
seen_update_order=deque(),
seen_message_keys=set(),
seen_messages_order=deque(),
) )
def refresh_topics_scope() -> None: def refresh_topics_scope() -> None:
@@ -1199,6 +1245,47 @@ async def run_main_loop(
await reply(text=f"error:\n{exc}") await reply(text=f"error:\n{exc}")
return None return None
topic_key = resolve_topic_key(msg) topic_key = resolve_topic_key(msg)
chat_project = (
_topics_chat_project(cfg, msg.chat_id)
if cfg.topics.enabled
else None
)
_, ok = await ensure_topic_context(
resolved=resolved,
ambient_context=ambient_context,
topic_key=topic_key,
chat_project=chat_project,
reply=reply,
)
if not ok:
return None
return resolved
async def resolve_engine_defaults(
*,
explicit_engine: EngineId | None,
context: RunContext | None,
chat_id: int,
topic_key: tuple[int, int] | None,
):
return await resolve_engine_for_message(
runtime=cfg.runtime,
context=context,
explicit_engine=explicit_engine,
chat_id=chat_id,
topic_key=topic_key,
topic_store=state.topic_store,
chat_prefs=state.chat_prefs,
)
async def ensure_topic_context(
*,
resolved: ResolvedMessage,
ambient_context: RunContext | None,
topic_key: tuple[int, int] | None,
chat_project: str | None,
reply: Callable[..., Awaitable[None]],
) -> tuple[RunContext | None, bool]:
effective_context = ambient_context effective_context = ambient_context
if ( if (
state.topic_store is not None state.topic_store is not None
@@ -1221,35 +1308,13 @@ async def run_main_loop(
and effective_context is None and effective_context is None
and resolved.context_source not in {"directives", "reply_ctx"} and resolved.context_source not in {"directives", "reply_ctx"}
): ):
chat_project = (
_topics_chat_project(cfg, msg.chat_id)
if cfg.topics.enabled
else None
)
await reply( await reply(
text="this topic isn't bound to a project yet.\n" text="this topic isn't bound to a project yet.\n"
f"{_usage_ctx_set(chat_project=chat_project)} or " f"{_usage_ctx_set(chat_project=chat_project)} or "
f"{_usage_topic(chat_project=chat_project)}", f"{_usage_topic(chat_project=chat_project)}",
) )
return None return effective_context, False
return resolved return effective_context, True
async def resolve_engine_defaults(
*,
explicit_engine: EngineId | None,
context: RunContext | None,
chat_id: int,
topic_key: tuple[int, int] | None,
):
return await resolve_engine_for_message(
runtime=cfg.runtime,
context=context,
explicit_engine=explicit_engine,
chat_id=chat_id,
topic_key=topic_key,
topic_store=state.topic_store,
chat_prefs=state.chat_prefs,
)
resume_resolver = ResumeResolver( resume_resolver = ResumeResolver(
cfg=cfg, cfg=cfg,
@@ -1260,29 +1325,19 @@ async def run_main_loop(
chat_session_store=state.chat_session_store, chat_session_store=state.chat_session_store,
) )
async def run_prompt_from_upload( async def dispatch_prompt_run(
*,
msg: TelegramIncomingMessage, msg: TelegramIncomingMessage,
prompt_text: str, prompt_text: str,
resolved: ResolvedMessage, resolved: ResolvedMessage,
topic_key: tuple[int, int] | None,
chat_session_key: tuple[int, int | None] | None,
reply_ref: MessageRef | None,
reply_id: int | None,
) -> None: ) -> None:
chat_id = msg.chat_id chat_id = msg.chat_id
user_msg_id = msg.message_id user_msg_id = msg.message_id
reply_id = msg.reply_to_message_id
reply_ref = (
MessageRef(
channel_id=msg.chat_id,
message_id=msg.reply_to_message_id,
thread_id=msg.thread_id,
)
if msg.reply_to_message_id is not None
else None
)
resume_token = resolved.resume_token
context = resolved.context context = resolved.context
chat_session_key = _chat_session_key(
msg, store=state.chat_session_store
)
topic_key = resolve_topic_key(msg)
engine_resolution = await resolve_engine_defaults( engine_resolution = await resolve_engine_defaults(
explicit_engine=resolved.engine_override, explicit_engine=resolved.engine_override,
context=context, context=context,
@@ -1291,7 +1346,7 @@ async def run_main_loop(
) )
engine_override = engine_resolution.engine engine_override = engine_resolution.engine
resume_decision = await resume_resolver.resolve( resume_decision = await resume_resolver.resolve(
resume_token=resume_token, resume_token=resolved.resume_token,
reply_id=reply_id, reply_id=reply_id,
chat_id=chat_id, chat_id=chat_id,
user_msg_id=user_msg_id, user_msg_id=user_msg_id,
@@ -1337,17 +1392,44 @@ async def run_main_loop(
progress_ref, progress_ref,
) )
async def run_prompt_from_upload(
msg: TelegramIncomingMessage,
prompt_text: str,
resolved: ResolvedMessage,
) -> None:
reply_id = msg.reply_to_message_id
reply_ref = (
MessageRef(
channel_id=msg.chat_id,
message_id=msg.reply_to_message_id,
thread_id=msg.thread_id,
)
if msg.reply_to_message_id is not None
else None
)
chat_session_key = _chat_session_key(
msg, store=state.chat_session_store
)
topic_key = resolve_topic_key(msg)
await dispatch_prompt_run(
msg=msg,
prompt_text=prompt_text,
resolved=resolved,
topic_key=topic_key,
chat_session_key=chat_session_key,
reply_ref=reply_ref,
reply_id=reply_id,
)
async def _dispatch_pending_prompt(pending: _PendingPrompt) -> None: async def _dispatch_pending_prompt(pending: _PendingPrompt) -> None:
msg = pending.msg msg = pending.msg
chat_id = msg.chat_id
user_msg_id = msg.message_id
reply = make_reply(cfg, msg) reply = make_reply(cfg, msg)
try: try:
resolved = cfg.runtime.resolve_message( resolved = cfg.runtime.resolve_message(
text=pending.text, text=pending.text,
reply_text=msg.reply_to_text, reply_text=msg.reply_to_text,
ambient_context=pending.ambient_context, ambient_context=pending.ambient_context,
chat_id=chat_id, chat_id=msg.chat_id,
) )
except DirectiveError as exc: except DirectiveError as exc:
await reply(text=f"error:\n{exc}") await reply(text=f"error:\n{exc}")
@@ -1375,92 +1457,23 @@ async def run_main_loop(
prompt_text, prompt_text,
) )
resume_token = resolved.resume_token _effective_context, ok = await ensure_topic_context(
context = resolved.context resolved=resolved,
engine_resolution = await resolve_engine_defaults( ambient_context=pending.ambient_context,
explicit_engine=resolved.engine_override,
context=context,
chat_id=chat_id,
topic_key=pending.topic_key, topic_key=pending.topic_key,
chat_project=pending.chat_project,
reply=reply,
) )
engine_override = engine_resolution.engine if not ok:
effective_context = pending.ambient_context
if (
state.topic_store is not None
and pending.topic_key is not None
and resolved.context is not None
and resolved.context_source == "directives"
):
await state.topic_store.set_context(
*pending.topic_key, resolved.context
)
await _maybe_rename_topic(
cfg,
state.topic_store,
chat_id=pending.topic_key[0],
thread_id=pending.topic_key[1],
context=resolved.context,
)
effective_context = resolved.context
if (
state.topic_store is not None
and pending.topic_key is not None
and effective_context is None
and resolved.context_source not in {"directives", "reply_ctx"}
):
await reply(
text="this topic isn't bound to a project yet.\n"
f"{_usage_ctx_set(chat_project=pending.chat_project)} or "
f"{_usage_topic(chat_project=pending.chat_project)}",
)
return return
resume_decision = await resume_resolver.resolve( await dispatch_prompt_run(
resume_token=resume_token, msg=msg,
reply_id=pending.reply_id,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=msg.thread_id,
chat_session_key=pending.chat_session_key,
topic_key=pending.topic_key,
engine_for_session=engine_resolution.engine,
prompt_text=prompt_text, prompt_text=prompt_text,
) resolved=resolved,
if resume_decision.handled_by_running_task: topic_key=pending.topic_key,
return chat_session_key=pending.chat_session_key,
resume_token = resume_decision.resume_token reply_ref=pending.reply_ref,
reply_id=pending.reply_id,
if resume_token is None:
tg.start_soon(
run_job,
chat_id,
user_msg_id,
prompt_text,
None,
context,
msg.thread_id,
pending.chat_session_key,
pending.reply_ref,
scheduler.note_thread_known,
engine_override,
)
return
progress_ref = await _send_queued_progress(
cfg,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=msg.thread_id,
resume_token=resume_token,
context=context,
)
await scheduler.enqueue_resume(
chat_id,
user_msg_id,
prompt_text,
resume_token,
context,
msg.thread_id,
pending.chat_session_key,
progress_ref,
) )
forward_coalescer = ForwardCoalescer( forward_coalescer = ForwardCoalescer(
@@ -1562,23 +1575,14 @@ async def run_main_loop(
async def route_message(msg: TelegramIncomingMessage) -> None: async def route_message(msg: TelegramIncomingMessage) -> None:
reply = make_reply(cfg, msg) reply = make_reply(cfg, msg)
text = msg.text classification = _classify_message(msg, files_enabled=cfg.files.enabled)
text = classification.text
is_voice_transcribed = False is_voice_transcribed = False
is_forward_candidate = ( if classification.is_forward_candidate:
_is_forwarded(msg.raw)
and msg.document is None
and msg.voice is None
and msg.media_group_id is None
)
if is_forward_candidate:
forward_coalescer.attach_forward(msg) forward_coalescer.attach_forward(msg)
return return
forward_key = _forward_key(msg) forward_key = _forward_key(msg)
if ( if classification.is_media_group_document:
cfg.files.enabled
and msg.document is not None
and msg.media_group_id is not None
):
media_group_buffer.add(msg) media_group_buffer.add(msg)
return return
ctx = await build_message_context(msg) ctx = await build_message_context(msg)
@@ -1591,13 +1595,14 @@ async def run_main_loop(
chat_project = ctx.chat_project chat_project = ctx.chat_project
ambient_context = ctx.ambient_context ambient_context = ctx.ambient_context
if is_cancel_command(text): if classification.is_cancel:
tg.start_soon( tg.start_soon(
handle_cancel, cfg, msg, state.running_tasks, scheduler handle_cancel, cfg, msg, state.running_tasks, scheduler
) )
return return
command_id, args_text = parse_slash_command(text) command_id = classification.command_id
args_text = classification.args_text
if command_id == "new": if command_id == "new":
forward_coalescer.cancel(forward_key) forward_coalescer.cancel(forward_key)
if state.topic_store is not None and topic_key is not None: if state.topic_store is not None and topic_key is not None:
@@ -1793,6 +1798,38 @@ async def run_main_loop(
sender_id=sender_id, sender_id=sender_id,
) )
return return
if update.update_id is not None:
update_id = update.update_id
if update_id in state.seen_update_ids:
logger.debug(
"update.ignored",
reason="duplicate_update",
update_id=update_id,
chat_id=update.chat_id,
sender_id=update.sender_id,
)
return
state.seen_update_ids.add(update_id)
state.seen_update_order.append(update_id)
if len(state.seen_update_order) > _SEEN_UPDATES_LIMIT:
oldest_update_id = state.seen_update_order.popleft()
state.seen_update_ids.discard(oldest_update_id)
elif isinstance(update, TelegramIncomingMessage):
key = (update.chat_id, update.message_id)
if key in state.seen_message_keys:
logger.debug(
"update.ignored",
reason="duplicate_message",
chat_id=update.chat_id,
message_id=update.message_id,
sender_id=update.sender_id,
)
return
state.seen_message_keys.add(key)
state.seen_messages_order.append(key)
if len(state.seen_messages_order) > _SEEN_MESSAGES_LIMIT:
oldest = state.seen_messages_order.popleft()
state.seen_message_keys.discard(oldest)
if isinstance(update, TelegramCallbackQuery): if isinstance(update, TelegramCallbackQuery):
if update.data == CANCEL_CALLBACK_DATA: if update.data == CANCEL_CALLBACK_DATA:
tg.start_soon( tg.start_soon(
+6
View File
@@ -36,12 +36,14 @@ def parse_incoming_update(
if update.message is not None: if update.message is not None:
return _parse_incoming_message( return _parse_incoming_message(
update.message, update.message,
update_id=update.update_id,
chat_id=chat_id, chat_id=chat_id,
chat_ids=chat_ids, chat_ids=chat_ids,
) )
if update.callback_query is not None: if update.callback_query is not None:
return _parse_callback_query( return _parse_callback_query(
update.callback_query, update.callback_query,
update_id=update.update_id,
chat_id=chat_id, chat_id=chat_id,
chat_ids=chat_ids, chat_ids=chat_ids,
) )
@@ -51,6 +53,7 @@ def parse_incoming_update(
def _parse_incoming_message( def _parse_incoming_message(
msg: Message, msg: Message,
*, *,
update_id: int | None = None,
chat_id: int | None = None, chat_id: int | None = None,
chat_ids: set[int] | None = None, chat_ids: set[int] | None = None,
) -> TelegramIncomingMessage | None: ) -> TelegramIncomingMessage | None:
@@ -133,12 +136,14 @@ def _parse_incoming_message(
voice=voice_payload, voice=voice_payload,
document=document_payload, document=document_payload,
raw=msgspec.to_builtins(msg), raw=msgspec.to_builtins(msg),
update_id=update_id,
) )
def _parse_callback_query( def _parse_callback_query(
query: CallbackQuery, query: CallbackQuery,
*, *,
update_id: int | None = None,
chat_id: int | None = None, chat_id: int | None = None,
chat_ids: set[int] | None = None, chat_ids: set[int] | None = None,
) -> TelegramCallbackQuery | None: ) -> TelegramCallbackQuery | None:
@@ -162,6 +167,7 @@ def _parse_callback_query(
data=data, data=data,
sender_id=sender_id, sender_id=sender_id,
raw=msgspec.to_builtins(query), raw=msgspec.to_builtins(query),
update_id=update_id,
) )
+2
View File
@@ -41,6 +41,7 @@ class TelegramIncomingMessage:
voice: TelegramVoice | None = None voice: TelegramVoice | None = None
document: TelegramDocument | None = None document: TelegramDocument | None = None
raw: dict[str, Any] | None = None raw: dict[str, Any] | None = None
update_id: int | None = None
@property @property
def is_private(self) -> bool: def is_private(self) -> bool:
@@ -58,6 +59,7 @@ class TelegramCallbackQuery:
data: str | None data: str | None
sender_id: int | None sender_id: int | None
raw: dict[str, Any] | None = None raw: dict[str, Any] | None = None
update_id: int | None = None
TelegramIncomingUpdate = TelegramIncomingMessage | TelegramCallbackQuery TelegramIncomingUpdate = TelegramIncomingMessage | TelegramCallbackQuery
+109
View File
@@ -1649,6 +1649,115 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None:
tg.cancel_scope.cancel() tg.cancel_scope.cancel()
@pytest.mark.anyio
async def test_run_main_loop_ignores_duplicate_message_id_for_replies() -> None:
transport = FakeTransport()
bot = FakeBot()
codex_runner = ScriptRunner([Return(answer="codex")], engine=CODEX_ENGINE)
claude_runner = ScriptRunner([Return(answer="claude")], engine="claude")
router = AutoRouter(
entries=[
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
RunnerEntry(engine=claude_runner.engine, runner=claude_runner),
],
default_engine=claude_runner.engine,
)
runtime = TransportRuntime(router=router, projects=_empty_projects())
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=123,
startup_msg="",
exec_cfg=ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
),
forward_coalesce_s=FAST_FORWARD_COALESCE_S,
media_group_debounce_s=FAST_MEDIA_GROUP_DEBOUNCE_S,
)
async def poller(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=42,
text="turn on logging in my lemon config for me",
reply_to_message_id=900,
reply_to_text="done\n`codex resume c-123`",
sender_id=123,
chat_type="private",
)
# Telegram can occasionally redeliver the same message id with less reply metadata.
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=42,
text="turn on logging in my lemon config for me",
reply_to_message_id=900,
reply_to_text=None,
sender_id=123,
chat_type="private",
)
await run_main_loop(cfg, poller)
assert len(codex_runner.calls) == 1
assert codex_runner.calls[0][1] == ResumeToken(engine=CODEX_ENGINE, value="c-123")
assert claude_runner.calls == []
@pytest.mark.anyio
async def test_run_main_loop_ignores_duplicate_update_id() -> None:
transport = FakeTransport()
bot = FakeBot()
runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
runtime = TransportRuntime(router=_make_router(runner), projects=_empty_projects())
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=123,
startup_msg="",
exec_cfg=ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
),
forward_coalesce_s=FAST_FORWARD_COALESCE_S,
media_group_debounce_s=FAST_MEDIA_GROUP_DEBOUNCE_S,
)
async def poller(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=1,
text="first",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
chat_type="private",
update_id=9001,
)
# Same Telegram update id redelivered.
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=2,
text="second",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
chat_type="private",
update_id=9001,
)
await run_main_loop(cfg, poller)
assert len(runner.calls) == 1
assert runner.calls[0][0] == "first"
@pytest.mark.anyio @pytest.mark.anyio
async def test_run_main_loop_persists_topic_sessions_in_project_scope( async def test_run_main_loop_persists_topic_sessions_in_project_scope(
tmp_path: Path, tmp_path: Path,
+3
View File
@@ -55,6 +55,7 @@ def test_parse_incoming_update_maps_fields() -> None:
assert msg.document is None assert msg.document is None
assert msg.raw assert msg.raw
assert msg.raw["message_id"] == 10 assert msg.raw["message_id"] == 10
assert msg.update_id == 1
def test_parse_incoming_update_ignores_implicit_topic_reply() -> None: def test_parse_incoming_update_ignores_implicit_topic_reply() -> None:
@@ -83,6 +84,7 @@ def test_parse_incoming_update_ignores_implicit_topic_reply() -> None:
assert msg.reply_to_text is None assert msg.reply_to_text is None
assert msg.reply_to_is_bot is None assert msg.reply_to_is_bot is None
assert msg.reply_to_username is None assert msg.reply_to_username is None
assert msg.update_id == 1
def test_parse_incoming_update_filters_non_matching_chat() -> None: def test_parse_incoming_update_filters_non_matching_chat() -> None:
@@ -295,6 +297,7 @@ def test_parse_incoming_update_callback_query() -> None:
assert msg.callback_query_id == "cbq-1" assert msg.callback_query_id == "cbq-1"
assert msg.data == "takopi:cancel" assert msg.data == "takopi:cancel"
assert msg.sender_id == 321 assert msg.sender_id == 321
assert msg.update_id == 1
def test_parse_incoming_update_topic_fields() -> None: def test_parse_incoming_update_topic_fields() -> None: