refactor(exec_bridge): add worker pool backpressure

This commit is contained in:
banteg
2025-12-29 14:34:10 +04:00
parent 25655ed200
commit 0bae7baad6
2 changed files with 42 additions and 49 deletions
@@ -432,7 +432,6 @@ async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
async def _handle_message( async def _handle_message(
cfg: BridgeConfig, cfg: BridgeConfig,
*, *,
semaphore: asyncio.Semaphore,
chat_id: int, chat_id: int,
user_msg_id: int, user_msg_id: int,
text: str, text: str,
@@ -523,27 +522,26 @@ async def _handle_message(
_edit_progress(progress_renderer.render_progress(elapsed)) _edit_progress(progress_renderer.render_progress(elapsed))
) )
async with semaphore: try:
try: session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
session_id, answer, saw_agent_message = await cfg.runner.run_serialized( text, resume_session, on_event=on_event
text, resume_session, on_event=on_event )
) except Exception as e:
except Exception as e: if edit_task is not None:
if edit_task is not None: await asyncio.gather(edit_task, return_exceptions=True)
await asyncio.gather(edit_task, return_exceptions=True)
err = _clamp_tg_text(f"Error:\n{e}") err = _clamp_tg_text(f"Error:\n{e}")
logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err) logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err)
await _send_or_edit_markdown( await _send_or_edit_markdown(
cfg.bot, cfg.bot,
chat_id=chat_id, chat_id=chat_id,
text=err, text=err,
edit_message_id=progress_id, edit_message_id=progress_id,
reply_to_message_id=user_msg_id, reply_to_message_id=user_msg_id,
disable_notification=True, disable_notification=True,
limit=TELEGRAM_MARKDOWN_LIMIT, limit=TELEGRAM_MARKDOWN_LIMIT,
) )
return return
if edit_task is not None: if edit_task is not None:
await asyncio.gather(edit_task, return_exceptions=True) await asyncio.gather(edit_task, return_exceptions=True)
@@ -621,29 +619,31 @@ async def poll_updates(cfg: BridgeConfig):
async def _run_main_loop(cfg: BridgeConfig) -> None: async def _run_main_loop(cfg: BridgeConfig) -> None:
semaphore = asyncio.Semaphore(cfg.max_concurrency) worker_count = max(1, min(cfg.max_concurrency, 16))
queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue(
maxsize=worker_count * 2
)
async def _handle_message_task( async def worker() -> None:
*, while True:
chat_id: int, chat_id, user_msg_id, text, resume_session = await queue.get()
user_msg_id: int, try:
text: str, await _handle_message(
resume_session: str | None, cfg,
) -> None: chat_id=chat_id,
try: user_msg_id=user_msg_id,
await _handle_message( text=text,
cfg, resume_session=resume_session,
semaphore=semaphore, )
chat_id=chat_id, except Exception:
user_msg_id=user_msg_id, logger.exception("[handle] worker failed")
text=text, finally:
resume_session=resume_session, queue.task_done()
)
except Exception:
logger.exception("[handle] task failed")
try: try:
async with asyncio.TaskGroup() as tg: async with asyncio.TaskGroup() as tg:
for _ in range(worker_count):
tg.create_task(worker())
async for msg in poll_updates(cfg): async for msg in poll_updates(cfg):
text = msg["text"] text = msg["text"]
user_msg_id = msg["message_id"] user_msg_id = msg["message_id"]
@@ -651,13 +651,8 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
r = msg.get("reply_to_message") or {} r = msg.get("reply_to_message") or {}
resume_session = resume_session or extract_session_id(r.get("text")) resume_session = resume_session or extract_session_id(r.get("text"))
tg.create_task( await queue.put(
_handle_message_task( (msg["chat"]["id"], user_msg_id, text, resume_session)
chat_id=msg["chat"]["id"],
user_msg_id=user_msg_id,
text=text,
resume_session=resume_session,
)
) )
finally: finally:
await cfg.bot.close() await cfg.bot.close()
@@ -111,7 +111,6 @@ def test_final_notify_sends_loud_final_message() -> None:
asyncio.run( asyncio.run(
_handle_message( _handle_message(
cfg, cfg,
semaphore=asyncio.Semaphore(1),
chat_id=123, chat_id=123,
user_msg_id=10, user_msg_id=10,
text="hi", text="hi",
@@ -141,7 +140,6 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
asyncio.run( asyncio.run(
_handle_message( _handle_message(
cfg, cfg,
semaphore=asyncio.Semaphore(1),
chat_id=123, chat_id=123,
user_msg_id=10, user_msg_id=10,
text="hi", text="hi",