refactor: handle message (#19)

This commit is contained in:
banteg
2026-01-02 01:55:46 +04:00
committed by GitHub
parent ac844b5305
commit 73ba4836c1
4 changed files with 208 additions and 148 deletions
+2 -2
View File
@@ -36,7 +36,7 @@ The orchestrator module containing:
|-----------|---------| |-----------|---------|
| `BridgeConfig` | Frozen dataclass holding runtime config | | `BridgeConfig` | Frozen dataclass holding runtime config |
| `poll_updates()` | Async generator that drains backlog, long-polls updates, filters messages | | `poll_updates()` | Async generator that drains backlog, long-polls updates, filters messages |
| `_run_main_loop()` | TaskGroup-based main loop that spawns per-message handlers | | `run_main_loop()` | TaskGroup-based main loop that spawns per-message handlers |
| `handle_message()` | Per-message handler with progress updates and final render | | `handle_message()` | Per-message handler with progress updates and final render |
| `ProgressEdits` | Throttled progress edit worker | | `ProgressEdits` | Throttled progress edit worker |
| `_handle_cancel()` | `/cancel` routing | | `_handle_cancel()` | `/cancel` routing |
@@ -162,7 +162,7 @@ Telegram Update
poll_updates() drains backlog, long-polls, filters chat_id == from_id == cfg.chat_id poll_updates() drains backlog, long-polls, filters chat_id == from_id == cfg.chat_id
_run_main_loop() spawns tasks in TaskGroup run_main_loop() spawns tasks in TaskGroup
handle_message() spawned as task handle_message() spawned as task
+202 -142
View File
@@ -232,6 +232,147 @@ async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None:
drained += len(updates) drained += len(updates)
@dataclass(frozen=True, slots=True)
class ProgressMessageState:
message_id: int | None
last_edit_at: float
last_rendered: str | None
async def send_initial_progress(
cfg: BridgeConfig,
*,
chat_id: int,
user_msg_id: int,
label: str,
renderer: ExecProgressRenderer,
is_resume_line: Callable[[str], bool],
clock: Callable[[], float],
limit: int,
) -> ProgressMessageState:
progress_id: int | None = None
last_edit_at = 0.0
last_rendered: str | None = None
initial_md = renderer.render_progress(0.0, label=label)
initial_rendered, initial_entities = prepare_telegram(
initial_md, limit=limit, is_resume_line=is_resume_line
)
logger.debug(
"[progress] send reply_to=%s md=%s rendered=%s entities=%s",
user_msg_id,
initial_md,
initial_rendered,
initial_entities,
)
progress_msg = await cfg.bot.send_message(
chat_id=chat_id,
text=initial_rendered,
entities=initial_entities,
reply_to_message_id=user_msg_id,
disable_notification=True,
)
if progress_msg is not None:
progress_id = int(progress_msg["message_id"])
last_edit_at = clock()
last_rendered = initial_rendered
logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id)
return ProgressMessageState(
message_id=progress_id,
last_edit_at=last_edit_at,
last_rendered=last_rendered,
)
@dataclass(slots=True)
class RunOutcome:
cancelled: bool = False
completed: CompletedEvent | None = None
resume: ResumeToken | None = None
async def run_runner_with_cancel(
runner: Runner,
*,
prompt: str,
resume_token: ResumeToken | None,
edits: ProgressEdits,
running_task: RunningTask | None,
on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None,
) -> RunOutcome:
outcome = RunOutcome()
async with anyio.create_task_group() as tg:
async def run_runner() -> None:
try:
async for evt in runner.run(prompt, resume_token):
_log_runner_event(evt)
if isinstance(evt, StartedEvent):
outcome.resume = evt.resume
if running_task is not None and running_task.resume is None:
running_task.resume = evt.resume
running_task.resume_ready.set()
if on_thread_known is not None:
await on_thread_known(evt.resume, running_task.done)
elif isinstance(evt, CompletedEvent):
outcome.resume = evt.resume or outcome.resume
outcome.completed = evt
await edits.on_event(evt)
finally:
tg.cancel_scope.cancel()
async def wait_cancel(task: RunningTask) -> None:
await task.cancel_requested.wait()
outcome.cancelled = True
tg.cancel_scope.cancel()
tg.start_soon(run_runner)
if running_task is not None:
tg.start_soon(wait_cancel, running_task)
return outcome
def sync_resume_token(
renderer: ExecProgressRenderer, resume: ResumeToken | None
) -> ResumeToken | None:
resume = resume or renderer.resume_token
renderer.resume_token = resume
return resume
async def send_result_message(
cfg: BridgeConfig,
*,
chat_id: int,
user_msg_id: int,
progress_id: int | None,
markdown: str,
disable_notification: bool,
edit_message_id: int | None,
is_resume_line: Callable[[str], bool],
prepared: tuple[str, list[dict[str, Any]] | None] | None = None,
delete_tag: str = "final",
) -> None:
final_msg, edited = await _send_or_edit_markdown(
cfg.bot,
chat_id=chat_id,
text=markdown,
edit_message_id=edit_message_id,
reply_to_message_id=user_msg_id,
disable_notification=disable_notification,
limit=TELEGRAM_MARKDOWN_LIMIT,
is_resume_line=is_resume_line,
prepared=prepared,
)
if final_msg is None:
return
if progress_id is not None and (edit_message_id is None or not edited):
logger.debug("[%s] delete progress message_id=%s", delete_tag, progress_id)
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
async def handle_message( async def handle_message(
cfg: BridgeConfig, cfg: BridgeConfig,
*, *,
@@ -262,35 +403,17 @@ async def handle_message(
max_actions=5, resume_formatter=runner.format_resume max_actions=5, resume_formatter=runner.format_resume
) )
progress_id: int | None = None progress_state = await send_initial_progress(
last_edit_at = 0.0 cfg,
last_rendered: str | None = None
initial_md = progress_renderer.render_progress(
0.0, label=f"working ({runner.engine})"
)
initial_rendered, initial_entities = prepare_telegram(
initial_md, limit=TELEGRAM_MARKDOWN_LIMIT, is_resume_line=is_resume_line
)
logger.debug(
"[progress] send reply_to=%s md=%s rendered=%s entities=%s",
user_msg_id,
initial_md,
initial_rendered,
initial_entities,
)
progress_msg = await cfg.bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=initial_rendered, user_msg_id=user_msg_id,
entities=initial_entities, label=f"working ({runner.engine})",
reply_to_message_id=user_msg_id, renderer=progress_renderer,
disable_notification=True, is_resume_line=is_resume_line,
clock=clock,
limit=TELEGRAM_MARKDOWN_LIMIT,
) )
if progress_msg is not None: progress_id = progress_state.message_id
progress_id = int(progress_msg["message_id"])
last_edit_at = clock()
last_rendered = initial_rendered
logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id)
edits = ProgressEdits( edits = ProgressEdits(
bot=cfg.bot, bot=cfg.bot,
@@ -302,23 +425,17 @@ async def handle_message(
clock=clock, clock=clock,
sleep=sleep, sleep=sleep,
limit=TELEGRAM_MARKDOWN_LIMIT, limit=TELEGRAM_MARKDOWN_LIMIT,
last_edit_at=last_edit_at, last_edit_at=progress_state.last_edit_at,
last_rendered=last_rendered, last_rendered=progress_state.last_rendered,
is_resume_line=is_resume_line, is_resume_line=is_resume_line,
) )
cancel_exc_type = anyio.get_cancelled_exc_class()
cancelled = False
error: Exception | None = None
resume_token_value: ResumeToken | None = None
answer: str | None = None
run_ok: bool | None = None
run_error: str | None = None
running_task: RunningTask | None = None running_task: RunningTask | None = None
if running_tasks is not None and progress_id is not None: if running_tasks is not None and progress_id is not None:
running_task = RunningTask() running_task = RunningTask()
running_tasks[progress_id] = running_task running_tasks[progress_id] = running_task
cancel_exc_type = anyio.get_cancelled_exc_class()
edits_scope = anyio.CancelScope() edits_scope = anyio.CancelScope()
async def run_edits() -> None: async def run_edits() -> None:
@@ -329,67 +446,22 @@ async def handle_message(
# Edits are best-effort; cancellation should not bubble into the task group. # Edits are best-effort; cancellation should not bubble into the task group.
return return
outcome = RunOutcome()
error: Exception | None = None
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
if progress_id is not None: if progress_id is not None:
tg.start_soon(run_edits) tg.start_soon(run_edits)
async def run_exec() -> CompletedEvent | None:
nonlocal cancelled
cancel_flag = False
completed: CompletedEvent | None = None
async with anyio.create_task_group() as exec_tg:
async def run_runner() -> None:
nonlocal resume_token_value, completed, answer, run_ok, run_error
try:
async for evt in runner.run(runner_text, resume_token):
_log_runner_event(evt)
if isinstance(evt, StartedEvent):
resume_token_value = evt.resume
if (
running_task is not None
and running_task.resume is None
):
running_task.resume = resume_token_value
running_task.resume_ready.set()
if on_thread_known is not None:
await on_thread_known(
resume_token_value, running_task.done
)
elif isinstance(evt, CompletedEvent):
resume_token_value = evt.resume or resume_token_value
answer = evt.answer
run_ok = evt.ok
run_error = evt.error
completed = evt
await edits.on_event(evt)
finally:
exec_tg.cancel_scope.cancel()
async def wait_cancel() -> None:
nonlocal cancel_flag
if running_task is None:
return
await running_task.cancel_requested.wait()
cancel_flag = True
exec_tg.cancel_scope.cancel()
exec_tg.start_soon(run_runner)
if running_task is not None:
exec_tg.start_soon(wait_cancel)
if cancel_flag:
cancelled = True
return completed
try: try:
completed = await run_exec() outcome = await run_runner_with_cancel(
if completed is not None: runner,
resume_token_value = completed.resume or resume_token_value prompt=runner_text,
answer = completed.answer resume_token=resume_token,
run_ok = completed.ok edits=edits,
run_error = completed.error running_task=running_task,
on_thread_known=on_thread_known,
)
except Exception as e: except Exception as e:
error = e error = e
finally: finally:
@@ -400,67 +472,60 @@ async def handle_message(
): ):
running_task.done.set() running_task.done.set()
running_tasks.pop(progress_id, None) running_tasks.pop(progress_id, None)
if not cancelled and error is None: if not outcome.cancelled and error is None:
# Give pending progress edits a chance to flush if they're ready.
await anyio.sleep(0) await anyio.sleep(0)
edits_scope.cancel() edits_scope.cancel()
elapsed = clock() - started_at
if error is not None: if error is not None:
elapsed = clock() - started_at sync_resume_token(progress_renderer, outcome.resume)
if resume_token_value is None:
resume_token_value = progress_renderer.resume_token
progress_renderer.resume_token = resume_token_value
err_body = str(error) err_body = str(error)
final_md = progress_renderer.render_final(elapsed, err_body, status="error") final_md = progress_renderer.render_final(elapsed, err_body, status="error")
logger.debug("[error] markdown: %s", final_md) logger.debug("[error] markdown: %s", final_md)
final_msg, edited = await _send_or_edit_markdown( await send_result_message(
cfg.bot, cfg,
chat_id=chat_id, chat_id=chat_id,
text=final_md, user_msg_id=user_msg_id,
edit_message_id=progress_id, progress_id=progress_id,
reply_to_message_id=user_msg_id, markdown=final_md,
disable_notification=True, disable_notification=True,
limit=TELEGRAM_MARKDOWN_LIMIT, edit_message_id=progress_id,
is_resume_line=is_resume_line, is_resume_line=is_resume_line,
delete_tag="error",
) )
if final_msg is None:
return
if progress_id is not None and not edited:
logger.debug("[error] delete progress message_id=%s", progress_id)
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
return return
elapsed = clock() - started_at if outcome.cancelled:
if cancelled: resume = sync_resume_token(progress_renderer, outcome.resume)
if resume_token_value is None:
resume_token_value = progress_renderer.resume_token
logger.info( logger.info(
"[handle] cancelled resume=%s elapsed=%.1fs", "[handle] cancelled resume=%s elapsed=%.1fs",
resume_token_value.value if resume_token_value else None, resume.value if resume else None,
elapsed, elapsed,
) )
progress_renderer.resume_token = resume_token_value
final_md = progress_renderer.render_progress(elapsed, label="`cancelled`") final_md = progress_renderer.render_progress(elapsed, label="`cancelled`")
final_msg, edited = await _send_or_edit_markdown( await send_result_message(
cfg.bot, cfg,
chat_id=chat_id, chat_id=chat_id,
text=final_md, user_msg_id=user_msg_id,
edit_message_id=progress_id, progress_id=progress_id,
reply_to_message_id=user_msg_id, markdown=final_md,
disable_notification=True, disable_notification=True,
limit=TELEGRAM_MARKDOWN_LIMIT, edit_message_id=progress_id,
is_resume_line=is_resume_line, is_resume_line=is_resume_line,
delete_tag="cancel",
) )
if final_msg is None:
return
if progress_id is not None and not edited:
logger.debug("[cancel] delete progress message_id=%s", progress_id)
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
return return
if answer is None: if outcome.completed is None:
raise RuntimeError("runner finished without a completed event") raise RuntimeError("runner finished without a completed event")
final_answer = answer completed = outcome.completed
run_ok = completed.ok
run_error = completed.error
final_answer = completed.answer
if run_ok is False and run_error: if run_ok is False and run_error:
if final_answer.strip(): if final_answer.strip():
final_answer = f"{final_answer}\n\n{run_error}" final_answer = f"{final_answer}\n\n{run_error}"
@@ -470,11 +535,10 @@ async def handle_message(
status = ( status = (
"error" if run_ok is False else ("done" if final_answer.strip() else "error") "error" if run_ok is False else ("done" if final_answer.strip() else "error")
) )
if resume_token_value is None: sync_resume_token(progress_renderer, completed.resume or outcome.resume)
resume_token_value = progress_renderer.resume_token
progress_renderer.resume_token = resume_token_value
final_md = progress_renderer.render_final(elapsed, final_answer, status=status) final_md = progress_renderer.render_final(elapsed, final_answer, status=status)
logger.debug("[final] markdown: %s", final_md) logger.debug("[final] markdown: %s", final_md)
final_rendered, final_entities = prepare_telegram( final_rendered, final_entities = prepare_telegram(
final_md, limit=TELEGRAM_MARKDOWN_LIMIT, is_resume_line=is_resume_line final_md, limit=TELEGRAM_MARKDOWN_LIMIT, is_resume_line=is_resume_line
) )
@@ -496,22 +560,18 @@ async def handle_message(
final_entities, final_entities,
) )
final_msg, edited = await _send_or_edit_markdown( await send_result_message(
cfg.bot, cfg,
chat_id=chat_id, chat_id=chat_id,
text=final_md, user_msg_id=user_msg_id,
edit_message_id=edit_message_id, progress_id=progress_id,
reply_to_message_id=user_msg_id, markdown=final_md,
disable_notification=False, disable_notification=False,
limit=TELEGRAM_MARKDOWN_LIMIT, edit_message_id=edit_message_id,
is_resume_line=is_resume_line, is_resume_line=is_resume_line,
prepared=(final_rendered, final_entities), prepared=(final_rendered, final_entities),
delete_tag="final",
) )
if final_msg is None:
return
if progress_id is not None and (edit_message_id is None or not edited):
logger.debug("[final] delete progress message_id=%s", progress_id)
await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id)
async def poll_updates(cfg: BridgeConfig) -> AsyncIterator[dict[str, Any]]: async def poll_updates(cfg: BridgeConfig) -> AsyncIterator[dict[str, Any]]:
@@ -621,7 +681,7 @@ async def _send_with_resume(
await enqueue(chat_id, user_msg_id, text, resume) await enqueue(chat_id, user_msg_id, text, resume)
async def _run_main_loop( async def run_main_loop(
cfg: BridgeConfig, cfg: BridgeConfig,
poller: Callable[[BridgeConfig], AsyncIterator[dict[str, Any]]] = poll_updates, poller: Callable[[BridgeConfig], AsyncIterator[dict[str, Any]]] = poll_updates,
) -> None: ) -> None:
+2 -2
View File
@@ -8,7 +8,7 @@ import typer
from . import __version__ from . import __version__
from .backends import EngineBackend from .backends import EngineBackend
from .bridge import BridgeConfig, _run_main_loop from .bridge import BridgeConfig, run_main_loop
from .config import ConfigError, load_telegram_config from .config import ConfigError, load_telegram_config
from .engines import get_backend, get_engine_config, list_backends from .engines import get_backend, get_engine_config, list_backends
from .logging import setup_logging from .logging import setup_logging
@@ -90,7 +90,7 @@ def _run_engine(*, engine: str, final_notify: bool, debug: bool) -> None:
except ConfigError as e: except ConfigError as e:
typer.echo(str(e), err=True) typer.echo(str(e), err=True)
raise typer.Exit(code=1) raise typer.Exit(code=1)
anyio.run(_run_main_loop, cfg) anyio.run(run_main_loop, cfg)
app = typer.Typer( app = typer.Typer(
+2 -2
View File
@@ -774,7 +774,7 @@ async def test_send_with_resume_reports_when_missing() -> None:
@pytest.mark.anyio @pytest.mark.anyio
async def test_run_main_loop_routes_reply_to_running_resume() -> None: async def test_run_main_loop_routes_reply_to_running_resume() -> None:
from takopi.bridge import BridgeConfig, _run_main_loop from takopi.bridge import BridgeConfig, run_main_loop
progress_ready = anyio.Event() progress_ready = anyio.Event()
stop_polling = anyio.Event() stop_polling = anyio.Event()
@@ -843,7 +843,7 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None:
await stop_polling.wait() await stop_polling.wait()
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
tg.start_soon(_run_main_loop, cfg, poller) tg.start_soon(run_main_loop, cfg, poller)
try: try:
with anyio.fail_after(2): with anyio.fail_after(2):
await reply_ready.wait() await reply_ready.wait()