refactor(exec_bridge): use taskgroup poller
This commit is contained in:
@@ -580,25 +580,11 @@ async def _handle_message(
|
||||
)
|
||||
|
||||
|
||||
async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
semaphore = asyncio.Semaphore(cfg.max_concurrency)
|
||||
|
||||
tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
def _task_done(task: asyncio.Task[None]) -> None:
|
||||
tasks.discard(task)
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("[handle] task failed")
|
||||
|
||||
async def poll_updates(cfg: BridgeConfig):
|
||||
offset: int | None = None
|
||||
offset = await _drain_backlog(cfg, offset)
|
||||
await _send_startup(cfg)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
updates = await cfg.bot.get_updates(
|
||||
@@ -612,37 +598,54 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
|
||||
for upd in updates:
|
||||
offset = upd["update_id"] + 1
|
||||
msg = upd.get("message") or {}
|
||||
msg_chat_id = msg.get("chat", {}).get("id")
|
||||
msg = upd["message"]
|
||||
if "text" not in msg:
|
||||
continue
|
||||
if int(msg_chat_id) != cfg.chat_id:
|
||||
continue
|
||||
if msg.get("from", {}).get("is_bot"):
|
||||
if not (msg["chat"]["id"] == msg["from"]["id"] == cfg.chat_id):
|
||||
continue
|
||||
yield msg
|
||||
|
||||
|
||||
async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
semaphore = asyncio.Semaphore(cfg.max_concurrency)
|
||||
|
||||
async def _handle_message_task(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
resume_session: str | None,
|
||||
) -> None:
|
||||
try:
|
||||
await _handle_message(
|
||||
cfg,
|
||||
semaphore=semaphore,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
text=text,
|
||||
resume_session=resume_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[handle] task failed")
|
||||
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
async for msg in poll_updates(cfg):
|
||||
text = msg["text"]
|
||||
user_msg_id = msg["message_id"]
|
||||
resume_session = extract_session_id(text)
|
||||
r = msg.get("reply_to_message") or {}
|
||||
resume_session = resume_session or extract_session_id(r.get("text"))
|
||||
|
||||
task = asyncio.create_task(
|
||||
_handle_message(
|
||||
cfg,
|
||||
semaphore=semaphore,
|
||||
chat_id=msg_chat_id,
|
||||
tg.create_task(
|
||||
_handle_message_task(
|
||||
chat_id=msg["chat"]["id"],
|
||||
user_msg_id=user_msg_id,
|
||||
text=text,
|
||||
resume_session=resume_session,
|
||||
)
|
||||
)
|
||||
tasks.add(task)
|
||||
task.add_done_callback(_task_done)
|
||||
finally:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
await cfg.bot.close()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user