refactor(exec_bridge): add worker pool backpressure
This commit is contained in:
@@ -432,7 +432,6 @@ async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
|
||||
async def _handle_message(
|
||||
cfg: BridgeConfig,
|
||||
*,
|
||||
semaphore: asyncio.Semaphore,
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
@@ -523,27 +522,26 @@ async def _handle_message(
|
||||
_edit_progress(progress_renderer.render_progress(elapsed))
|
||||
)
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
|
||||
text, resume_session, on_event=on_event
|
||||
)
|
||||
except Exception as e:
|
||||
if edit_task is not None:
|
||||
await asyncio.gather(edit_task, return_exceptions=True)
|
||||
try:
|
||||
session_id, answer, saw_agent_message = await cfg.runner.run_serialized(
|
||||
text, resume_session, on_event=on_event
|
||||
)
|
||||
except Exception as e:
|
||||
if edit_task is not None:
|
||||
await asyncio.gather(edit_task, return_exceptions=True)
|
||||
|
||||
err = _clamp_tg_text(f"Error:\n{e}")
|
||||
logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err)
|
||||
await _send_or_edit_markdown(
|
||||
cfg.bot,
|
||||
chat_id=chat_id,
|
||||
text=err,
|
||||
edit_message_id=progress_id,
|
||||
reply_to_message_id=user_msg_id,
|
||||
disable_notification=True,
|
||||
limit=TELEGRAM_MARKDOWN_LIMIT,
|
||||
)
|
||||
return
|
||||
err = _clamp_tg_text(f"Error:\n{e}")
|
||||
logger.debug("[error] send reply_to=%s text=%s", user_msg_id, err)
|
||||
await _send_or_edit_markdown(
|
||||
cfg.bot,
|
||||
chat_id=chat_id,
|
||||
text=err,
|
||||
edit_message_id=progress_id,
|
||||
reply_to_message_id=user_msg_id,
|
||||
disable_notification=True,
|
||||
limit=TELEGRAM_MARKDOWN_LIMIT,
|
||||
)
|
||||
return
|
||||
|
||||
if edit_task is not None:
|
||||
await asyncio.gather(edit_task, return_exceptions=True)
|
||||
@@ -621,29 +619,31 @@ async def poll_updates(cfg: BridgeConfig):
|
||||
|
||||
|
||||
async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
semaphore = asyncio.Semaphore(cfg.max_concurrency)
|
||||
worker_count = max(1, min(cfg.max_concurrency, 16))
|
||||
queue: asyncio.Queue[tuple[int, int, str, str | None]] = asyncio.Queue(
|
||||
maxsize=worker_count * 2
|
||||
)
|
||||
|
||||
async def _handle_message_task(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_msg_id: int,
|
||||
text: str,
|
||||
resume_session: str | None,
|
||||
) -> None:
|
||||
try:
|
||||
await _handle_message(
|
||||
cfg,
|
||||
semaphore=semaphore,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
text=text,
|
||||
resume_session=resume_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[handle] task failed")
|
||||
async def worker() -> None:
|
||||
while True:
|
||||
chat_id, user_msg_id, text, resume_session = await queue.get()
|
||||
try:
|
||||
await _handle_message(
|
||||
cfg,
|
||||
chat_id=chat_id,
|
||||
user_msg_id=user_msg_id,
|
||||
text=text,
|
||||
resume_session=resume_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[handle] worker failed")
|
||||
finally:
|
||||
queue.task_done()
|
||||
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for _ in range(worker_count):
|
||||
tg.create_task(worker())
|
||||
async for msg in poll_updates(cfg):
|
||||
text = msg["text"]
|
||||
user_msg_id = msg["message_id"]
|
||||
@@ -651,13 +651,8 @@ async def _run_main_loop(cfg: BridgeConfig) -> None:
|
||||
r = msg.get("reply_to_message") or {}
|
||||
resume_session = resume_session or extract_session_id(r.get("text"))
|
||||
|
||||
tg.create_task(
|
||||
_handle_message_task(
|
||||
chat_id=msg["chat"]["id"],
|
||||
user_msg_id=user_msg_id,
|
||||
text=text,
|
||||
resume_session=resume_session,
|
||||
)
|
||||
await queue.put(
|
||||
(msg["chat"]["id"], user_msg_id, text, resume_session)
|
||||
)
|
||||
finally:
|
||||
await cfg.bot.close()
|
||||
|
||||
@@ -111,7 +111,6 @@ def test_final_notify_sends_loud_final_message() -> None:
|
||||
asyncio.run(
|
||||
_handle_message(
|
||||
cfg,
|
||||
semaphore=asyncio.Semaphore(1),
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="hi",
|
||||
@@ -141,7 +140,6 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
|
||||
asyncio.run(
|
||||
_handle_message(
|
||||
cfg,
|
||||
semaphore=asyncio.Semaphore(1),
|
||||
chat_id=123,
|
||||
user_msg_id=10,
|
||||
text="hi",
|
||||
|
||||
Reference in New Issue
Block a user