refactor(exec_bridge): use taskgroup poller

This commit is contained in:
banteg
2025-12-29 14:06:09 +04:00
parent abbb1a8825
commit 773975bf56
@@ -580,69 +580,72 @@ async def _handle_message(
)
async def _run_main_loop(cfg: BridgeConfig) -> None:
semaphore = asyncio.Semaphore(cfg.max_concurrency)
tasks: set[asyncio.Task[None]] = set()
def _task_done(task: asyncio.Task[None]) -> None:
tasks.discard(task)
try:
task.result()
except asyncio.CancelledError:
pass
except Exception:
logger.exception("[handle] task failed")
async def poll_updates(cfg: BridgeConfig):
offset: int | None = None
offset = await _drain_backlog(cfg, offset)
await _send_startup(cfg)
try:
while True:
try:
updates = await cfg.bot.get_updates(
offset=offset, timeout_s=50, allowed_updates=["message"]
)
except Exception as e:
logger.info("[loop] getUpdates failed: %s", e)
await asyncio.sleep(2)
while True:
try:
updates = await cfg.bot.get_updates(
offset=offset, timeout_s=50, allowed_updates=["message"]
)
except Exception as e:
logger.info("[loop] getUpdates failed: %s", e)
await asyncio.sleep(2)
continue
logger.debug("[loop] updates: %s", updates)
for upd in updates:
offset = upd["update_id"] + 1
msg = upd["message"]
if "text" not in msg:
continue
logger.debug("[loop] updates: %s", updates)
if not (msg["chat"]["id"] == msg["from"]["id"] == cfg.chat_id):
continue
yield msg
for upd in updates:
offset = upd["update_id"] + 1
msg = upd.get("message") or {}
msg_chat_id = msg.get("chat", {}).get("id")
if "text" not in msg:
continue
if int(msg_chat_id) != cfg.chat_id:
continue
if msg.get("from", {}).get("is_bot"):
continue
async def _run_main_loop(cfg: BridgeConfig) -> None:
semaphore = asyncio.Semaphore(cfg.max_concurrency)
async def _handle_message_task(
*,
chat_id: int,
user_msg_id: int,
text: str,
resume_session: str | None,
) -> None:
try:
await _handle_message(
cfg,
semaphore=semaphore,
chat_id=chat_id,
user_msg_id=user_msg_id,
text=text,
resume_session=resume_session,
)
except Exception:
logger.exception("[handle] task failed")
try:
async with asyncio.TaskGroup() as tg:
async for msg in poll_updates(cfg):
text = msg["text"]
user_msg_id = msg["message_id"]
resume_session = extract_session_id(text)
r = msg.get("reply_to_message") or {}
resume_session = resume_session or extract_session_id(r.get("text"))
task = asyncio.create_task(
_handle_message(
cfg,
semaphore=semaphore,
chat_id=msg_chat_id,
tg.create_task(
_handle_message_task(
chat_id=msg["chat"]["id"],
user_msg_id=user_msg_id,
text=text,
resume_session=resume_session,
)
)
tasks.add(task)
task.add_done_callback(_task_done)
finally:
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
await cfg.bot.close()