From 773975bf5676004736eaeba6c9e5c065c2f11d73 Mon Sep 17 00:00:00 2001 From: banteg <4562643+banteg@users.noreply.github.com> Date: Mon, 29 Dec 2025 14:06:09 +0400 Subject: [PATCH] refactor(exec_bridge): use taskgroup poller --- .../src/codex_telegram_bridge/exec_bridge.py | 91 ++++++++++--------- 1 file changed, 47 insertions(+), 44 deletions(-) diff --git a/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py b/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py index 04fb383..cebb560 100644 --- a/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py +++ b/codex_telegram_bridge/src/codex_telegram_bridge/exec_bridge.py @@ -580,69 +580,72 @@ async def _handle_message( ) -async def _run_main_loop(cfg: BridgeConfig) -> None: - semaphore = asyncio.Semaphore(cfg.max_concurrency) - - tasks: set[asyncio.Task[None]] = set() - - def _task_done(task: asyncio.Task[None]) -> None: - tasks.discard(task) - try: - task.result() - except asyncio.CancelledError: - pass - except Exception: - logger.exception("[handle] task failed") - +async def poll_updates(cfg: BridgeConfig): offset: int | None = None offset = await _drain_backlog(cfg, offset) await _send_startup(cfg) - try: - while True: - try: - updates = await cfg.bot.get_updates( - offset=offset, timeout_s=50, allowed_updates=["message"] - ) - except Exception as e: - logger.info("[loop] getUpdates failed: %s", e) - await asyncio.sleep(2) + while True: + try: + updates = await cfg.bot.get_updates( + offset=offset, timeout_s=50, allowed_updates=["message"] + ) + except Exception as e: + logger.info("[loop] getUpdates failed: %s", e) + await asyncio.sleep(2) + continue + logger.debug("[loop] updates: %s", updates) + + for upd in updates: + offset = upd["update_id"] + 1 + msg = upd["message"] + if "text" not in msg: continue - logger.debug("[loop] updates: %s", updates) + if not (msg["chat"]["id"] == msg["from"]["id"] == cfg.chat_id): + continue + yield msg - for upd in updates: - offset = upd["update_id"] + 1 - msg = upd.get("message") or {} - msg_chat_id = msg.get("chat", {}).get("id") - if "text" not in msg: - continue - if int(msg_chat_id) != cfg.chat_id: - continue - if msg.get("from", {}).get("is_bot"): - continue +async def _run_main_loop(cfg: BridgeConfig) -> None: + semaphore = asyncio.Semaphore(cfg.max_concurrency) + + async def _handle_message_task( + *, + chat_id: int, + user_msg_id: int, + text: str, + resume_session: str | None, + ) -> None: + try: + await _handle_message( + cfg, + semaphore=semaphore, + chat_id=chat_id, + user_msg_id=user_msg_id, + text=text, + resume_session=resume_session, + ) + except Exception: + logger.exception("[handle] task failed") + + try: + async with asyncio.TaskGroup() as tg: + async for msg in poll_updates(cfg): text = msg["text"] user_msg_id = msg["message_id"] resume_session = extract_session_id(text) r = msg.get("reply_to_message") or {} resume_session = resume_session or extract_session_id(r.get("text")) - task = asyncio.create_task( - _handle_message( - cfg, - semaphore=semaphore, - chat_id=msg_chat_id, + tg.create_task( + _handle_message_task( + chat_id=msg["chat"]["id"], user_msg_id=user_msg_id, text=text, resume_session=resume_session, ) ) - tasks.add(task) - task.add_done_callback(_task_done) finally: - for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) await cfg.bot.close()