fix(telegram): dedupe duplicate incoming messages (#198)

This commit is contained in:
banteg
2026-02-08 19:10:34 +04:00
committed by GitHub
parent b293818195
commit fb51708473
5 changed files with 298 additions and 141 deletions
+178 -141
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
from collections import deque
from collections.abc import AsyncIterator, Awaitable, Callable, Mapping
from dataclasses import dataclass
from functools import partial
@@ -77,6 +78,9 @@ logger = get_logger(__name__)
__all__ = ["poll_updates", "run_main_loop", "send_with_resume"]
ForwardKey = tuple[int, int, int]
MessageKey = tuple[int, int]
_SEEN_MESSAGES_LIMIT = 2048
_SEEN_UPDATES_LIMIT = 4096
_handle_file_put_default = handle_file_put_default
@@ -360,6 +364,16 @@ class TelegramMsgContext:
ambient_context: RunContext | None
@dataclass(frozen=True, slots=True)
class MessageClassification:
text: str
command_id: str | None
args_text: str
is_cancel: bool
is_forward_candidate: bool
is_media_group_document: bool
@dataclass(frozen=True, slots=True)
class TelegramCommandContext:
cfg: TelegramBridgeConfig
@@ -374,6 +388,30 @@ class TelegramCommandContext:
task_group: TaskGroup
def _classify_message(
msg: TelegramIncomingMessage, *, files_enabled: bool
) -> MessageClassification:
text = msg.text
command_id, args_text = parse_slash_command(text)
is_forward_candidate = (
_is_forwarded(msg.raw)
and msg.document is None
and msg.voice is None
and msg.media_group_id is None
)
is_media_group_document = (
files_enabled and msg.document is not None and msg.media_group_id is not None
)
return MessageClassification(
text=text,
command_id=command_id,
args_text=args_text,
is_cancel=is_cancel_command(text),
is_forward_candidate=is_forward_candidate,
is_media_group_document=is_media_group_document,
)
@dataclass(slots=True)
class TelegramLoopState:
running_tasks: RunningTasks
@@ -392,6 +430,10 @@ class TelegramLoopState:
forward_coalesce_s: float
media_group_debounce_s: float
transport_id: str | None
seen_update_ids: set[int]
seen_update_order: deque[int]
seen_message_keys: set[MessageKey]
seen_messages_order: deque[MessageKey]
if TYPE_CHECKING:
@@ -931,6 +973,10 @@ async def run_main_loop(
forward_coalesce_s=max(0.0, float(cfg.forward_coalesce_s)),
media_group_debounce_s=max(0.0, float(cfg.media_group_debounce_s)),
transport_id=transport_id,
seen_update_ids=set(),
seen_update_order=deque(),
seen_message_keys=set(),
seen_messages_order=deque(),
)
def refresh_topics_scope() -> None:
@@ -1199,6 +1245,47 @@ async def run_main_loop(
await reply(text=f"error:\n{exc}")
return None
topic_key = resolve_topic_key(msg)
chat_project = (
_topics_chat_project(cfg, msg.chat_id)
if cfg.topics.enabled
else None
)
_, ok = await ensure_topic_context(
resolved=resolved,
ambient_context=ambient_context,
topic_key=topic_key,
chat_project=chat_project,
reply=reply,
)
if not ok:
return None
return resolved
async def resolve_engine_defaults(
*,
explicit_engine: EngineId | None,
context: RunContext | None,
chat_id: int,
topic_key: tuple[int, int] | None,
):
return await resolve_engine_for_message(
runtime=cfg.runtime,
context=context,
explicit_engine=explicit_engine,
chat_id=chat_id,
topic_key=topic_key,
topic_store=state.topic_store,
chat_prefs=state.chat_prefs,
)
async def ensure_topic_context(
*,
resolved: ResolvedMessage,
ambient_context: RunContext | None,
topic_key: tuple[int, int] | None,
chat_project: str | None,
reply: Callable[..., Awaitable[None]],
) -> tuple[RunContext | None, bool]:
effective_context = ambient_context
if (
state.topic_store is not None
@@ -1221,35 +1308,13 @@ async def run_main_loop(
and effective_context is None
and resolved.context_source not in {"directives", "reply_ctx"}
):
chat_project = (
_topics_chat_project(cfg, msg.chat_id)
if cfg.topics.enabled
else None
)
await reply(
text="this topic isn't bound to a project yet.\n"
f"{_usage_ctx_set(chat_project=chat_project)} or "
f"{_usage_topic(chat_project=chat_project)}",
)
return None
return resolved
async def resolve_engine_defaults(
*,
explicit_engine: EngineId | None,
context: RunContext | None,
chat_id: int,
topic_key: tuple[int, int] | None,
):
return await resolve_engine_for_message(
runtime=cfg.runtime,
context=context,
explicit_engine=explicit_engine,
chat_id=chat_id,
topic_key=topic_key,
topic_store=state.topic_store,
chat_prefs=state.chat_prefs,
)
return effective_context, False
return effective_context, True
resume_resolver = ResumeResolver(
cfg=cfg,
@@ -1260,29 +1325,19 @@ async def run_main_loop(
chat_session_store=state.chat_session_store,
)
async def run_prompt_from_upload(
async def dispatch_prompt_run(
*,
msg: TelegramIncomingMessage,
prompt_text: str,
resolved: ResolvedMessage,
topic_key: tuple[int, int] | None,
chat_session_key: tuple[int, int | None] | None,
reply_ref: MessageRef | None,
reply_id: int | None,
) -> None:
chat_id = msg.chat_id
user_msg_id = msg.message_id
reply_id = msg.reply_to_message_id
reply_ref = (
MessageRef(
channel_id=msg.chat_id,
message_id=msg.reply_to_message_id,
thread_id=msg.thread_id,
)
if msg.reply_to_message_id is not None
else None
)
resume_token = resolved.resume_token
context = resolved.context
chat_session_key = _chat_session_key(
msg, store=state.chat_session_store
)
topic_key = resolve_topic_key(msg)
engine_resolution = await resolve_engine_defaults(
explicit_engine=resolved.engine_override,
context=context,
@@ -1291,7 +1346,7 @@ async def run_main_loop(
)
engine_override = engine_resolution.engine
resume_decision = await resume_resolver.resolve(
resume_token=resume_token,
resume_token=resolved.resume_token,
reply_id=reply_id,
chat_id=chat_id,
user_msg_id=user_msg_id,
@@ -1337,17 +1392,44 @@ async def run_main_loop(
progress_ref,
)
async def run_prompt_from_upload(
msg: TelegramIncomingMessage,
prompt_text: str,
resolved: ResolvedMessage,
) -> None:
reply_id = msg.reply_to_message_id
reply_ref = (
MessageRef(
channel_id=msg.chat_id,
message_id=msg.reply_to_message_id,
thread_id=msg.thread_id,
)
if msg.reply_to_message_id is not None
else None
)
chat_session_key = _chat_session_key(
msg, store=state.chat_session_store
)
topic_key = resolve_topic_key(msg)
await dispatch_prompt_run(
msg=msg,
prompt_text=prompt_text,
resolved=resolved,
topic_key=topic_key,
chat_session_key=chat_session_key,
reply_ref=reply_ref,
reply_id=reply_id,
)
async def _dispatch_pending_prompt(pending: _PendingPrompt) -> None:
msg = pending.msg
chat_id = msg.chat_id
user_msg_id = msg.message_id
reply = make_reply(cfg, msg)
try:
resolved = cfg.runtime.resolve_message(
text=pending.text,
reply_text=msg.reply_to_text,
ambient_context=pending.ambient_context,
chat_id=chat_id,
chat_id=msg.chat_id,
)
except DirectiveError as exc:
await reply(text=f"error:\n{exc}")
@@ -1375,92 +1457,23 @@ async def run_main_loop(
prompt_text,
)
resume_token = resolved.resume_token
context = resolved.context
engine_resolution = await resolve_engine_defaults(
explicit_engine=resolved.engine_override,
context=context,
chat_id=chat_id,
_effective_context, ok = await ensure_topic_context(
resolved=resolved,
ambient_context=pending.ambient_context,
topic_key=pending.topic_key,
chat_project=pending.chat_project,
reply=reply,
)
engine_override = engine_resolution.engine
effective_context = pending.ambient_context
if (
state.topic_store is not None
and pending.topic_key is not None
and resolved.context is not None
and resolved.context_source == "directives"
):
await state.topic_store.set_context(
*pending.topic_key, resolved.context
)
await _maybe_rename_topic(
cfg,
state.topic_store,
chat_id=pending.topic_key[0],
thread_id=pending.topic_key[1],
context=resolved.context,
)
effective_context = resolved.context
if (
state.topic_store is not None
and pending.topic_key is not None
and effective_context is None
and resolved.context_source not in {"directives", "reply_ctx"}
):
await reply(
text="this topic isn't bound to a project yet.\n"
f"{_usage_ctx_set(chat_project=pending.chat_project)} or "
f"{_usage_topic(chat_project=pending.chat_project)}",
)
if not ok:
return
resume_decision = await resume_resolver.resolve(
resume_token=resume_token,
reply_id=pending.reply_id,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=msg.thread_id,
chat_session_key=pending.chat_session_key,
topic_key=pending.topic_key,
engine_for_session=engine_resolution.engine,
await dispatch_prompt_run(
msg=msg,
prompt_text=prompt_text,
)
if resume_decision.handled_by_running_task:
return
resume_token = resume_decision.resume_token
if resume_token is None:
tg.start_soon(
run_job,
chat_id,
user_msg_id,
prompt_text,
None,
context,
msg.thread_id,
pending.chat_session_key,
pending.reply_ref,
scheduler.note_thread_known,
engine_override,
)
return
progress_ref = await _send_queued_progress(
cfg,
chat_id=chat_id,
user_msg_id=user_msg_id,
thread_id=msg.thread_id,
resume_token=resume_token,
context=context,
)
await scheduler.enqueue_resume(
chat_id,
user_msg_id,
prompt_text,
resume_token,
context,
msg.thread_id,
pending.chat_session_key,
progress_ref,
resolved=resolved,
topic_key=pending.topic_key,
chat_session_key=pending.chat_session_key,
reply_ref=pending.reply_ref,
reply_id=pending.reply_id,
)
forward_coalescer = ForwardCoalescer(
@@ -1562,23 +1575,14 @@ async def run_main_loop(
async def route_message(msg: TelegramIncomingMessage) -> None:
reply = make_reply(cfg, msg)
text = msg.text
classification = _classify_message(msg, files_enabled=cfg.files.enabled)
text = classification.text
is_voice_transcribed = False
is_forward_candidate = (
_is_forwarded(msg.raw)
and msg.document is None
and msg.voice is None
and msg.media_group_id is None
)
if is_forward_candidate:
if classification.is_forward_candidate:
forward_coalescer.attach_forward(msg)
return
forward_key = _forward_key(msg)
if (
cfg.files.enabled
and msg.document is not None
and msg.media_group_id is not None
):
if classification.is_media_group_document:
media_group_buffer.add(msg)
return
ctx = await build_message_context(msg)
@@ -1591,13 +1595,14 @@ async def run_main_loop(
chat_project = ctx.chat_project
ambient_context = ctx.ambient_context
if is_cancel_command(text):
if classification.is_cancel:
tg.start_soon(
handle_cancel, cfg, msg, state.running_tasks, scheduler
)
return
command_id, args_text = parse_slash_command(text)
command_id = classification.command_id
args_text = classification.args_text
if command_id == "new":
forward_coalescer.cancel(forward_key)
if state.topic_store is not None and topic_key is not None:
@@ -1793,6 +1798,38 @@ async def run_main_loop(
sender_id=sender_id,
)
return
if update.update_id is not None:
update_id = update.update_id
if update_id in state.seen_update_ids:
logger.debug(
"update.ignored",
reason="duplicate_update",
update_id=update_id,
chat_id=update.chat_id,
sender_id=update.sender_id,
)
return
state.seen_update_ids.add(update_id)
state.seen_update_order.append(update_id)
if len(state.seen_update_order) > _SEEN_UPDATES_LIMIT:
oldest_update_id = state.seen_update_order.popleft()
state.seen_update_ids.discard(oldest_update_id)
elif isinstance(update, TelegramIncomingMessage):
key = (update.chat_id, update.message_id)
if key in state.seen_message_keys:
logger.debug(
"update.ignored",
reason="duplicate_message",
chat_id=update.chat_id,
message_id=update.message_id,
sender_id=update.sender_id,
)
return
state.seen_message_keys.add(key)
state.seen_messages_order.append(key)
if len(state.seen_messages_order) > _SEEN_MESSAGES_LIMIT:
oldest = state.seen_messages_order.popleft()
state.seen_message_keys.discard(oldest)
if isinstance(update, TelegramCallbackQuery):
if update.data == CANCEL_CALLBACK_DATA:
tg.start_soon(
+6
View File
@@ -36,12 +36,14 @@ def parse_incoming_update(
if update.message is not None:
return _parse_incoming_message(
update.message,
update_id=update.update_id,
chat_id=chat_id,
chat_ids=chat_ids,
)
if update.callback_query is not None:
return _parse_callback_query(
update.callback_query,
update_id=update.update_id,
chat_id=chat_id,
chat_ids=chat_ids,
)
@@ -51,6 +53,7 @@ def parse_incoming_update(
def _parse_incoming_message(
msg: Message,
*,
update_id: int | None = None,
chat_id: int | None = None,
chat_ids: set[int] | None = None,
) -> TelegramIncomingMessage | None:
@@ -133,12 +136,14 @@ def _parse_incoming_message(
voice=voice_payload,
document=document_payload,
raw=msgspec.to_builtins(msg),
update_id=update_id,
)
def _parse_callback_query(
query: CallbackQuery,
*,
update_id: int | None = None,
chat_id: int | None = None,
chat_ids: set[int] | None = None,
) -> TelegramCallbackQuery | None:
@@ -162,6 +167,7 @@ def _parse_callback_query(
data=data,
sender_id=sender_id,
raw=msgspec.to_builtins(query),
update_id=update_id,
)
+2
View File
@@ -41,6 +41,7 @@ class TelegramIncomingMessage:
voice: TelegramVoice | None = None
document: TelegramDocument | None = None
raw: dict[str, Any] | None = None
update_id: int | None = None
@property
def is_private(self) -> bool:
@@ -58,6 +59,7 @@ class TelegramCallbackQuery:
data: str | None
sender_id: int | None
raw: dict[str, Any] | None = None
update_id: int | None = None
TelegramIncomingUpdate = TelegramIncomingMessage | TelegramCallbackQuery
+109
View File
@@ -1649,6 +1649,115 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None:
tg.cancel_scope.cancel()
@pytest.mark.anyio
async def test_run_main_loop_ignores_duplicate_message_id_for_replies() -> None:
transport = FakeTransport()
bot = FakeBot()
codex_runner = ScriptRunner([Return(answer="codex")], engine=CODEX_ENGINE)
claude_runner = ScriptRunner([Return(answer="claude")], engine="claude")
router = AutoRouter(
entries=[
RunnerEntry(engine=codex_runner.engine, runner=codex_runner),
RunnerEntry(engine=claude_runner.engine, runner=claude_runner),
],
default_engine=claude_runner.engine,
)
runtime = TransportRuntime(router=router, projects=_empty_projects())
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=123,
startup_msg="",
exec_cfg=ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
),
forward_coalesce_s=FAST_FORWARD_COALESCE_S,
media_group_debounce_s=FAST_MEDIA_GROUP_DEBOUNCE_S,
)
async def poller(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=42,
text="turn on logging in my lemon config for me",
reply_to_message_id=900,
reply_to_text="done\n`codex resume c-123`",
sender_id=123,
chat_type="private",
)
# Telegram can occasionally redeliver the same message id with less reply metadata.
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=42,
text="turn on logging in my lemon config for me",
reply_to_message_id=900,
reply_to_text=None,
sender_id=123,
chat_type="private",
)
await run_main_loop(cfg, poller)
assert len(codex_runner.calls) == 1
assert codex_runner.calls[0][1] == ResumeToken(engine=CODEX_ENGINE, value="c-123")
assert claude_runner.calls == []
@pytest.mark.anyio
async def test_run_main_loop_ignores_duplicate_update_id() -> None:
transport = FakeTransport()
bot = FakeBot()
runner = ScriptRunner([Return(answer="ok")], engine=CODEX_ENGINE)
runtime = TransportRuntime(router=_make_router(runner), projects=_empty_projects())
cfg = TelegramBridgeConfig(
bot=bot,
runtime=runtime,
chat_id=123,
startup_msg="",
exec_cfg=ExecBridgeConfig(
transport=transport,
presenter=MarkdownPresenter(),
final_notify=True,
),
forward_coalesce_s=FAST_FORWARD_COALESCE_S,
media_group_debounce_s=FAST_MEDIA_GROUP_DEBOUNCE_S,
)
async def poller(_cfg: TelegramBridgeConfig):
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=1,
text="first",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
chat_type="private",
update_id=9001,
)
# Same Telegram update id redelivered.
yield TelegramIncomingMessage(
transport="telegram",
chat_id=123,
message_id=2,
text="second",
reply_to_message_id=None,
reply_to_text=None,
sender_id=123,
chat_type="private",
update_id=9001,
)
await run_main_loop(cfg, poller)
assert len(runner.calls) == 1
assert runner.calls[0][0] == "first"
@pytest.mark.anyio
async def test_run_main_loop_persists_topic_sessions_in_project_scope(
tmp_path: Path,
+3
View File
@@ -55,6 +55,7 @@ def test_parse_incoming_update_maps_fields() -> None:
assert msg.document is None
assert msg.raw
assert msg.raw["message_id"] == 10
assert msg.update_id == 1
def test_parse_incoming_update_ignores_implicit_topic_reply() -> None:
@@ -83,6 +84,7 @@ def test_parse_incoming_update_ignores_implicit_topic_reply() -> None:
assert msg.reply_to_text is None
assert msg.reply_to_is_bot is None
assert msg.reply_to_username is None
assert msg.update_id == 1
def test_parse_incoming_update_filters_non_matching_chat() -> None:
@@ -295,6 +297,7 @@ def test_parse_incoming_update_callback_query() -> None:
assert msg.callback_query_id == "cbq-1"
assert msg.data == "takopi:cancel"
assert msg.sender_id == 321
assert msg.update_id == 1
def test_parse_incoming_update_topic_fields() -> None: