diff --git a/docs/developing.md b/docs/developing.md index 0455631..fd92456 100644 --- a/docs/developing.md +++ b/docs/developing.md @@ -36,7 +36,7 @@ The orchestrator module containing: |-----------|---------| | `BridgeConfig` | Frozen dataclass holding runtime config | | `poll_updates()` | Async generator that drains backlog, long-polls updates, filters messages | -| `_run_main_loop()` | TaskGroup-based main loop that spawns per-message handlers | +| `run_main_loop()` | TaskGroup-based main loop that spawns per-message handlers | | `handle_message()` | Per-message handler with progress updates and final render | | `ProgressEdits` | Throttled progress edit worker | | `_handle_cancel()` | `/cancel` routing | @@ -162,7 +162,7 @@ Telegram Update ↓ poll_updates() drains backlog, long-polls, filters chat_id == from_id == cfg.chat_id ↓ -_run_main_loop() spawns tasks in TaskGroup +run_main_loop() spawns tasks in TaskGroup ↓ handle_message() spawned as task ↓ diff --git a/src/takopi/bridge.py b/src/takopi/bridge.py index d61c3ae..f5d3d9f 100644 --- a/src/takopi/bridge.py +++ b/src/takopi/bridge.py @@ -232,6 +232,147 @@ async def _drain_backlog(cfg: BridgeConfig, offset: int | None) -> int | None: drained += len(updates) +@dataclass(frozen=True, slots=True) +class ProgressMessageState: + message_id: int | None + last_edit_at: float + last_rendered: str | None + + +async def send_initial_progress( + cfg: BridgeConfig, + *, + chat_id: int, + user_msg_id: int, + label: str, + renderer: ExecProgressRenderer, + is_resume_line: Callable[[str], bool], + clock: Callable[[], float], + limit: int, +) -> ProgressMessageState: + progress_id: int | None = None + last_edit_at = 0.0 + last_rendered: str | None = None + + initial_md = renderer.render_progress(0.0, label=label) + initial_rendered, initial_entities = prepare_telegram( + initial_md, limit=limit, is_resume_line=is_resume_line + ) + logger.debug( + "[progress] send reply_to=%s md=%s rendered=%s entities=%s", + user_msg_id, + initial_md, + initial_rendered, + initial_entities, + ) + progress_msg = await cfg.bot.send_message( + chat_id=chat_id, + text=initial_rendered, + entities=initial_entities, + reply_to_message_id=user_msg_id, + disable_notification=True, + ) + if progress_msg is not None: + progress_id = int(progress_msg["message_id"]) + last_edit_at = clock() + last_rendered = initial_rendered + logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id) + + return ProgressMessageState( + message_id=progress_id, + last_edit_at=last_edit_at, + last_rendered=last_rendered, + ) + + +@dataclass(slots=True) +class RunOutcome: + cancelled: bool = False + completed: CompletedEvent | None = None + resume: ResumeToken | None = None + + +async def run_runner_with_cancel( + runner: Runner, + *, + prompt: str, + resume_token: ResumeToken | None, + edits: ProgressEdits, + running_task: RunningTask | None, + on_thread_known: Callable[[ResumeToken, anyio.Event], Awaitable[None]] | None, +) -> RunOutcome: + outcome = RunOutcome() + async with anyio.create_task_group() as tg: + + async def run_runner() -> None: + try: + async for evt in runner.run(prompt, resume_token): + _log_runner_event(evt) + if isinstance(evt, StartedEvent): + outcome.resume = evt.resume + if running_task is not None and running_task.resume is None: + running_task.resume = evt.resume + running_task.resume_ready.set() + if on_thread_known is not None: + await on_thread_known(evt.resume, running_task.done) + elif isinstance(evt, CompletedEvent): + outcome.resume = evt.resume or outcome.resume + outcome.completed = evt + await edits.on_event(evt) + finally: + tg.cancel_scope.cancel() + + async def wait_cancel(task: RunningTask) -> None: + await task.cancel_requested.wait() + outcome.cancelled = True + tg.cancel_scope.cancel() + + tg.start_soon(run_runner) + if running_task is not None: + tg.start_soon(wait_cancel, running_task) + + return outcome + + +def sync_resume_token( + renderer: ExecProgressRenderer, resume: ResumeToken | None +) -> ResumeToken | None: + resume = resume or renderer.resume_token + renderer.resume_token = resume + return resume + + +async def send_result_message( + cfg: BridgeConfig, + *, + chat_id: int, + user_msg_id: int, + progress_id: int | None, + markdown: str, + disable_notification: bool, + edit_message_id: int | None, + is_resume_line: Callable[[str], bool], + prepared: tuple[str, list[dict[str, Any]] | None] | None = None, + delete_tag: str = "final", +) -> None: + final_msg, edited = await _send_or_edit_markdown( + cfg.bot, + chat_id=chat_id, + text=markdown, + edit_message_id=edit_message_id, + reply_to_message_id=user_msg_id, + disable_notification=disable_notification, + limit=TELEGRAM_MARKDOWN_LIMIT, + is_resume_line=is_resume_line, + prepared=prepared, + ) + if final_msg is None: + return + if progress_id is not None and (edit_message_id is None or not edited): + logger.debug("[%s] delete progress message_id=%s", delete_tag, progress_id) + await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id) + + async def handle_message( cfg: BridgeConfig, *, @@ -262,35 +403,17 @@ async def handle_message( max_actions=5, resume_formatter=runner.format_resume ) - progress_id: int | None = None - last_edit_at = 0.0 - last_rendered: str | None = None - - initial_md = progress_renderer.render_progress( - 0.0, label=f"working ({runner.engine})" - ) - initial_rendered, initial_entities = prepare_telegram( - initial_md, limit=TELEGRAM_MARKDOWN_LIMIT, is_resume_line=is_resume_line - ) - logger.debug( - "[progress] send reply_to=%s md=%s rendered=%s entities=%s", - user_msg_id, - initial_md, - initial_rendered, - initial_entities, - ) - progress_msg = await cfg.bot.send_message( + progress_state = await send_initial_progress( + cfg, chat_id=chat_id, - text=initial_rendered, - entities=initial_entities, - reply_to_message_id=user_msg_id, - disable_notification=True, + user_msg_id=user_msg_id, + label=f"working ({runner.engine})", + renderer=progress_renderer, + is_resume_line=is_resume_line, + clock=clock, + limit=TELEGRAM_MARKDOWN_LIMIT, ) - if progress_msg is not None: - progress_id = int(progress_msg["message_id"]) - last_edit_at = clock() - last_rendered = initial_rendered - logger.debug("[progress] sent chat_id=%s message_id=%s", chat_id, progress_id) + progress_id = progress_state.message_id edits = ProgressEdits( bot=cfg.bot, @@ -302,23 +425,17 @@ async def handle_message( clock=clock, sleep=sleep, limit=TELEGRAM_MARKDOWN_LIMIT, - last_edit_at=last_edit_at, - last_rendered=last_rendered, + last_edit_at=progress_state.last_edit_at, + last_rendered=progress_state.last_rendered, is_resume_line=is_resume_line, ) - cancel_exc_type = anyio.get_cancelled_exc_class() - cancelled = False - error: Exception | None = None - resume_token_value: ResumeToken | None = None - answer: str | None = None - run_ok: bool | None = None - run_error: str | None = None running_task: RunningTask | None = None if running_tasks is not None and progress_id is not None: running_task = RunningTask() running_tasks[progress_id] = running_task + cancel_exc_type = anyio.get_cancelled_exc_class() edits_scope = anyio.CancelScope() async def run_edits() -> None: @@ -329,67 +446,22 @@ async def handle_message( # Edits are best-effort; cancellation should not bubble into the task group. return + outcome = RunOutcome() + error: Exception | None = None + async with anyio.create_task_group() as tg: if progress_id is not None: tg.start_soon(run_edits) - async def run_exec() -> CompletedEvent | None: - nonlocal cancelled - cancel_flag = False - completed: CompletedEvent | None = None - - async with anyio.create_task_group() as exec_tg: - - async def run_runner() -> None: - nonlocal resume_token_value, completed, answer, run_ok, run_error - try: - async for evt in runner.run(runner_text, resume_token): - _log_runner_event(evt) - if isinstance(evt, StartedEvent): - resume_token_value = evt.resume - if ( - running_task is not None - and running_task.resume is None - ): - running_task.resume = resume_token_value - running_task.resume_ready.set() - if on_thread_known is not None: - await on_thread_known( - resume_token_value, running_task.done - ) - elif isinstance(evt, CompletedEvent): - resume_token_value = evt.resume or resume_token_value - answer = evt.answer - run_ok = evt.ok - run_error = evt.error - completed = evt - await edits.on_event(evt) - finally: - exec_tg.cancel_scope.cancel() - - async def wait_cancel() -> None: - nonlocal cancel_flag - if running_task is None: - return - await running_task.cancel_requested.wait() - cancel_flag = True - exec_tg.cancel_scope.cancel() - - exec_tg.start_soon(run_runner) - if running_task is not None: - exec_tg.start_soon(wait_cancel) - - if cancel_flag: - cancelled = True - return completed - try: - completed = await run_exec() - if completed is not None: - resume_token_value = completed.resume or resume_token_value - answer = completed.answer - run_ok = completed.ok - run_error = completed.error + outcome = await run_runner_with_cancel( + runner, + prompt=runner_text, + resume_token=resume_token, + edits=edits, + running_task=running_task, + on_thread_known=on_thread_known, + ) except Exception as e: error = e finally: @@ -400,67 +472,60 @@ async def handle_message( ): running_task.done.set() running_tasks.pop(progress_id, None) - if not cancelled and error is None: + if not outcome.cancelled and error is None: + # Give pending progress edits a chance to flush if they're ready. await anyio.sleep(0) edits_scope.cancel() + elapsed = clock() - started_at + if error is not None: - elapsed = clock() - started_at - if resume_token_value is None: - resume_token_value = progress_renderer.resume_token - progress_renderer.resume_token = resume_token_value + sync_resume_token(progress_renderer, outcome.resume) err_body = str(error) final_md = progress_renderer.render_final(elapsed, err_body, status="error") logger.debug("[error] markdown: %s", final_md) - final_msg, edited = await _send_or_edit_markdown( - cfg.bot, + await send_result_message( + cfg, chat_id=chat_id, - text=final_md, - edit_message_id=progress_id, - reply_to_message_id=user_msg_id, + user_msg_id=user_msg_id, + progress_id=progress_id, + markdown=final_md, disable_notification=True, - limit=TELEGRAM_MARKDOWN_LIMIT, + edit_message_id=progress_id, is_resume_line=is_resume_line, + delete_tag="error", ) - if final_msg is None: - return - if progress_id is not None and not edited: - logger.debug("[error] delete progress message_id=%s", progress_id) - await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id) return - elapsed = clock() - started_at - if cancelled: - if resume_token_value is None: - resume_token_value = progress_renderer.resume_token + if outcome.cancelled: + resume = sync_resume_token(progress_renderer, outcome.resume) logger.info( "[handle] cancelled resume=%s elapsed=%.1fs", - resume_token_value.value if resume_token_value else None, + resume.value if resume else None, elapsed, ) - progress_renderer.resume_token = resume_token_value final_md = progress_renderer.render_progress(elapsed, label="`cancelled`") - final_msg, edited = await _send_or_edit_markdown( - cfg.bot, + await send_result_message( + cfg, chat_id=chat_id, - text=final_md, - edit_message_id=progress_id, - reply_to_message_id=user_msg_id, + user_msg_id=user_msg_id, + progress_id=progress_id, + markdown=final_md, disable_notification=True, - limit=TELEGRAM_MARKDOWN_LIMIT, + edit_message_id=progress_id, is_resume_line=is_resume_line, + delete_tag="cancel", ) - if final_msg is None: - return - if progress_id is not None and not edited: - logger.debug("[cancel] delete progress message_id=%s", progress_id) - await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id) return - if answer is None: + if outcome.completed is None: raise RuntimeError("runner finished without a completed event") - final_answer = answer + completed = outcome.completed + run_ok = completed.ok + run_error = completed.error + + final_answer = completed.answer if run_ok is False and run_error: if final_answer.strip(): final_answer = f"{final_answer}\n\n{run_error}" @@ -470,11 +535,10 @@ async def handle_message( status = ( "error" if run_ok is False else ("done" if final_answer.strip() else "error") ) - if resume_token_value is None: - resume_token_value = progress_renderer.resume_token - progress_renderer.resume_token = resume_token_value + sync_resume_token(progress_renderer, completed.resume or outcome.resume) final_md = progress_renderer.render_final(elapsed, final_answer, status=status) logger.debug("[final] markdown: %s", final_md) + final_rendered, final_entities = prepare_telegram( final_md, limit=TELEGRAM_MARKDOWN_LIMIT, is_resume_line=is_resume_line ) @@ -496,22 +560,18 @@ async def handle_message( final_entities, ) - final_msg, edited = await _send_or_edit_markdown( - cfg.bot, + await send_result_message( + cfg, chat_id=chat_id, - text=final_md, - edit_message_id=edit_message_id, - reply_to_message_id=user_msg_id, + user_msg_id=user_msg_id, + progress_id=progress_id, + markdown=final_md, disable_notification=False, - limit=TELEGRAM_MARKDOWN_LIMIT, + edit_message_id=edit_message_id, is_resume_line=is_resume_line, prepared=(final_rendered, final_entities), + delete_tag="final", ) - if final_msg is None: - return - if progress_id is not None and (edit_message_id is None or not edited): - logger.debug("[final] delete progress message_id=%s", progress_id) - await cfg.bot.delete_message(chat_id=chat_id, message_id=progress_id) async def poll_updates(cfg: BridgeConfig) -> AsyncIterator[dict[str, Any]]: @@ -621,7 +681,7 @@ async def _send_with_resume( await enqueue(chat_id, user_msg_id, text, resume) -async def _run_main_loop( +async def run_main_loop( cfg: BridgeConfig, poller: Callable[[BridgeConfig], AsyncIterator[dict[str, Any]]] = poll_updates, ) -> None: diff --git a/src/takopi/cli.py b/src/takopi/cli.py index 3a63f53..48511b7 100644 --- a/src/takopi/cli.py +++ b/src/takopi/cli.py @@ -8,7 +8,7 @@ import typer from . import __version__ from .backends import EngineBackend -from .bridge import BridgeConfig, _run_main_loop +from .bridge import BridgeConfig, run_main_loop from .config import ConfigError, load_telegram_config from .engines import get_backend, get_engine_config, list_backends from .logging import setup_logging @@ -90,7 +90,7 @@ def _run_engine(*, engine: str, final_notify: bool, debug: bool) -> None: except ConfigError as e: typer.echo(str(e), err=True) raise typer.Exit(code=1) - anyio.run(_run_main_loop, cfg) + anyio.run(run_main_loop, cfg) app = typer.Typer( diff --git a/tests/test_exec_bridge.py b/tests/test_exec_bridge.py index b72af83..4546690 100644 --- a/tests/test_exec_bridge.py +++ b/tests/test_exec_bridge.py @@ -774,7 +774,7 @@ async def test_send_with_resume_reports_when_missing() -> None: @pytest.mark.anyio async def test_run_main_loop_routes_reply_to_running_resume() -> None: - from takopi.bridge import BridgeConfig, _run_main_loop + from takopi.bridge import BridgeConfig, run_main_loop progress_ready = anyio.Event() stop_polling = anyio.Event() @@ -843,7 +843,7 @@ async def test_run_main_loop_routes_reply_to_running_resume() -> None: await stop_polling.wait() async with anyio.create_task_group() as tg: - tg.start_soon(_run_main_loop, cfg, poller) + tg.start_soon(run_main_loop, cfg, poller) try: with anyio.fail_after(2): await reply_ready.wait()