From 51cdb72d0b519a3df2b60d76477b783a42d28a52 Mon Sep 17 00:00:00 2001 From: banteg <4562643+banteg@users.noreply.github.com> Date: Fri, 2 Jan 2026 14:22:59 +0400 Subject: [PATCH] refactor: runners and scheduler, fix path handling (#23) --- src/takopi/bridge.py | 95 ++----------- src/takopi/onboarding.py | 23 +-- src/takopi/render.py | 23 ++- src/takopi/runner.py | 5 +- src/takopi/runners/claude.py | 261 +++++++++++++++-------------------- src/takopi/runners/codex.py | 257 +++++++++++++++------------------- src/takopi/scheduler.py | 103 ++++++++++++++ src/takopi/utils/paths.py | 10 +- tests/test_exec_render.py | 6 +- tests/test_paths.py | 16 ++- uv.lock | 2 +- 11 files changed, 377 insertions(+), 424 deletions(-) create mode 100644 src/takopi/scheduler.py diff --git a/src/takopi/bridge.py b/src/takopi/bridge.py index ac6cdc1..0e50699 100644 --- a/src/takopi/bridge.py +++ b/src/takopi/bridge.py @@ -4,7 +4,6 @@ from __future__ import annotations import logging import time -from collections import deque from collections.abc import AsyncIterator, Awaitable, Callable from dataclasses import dataclass, field from typing import Any @@ -21,6 +20,7 @@ from .render import ( ) from .router import AutoRouter, RunnerUnavailableError from .runner import Runner +from .scheduler import ThreadJob, ThreadScheduler from .telegram import BotClient @@ -801,35 +801,6 @@ async def run_main_loop( try: await _set_command_menu(cfg) async with anyio.create_task_group() as tg: - scheduler_lock = anyio.Lock() - - @dataclass(frozen=True, slots=True) - class ThreadJob: - chat_id: int - user_msg_id: int - text: str - resume_token: ResumeToken - - pending_by_thread: dict[str, deque[ThreadJob]] = {} - active_threads: set[str] = set() - busy_until: dict[str, anyio.Event] = {} - - def thread_key(token: ResumeToken) -> str: - return f"{token.engine}:{token.value}" - - async def clear_busy(key: str, done: anyio.Event) -> None: - await done.wait() - async with scheduler_lock: - if busy_until.get(key) is done: - busy_until.pop(key, None) - - async def note_thread_known(token: ResumeToken, done: anyio.Event) -> None: - key = thread_key(token) - async with scheduler_lock: - current = busy_until.get(key) - if current is None or current.is_set(): - busy_until[key] = done - tg.start_soon(clear_busy, key, done) async def run_job( chat_id: int, @@ -882,55 +853,15 @@ async def run_main_loop( except Exception: logger.exception("[handle] worker failed") - async def thread_worker(key: str) -> None: - try: - while True: - async with scheduler_lock: - done = busy_until.get(key) - queue = pending_by_thread.get(key) - if not queue: - pending_by_thread.pop(key, None) - active_threads.discard(key) - return - job = queue.popleft() + async def run_thread_job(job: ThreadJob) -> None: + await run_job( + job.chat_id, + job.user_msg_id, + job.text, + job.resume_token, + ) - if done is not None and not done.is_set(): - await done.wait() - - await run_job( - job.chat_id, - job.user_msg_id, - job.text, - job.resume_token, - ) - finally: - async with scheduler_lock: - active_threads.discard(key) - - async def enqueue( - chat_id: int, - user_msg_id: int, - text: str, - resume_token: ResumeToken, - ) -> None: - key = thread_key(resume_token) - async with scheduler_lock: - queue = pending_by_thread.get(key) - if queue is None: - queue = deque() - pending_by_thread[key] = queue - queue.append( - ThreadJob( - chat_id=chat_id, - user_msg_id=user_msg_id, - text=text, - resume_token=resume_token, - ) - ) - if key in active_threads: - return - active_threads.add(key) - tg.start_soon(thread_worker, key) + scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job) async for msg in poller(cfg): text = msg["text"] @@ -953,7 +884,7 @@ async def run_main_loop( tg.start_soon( _send_with_resume, cfg.bot, - enqueue, + scheduler.enqueue_resume, running_task, msg["chat"]["id"], user_msg_id, @@ -968,10 +899,12 @@ async def run_main_loop( user_msg_id, text, None, - note_thread_known, + scheduler.note_thread_known, engine_override, ) else: - await enqueue(msg["chat"]["id"], user_msg_id, text, resume_token) + await scheduler.enqueue_resume( + msg["chat"]["id"], user_msg_id, text, resume_token + ) finally: await cfg.bot.close() diff --git a/src/takopi/onboarding.py b/src/takopi/onboarding.py index cb7720c..04740ea 100644 --- a/src/takopi/onboarding.py +++ b/src/takopi/onboarding.py @@ -3,7 +3,7 @@ from __future__ import annotations import shutil from dataclasses import dataclass from pathlib import Path -from typing import Sequence + from rich.console import Console from rich.panel import Panel @@ -110,24 +110,3 @@ def render_setup_guide(result: SetupResult) -> None: expand=False, ) console.print(panel) - - -def render_engine_choice(backends: Sequence[EngineBackend]) -> None: - console = Console(stderr=True) - parts: list[str] = [] - parts.append("[bold]available engines:[/]") - parts.append("") - for idx, backend in enumerate(backends, start=1): - parts.append(f"[bold yellow]{idx}.[/] [dim]$[/] takopi {backend.id}") - parts.append(f" [dim]use {backend.id}[/]") - parts.append("") - - panel = Panel( - "\n".join(parts).rstrip(), - title="[bold]welcome to takopi![/]", - subtitle=f"{_OCTOPUS} choose engine", - border_style="yellow", - padding=(1, 2), - expand=False, - ) - console.print(panel) diff --git a/src/takopi/render.py b/src/takopi/render.py index cce62cf..42a4fcb 100644 --- a/src/takopi/render.py +++ b/src/takopi/render.py @@ -23,6 +23,9 @@ HARD_BREAK = " \n" MAX_PROGRESS_CMD_LEN = 300 MAX_FILE_CHANGES_INLINE = 3 +_MD_RENDERER = MarkdownIt("commonmark", {"html": False}) +_BULLET_RE = re.compile(r"(?m)^(\s*)•") + @dataclass(frozen=True) class MarkdownParts: @@ -38,11 +41,10 @@ def assemble_markdown_parts(parts: MarkdownParts) -> str: def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]: - md_renderer = MarkdownIt("commonmark", {"html": False}) - html = md_renderer.render(md or "") + html = _MD_RENDERER.render(md or "") rendered = transform_html(html) - text = re.sub(r"(?m)^(\s*)•", r"\1-", rendered.text) + text = _BULLET_RE.sub(r"\1-", rendered.text) entities = [dict(e) for e in rendered.entities] return text, entities @@ -218,7 +220,6 @@ class ExecProgressRenderer: max_actions: int = 5, command_width: int | None = MAX_PROGRESS_CMD_LEN, resume_formatter: Callable[[ResumeToken], str] | None = None, - show_title: bool = False, ) -> None: self.max_actions = max(0, int(max_actions)) self.command_width = command_width @@ -226,16 +227,13 @@ class ExecProgressRenderer: self.action_count = 0 self.seen_action_ids: set[str] = set() self.resume_token: ResumeToken | None = None - self.session_title: str | None = None self._resume_formatter = resume_formatter - self.show_title = show_title self.engine = engine def note_event(self, event: TakopiEvent) -> bool: match event: - case StartedEvent(resume=resume, title=title): + case StartedEvent(resume=resume): self.resume_token = resume - self.session_title = title return True case ActionEvent(action=action, phase=phase, ok=ok): if action.kind == "turn": @@ -286,7 +284,7 @@ class ExecProgressRenderer: header = format_header( elapsed_s, step, - label=self.label_with_title(label), + label=label, engine=self.engine, ) body = self.assemble_body([line.text for line in self.lines]) @@ -299,18 +297,13 @@ class ExecProgressRenderer: header = format_header( elapsed_s, step, - label=self.label_with_title(status), + label=status, engine=self.engine, ) answer = (answer or "").strip() body = answer if answer else None return MarkdownParts(header=header, body=body, footer=self.render_footer()) - def label_with_title(self, label: str) -> str: - if self.show_title and self.session_title: - return f"{label} ({self.session_title})" - return label - def render_footer(self) -> str | None: if not self.resume_token or self._resume_formatter is None: return None diff --git a/src/takopi/runner.py b/src/takopi/runner.py index ed8a8cd..3073983 100644 --- a/src/takopi/runner.py +++ b/src/takopi/runner.py @@ -91,11 +91,10 @@ class SessionLockMixin: class BaseRunner(SessionLockMixin): engine: EngineId - async def run( + def run( self, prompt: str, resume: ResumeToken | None ) -> AsyncIterator[TakopiEvent]: - async for evt in self.run_locked(prompt, resume): - yield evt + return self.run_locked(prompt, resume) async def run_locked( self, prompt: str, resume: ResumeToken | None diff --git a/src/takopi/runners/claude.py b/src/takopi/runners/claude.py index 2307278..d1756c5 100644 --- a/src/takopi/runners/claude.py +++ b/src/takopi/runners/claude.py @@ -142,14 +142,10 @@ def _tool_action( *, message_id: str | None, parent_tool_use_id: str | None, -) -> Action | None: - tool_id = content.get("id") - if not isinstance(tool_id, str) or not tool_id: - return None +) -> Action: + tool_id = content["id"] tool_name = str(content.get("name") or "tool") - tool_input = content.get("input") - if not isinstance(tool_input, dict): - tool_input = {} + tool_input = content["input"] kind, title = _tool_kind_and_title(tool_name, tool_input) @@ -247,153 +243,126 @@ def translate_claude_event( title: str, state: ClaudeStreamState, ) -> list[TakopiEvent]: - etype = event.get("type") - if etype == "system" and event.get("subtype") == "init": - session_id = event.get("session_id") - if not session_id: - return [] - model = event.get("model") - event_title = str(model) if model else title - meta: dict[str, Any] = {} - for key in ("cwd", "tools", "permissionMode", "output_style", "apiKeySource"): - if key in event: - meta[key] = event.get(key) - if "mcp_servers" in event: - meta["mcp_servers"] = event.get("mcp_servers") + etype = event["type"] + match etype: + case "system" if event.get("subtype") == "init": + session_id = event["session_id"] + model = event.get("model") + event_title = str(model) if model else title + meta: dict[str, Any] = {} + for key in ( + "cwd", + "tools", + "permissionMode", + "output_style", + "apiKeySource", + ): + if key in event: + meta[key] = event.get(key) + if "mcp_servers" in event: + meta["mcp_servers"] = event.get("mcp_servers") - return [ - StartedEvent( - engine=ENGINE, - resume=ResumeToken(engine=ENGINE, value=str(session_id)), - title=event_title, - meta=meta or None, - ) - ] - - if etype == "assistant": - message = event.get("message") - if not isinstance(message, dict): - return [] - message_id = message.get("id") - if not isinstance(message_id, str): - message_id = None - parent_tool_use_id = event.get("parent_tool_use_id") - if not isinstance(parent_tool_use_id, str): - parent_tool_use_id = None - content_blocks = message.get("content") - if not isinstance(content_blocks, list): - return [] - out: list[TakopiEvent] = [] - for content in content_blocks: - if not isinstance(content, dict): - continue - ctype = content.get("type") - if ctype == "tool_use": - action = _tool_action( - content, - message_id=message_id, - parent_tool_use_id=parent_tool_use_id, + return [ + StartedEvent( + engine=ENGINE, + resume=ResumeToken(engine=ENGINE, value=str(session_id)), + title=event_title, + meta=meta or None, ) - if action is None: + ] + case "assistant": + message = event["message"] + message_id = message.get("id") + parent_tool_use_id = event.get("parent_tool_use_id") + content_blocks = message["content"] + out: list[TakopiEvent] = [] + for content in content_blocks: + match content["type"]: + case "tool_use": + action = _tool_action( + content, + message_id=message_id, + parent_tool_use_id=parent_tool_use_id, + ) + state.pending_actions[action.id] = action + out.append(_action_event(phase="started", action=action)) + case "text": + text = content["text"] + if text: + state.last_assistant_text = text + case _: + continue + return out + case "user": + message = event["message"] + message_id = message.get("id") + content_blocks = message["content"] + out: list[TakopiEvent] = [] + for content in content_blocks: + if content["type"] != "tool_result": continue - state.pending_actions[action.id] = action - out.append(_action_event(phase="started", action=action)) - elif ctype == "text": - text = content.get("text") - if isinstance(text, str) and text: - state.last_assistant_text = text - return out - - if etype == "user": - message = event.get("message") - if not isinstance(message, dict): - return [] - message_id = message.get("id") - if not isinstance(message_id, str): - message_id = None - content_blocks = message.get("content") - if not isinstance(content_blocks, list): - return [] - out: list[TakopiEvent] = [] - for content in content_blocks: - if not isinstance(content, dict): - continue - if content.get("type") != "tool_result": - continue - tool_use_id = content.get("tool_use_id") - if not isinstance(tool_use_id, str) or not tool_use_id: - continue - action = state.pending_actions.pop(tool_use_id, None) - if action is None: - action = Action( - id=tool_use_id, - kind="tool", - title="tool result", - detail={}, + tool_use_id = content["tool_use_id"] + action = state.pending_actions.pop(tool_use_id, None) + if action is None: + action = Action( + id=tool_use_id, + kind="tool", + title="tool result", + detail={}, + ) + out.append( + _tool_result_event(content, action=action, message_id=message_id) + ) + return out + case "result": + out: list[TakopiEvent] = [] + for idx, denial in enumerate(event.get("permission_denials", [])): + tool_name = denial.get("tool_name") + denial_title = "permission denied" + if tool_name: + denial_title = f"permission denied: {tool_name}" + tool_use_id = denial.get("tool_use_id") + action_id = ( + f"claude.permission.{tool_use_id}" + if tool_use_id + else f"claude.permission.{idx}" + ) + out.append( + _action_event( + phase="completed", + action=Action( + id=action_id, + kind="warning", + title=denial_title, + detail=denial, + ), + ok=False, + level="warning", + ) ) - out.append( - _tool_result_event(content, action=action, message_id=message_id) - ) - return out - if etype == "result": - out: list[TakopiEvent] = [] - for idx, denial in enumerate(event.get("permission_denials") or []): - if not isinstance(denial, dict): - continue - tool_name = denial.get("tool_name") - denial_title = "permission denied" - if isinstance(tool_name, str) and tool_name: - denial_title = f"permission denied: {tool_name}" - tool_use_id = denial.get("tool_use_id") - action_id = ( - f"claude.permission.{tool_use_id}" - if isinstance(tool_use_id, str) and tool_use_id - else f"claude.permission.{idx}" - ) + ok = not event.get("is_error", False) + result_text = event["result"] + if ok and not result_text and state.last_assistant_text: + result_text = state.last_assistant_text + + resume = ResumeToken(engine=ENGINE, value=str(event["session_id"])) + error = None if ok else _extract_error(event) + usage = _usage_payload(event) + out.append( - _action_event( - phase="completed", - action=Action( - id=action_id, - kind="warning", - title=denial_title, - detail=denial, - ), - ok=False, - level="warning", + CompletedEvent( + engine=ENGINE, + ok=ok, + answer=result_text, + resume=resume, + error=error, + usage=usage or None, ) ) - - ok = not event.get("is_error", False) - result_text = event.get("result") - if not isinstance(result_text, str): - result_text = "" - if ok and not result_text and state.last_assistant_text: - result_text = state.last_assistant_text - - resume_value = event.get("session_id") - resume = ( - ResumeToken(engine=ENGINE, value=str(resume_value)) - if resume_value - else None - ) - error = None if ok else _extract_error(event) - usage = _usage_payload(event) - - out.append( - CompletedEvent( - engine=ENGINE, - ok=ok, - answer=result_text, - resume=resume, - error=error, - usage=usage or None, - ) - ) - return out - - return [] + return out + case _: + return [] @dataclass diff --git a/src/takopi/runners/codex.py b/src/takopi/runners/codex.py index bb37719..acaba57 100644 --- a/src/takopi/runners/codex.py +++ b/src/takopi/runners/codex.py @@ -185,27 +185,21 @@ def _todo_title(summary: _TodoSummary) -> str: def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]: - item_type = item.get("type") or item.get("item_type") + item_type = cast(str, item.get("type") or item.get("item_type")) if item_type == "assistant_message": item_type = "agent_message" - if not item_type: - return [] - if item_type == "agent_message": return [] - action_id = item.get("id") - if not isinstance(action_id, str) or not action_id: - logger.debug("[codex] missing item id in codex event: %r", item) - return [] + action_id = str(item["id"]) phase = cast(ActionPhase, etype.split(".")[-1]) if item_type == "error": if phase != "completed": return [] - message = str(item.get("message") or "codex item error") + message = str(item["message"]) return [ _action_event( phase="completed", @@ -224,7 +218,7 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent] return [] if kind == "command": - title = relativize_command(str(item.get("command") or "")) + title = relativize_command(str(item["command"])) if phase in {"started", "updated"}: return [ _action_event( @@ -235,13 +229,13 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent] ) ] if phase == "completed": - exit_code = item.get("exit_code") - ok = item.get("status") != "failed" - if isinstance(exit_code, int): + exit_code = item["exit_code"] + ok = item["status"] != "failed" + if exit_code is not None: ok = ok and exit_code == 0 detail = { "exit_code": exit_code, - "status": item.get("status"), + "status": item["status"], } return [ _action_event( @@ -255,22 +249,23 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent] ] if kind == "tool": - tool_name = _short_tool_name(item) - title = tool_name - detail = { - "server": item.get("server"), - "tool": item.get("tool"), - "status": item.get("status"), - } - if "arguments" in item: - detail["arguments"] = item.get("arguments") if item_type == "tool_call": - name = item.get("name") - tool_name = str(name) if name else "tool" + name = item["name"] + title = str(name) if name else "tool" + detail = { + "name": name, + "status": item["status"], + "arguments": item.get("arguments"), + } + else: + tool_name = _short_tool_name(item) title = tool_name - detail = {"name": name, "status": item.get("status")} - if "arguments" in item: - detail["arguments"] = item.get("arguments") + detail = { + "server": item["server"], + "tool": item["tool"], + "status": item["status"], + "arguments": item.get("arguments"), + } if phase in {"started", "updated"}: return [ @@ -283,12 +278,10 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent] ) ] if phase == "completed": - ok = item.get("status") != "failed" and not item.get("error") - error = item.get("error") + ok = item["status"] != "failed" and not item["error"] + error = item["error"] if error: - detail["error_message"] = str( - error.get("message") if isinstance(error, dict) else error - ) + detail["error_message"] = str(error.get("message") or error) result_summary = _summarize_tool_result(item.get("result")) if result_summary is not None: detail["result_summary"] = result_summary @@ -304,8 +297,8 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent] ] if kind == "web_search": - title = str(item.get("query") or "") - detail = {"query": item.get("query")} + title = str(item["query"]) + detail = {"query": item["query"]} if phase in {"started", "updated"}: return [ _action_event( @@ -333,11 +326,11 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent] return [] title = _format_change_summary(item) detail = { - "changes": item.get("changes") or [], - "status": item.get("status"), - "error": item.get("error"), + "changes": item["changes"], + "status": item["status"], + "error": item["error"], } - ok = item.get("status") != "failed" + ok = item["status"] != "failed" return [ _action_event( phase="completed", @@ -351,11 +344,11 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent] if kind == "note": if item_type == "todo_list": - summary = _summarize_todo_list(item.get("items")) + summary = _summarize_todo_list(item["items"]) title = _todo_title(summary) detail = {"done": summary.done, "total": summary.total} else: - title = str(item.get("text") or "") + title = str(item["text"]) detail = None if phase in {"started", "updated"}: @@ -384,20 +377,15 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent] def translate_codex_event(event: dict[str, Any], *, title: str) -> list[TakopiEvent]: - etype = event.get("type") - if etype == "thread.started": - thread_id = event.get("thread_id") - if thread_id: - token = ResumeToken(engine=ENGINE, value=str(thread_id)) + etype = event["type"] + match etype: + case "thread.started": + token = ResumeToken(engine=ENGINE, value=str(event["thread_id"])) return [_started_event(token, title=title)] - logger.debug("[codex] codex thread.started missing thread_id: %r", event) - return [] - - if etype in {"item.started", "item.updated", "item.completed"}: - item = event.get("item") or {} - return _translate_item_event(etype, item) - - return [] + case "item.started" | "item.updated" | "item.completed": + return _translate_item_event(etype, event["item"]) + case _: + return [] @dataclass(slots=True) @@ -460,33 +448,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner): def pipes_error_message(self) -> str: return "codex exec failed to open subprocess pipes" - def handle_started_event( - self, - event: StartedEvent, - *, - expected_session: ResumeToken | None, - found_session: ResumeToken | None, - ) -> tuple[ResumeToken | None, bool]: - if event.engine != ENGINE: - raise RuntimeError( - f"codex emitted session token for engine {event.engine!r}" - ) - if expected_session is not None and event.resume != expected_session: - message = ( - f"codex emitted session id {event.resume.value} " - f"but expected {expected_session.value}" - ) - raise RuntimeError(message) - if found_session is None: - return event.resume, True - if event.resume != found_session: - message = ( - f"codex emitted session id {event.resume.value} " - f"but expected {found_session.value}" - ) - raise RuntimeError(message) - return found_session, False - def translate( self, data: dict[str, Any], @@ -495,12 +456,33 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner): resume: ResumeToken | None, found_session: ResumeToken | None, ) -> list[TakopiEvent]: - etype = data.get("type") - if etype == "error": - message = str(data.get("message") or "codex error") - fatal_flag = data.get("fatal") - fatal = fatal_flag is True or fatal_flag is None - if fatal: + etype = data["type"] + match etype: + case "error": + message = str(data["message"]) + fatal_flag = data.get("fatal") + fatal = fatal_flag is True or fatal_flag is None + if fatal: + resume_for_completed = found_session or resume + return [ + _completed_event( + resume=resume_for_completed, + ok=False, + answer=state.final_answer or "", + error=message, + ) + ] + return [ + self.note_event( + message, + state=state, + ok=False, + detail={"code": data.get("code"), "fatal": data.get("fatal")}, + ) + ] + case "turn.failed": + error = data["error"] + message = str(error["message"]) resume_for_completed = found_session or resume return [ _completed_event( @@ -510,67 +492,48 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner): error=message, ) ] - return [ - self.note_event( - message, - state=state, - ok=False, - detail={"code": data.get("code"), "fatal": data.get("fatal")}, - ) - ] - if etype == "turn.failed": - error = data.get("error") or {} - message = str(error.get("message") or "codex turn failed") - resume_for_completed = found_session or resume - return [ - _completed_event( - resume=resume_for_completed, - ok=False, - answer=state.final_answer or "", - error=message, - ) - ] - if etype == "turn.rate_limited": - retry_ms = data.get("retry_after_ms") - message = "rate limited" - if isinstance(retry_ms, int): - message = f"rate limited (retry after {retry_ms}ms)" - return [self.note_event(message, state=state, ok=False)] - if etype == "turn.started": - action_id = f"turn_{state.turn_index}" - state.turn_index += 1 - return [ - _action_event( - phase="started", - action_id=action_id, - kind="turn", - title="turn started", - ) - ] - if etype == "turn.completed": - resume_for_completed = found_session or resume - return [ - _completed_event( - resume=resume_for_completed, - ok=True, - answer=state.final_answer or "", - usage=data.get("usage"), - ) - ] - - if data.get("type") == "item.completed": - item = data.get("item") or {} - item_type = item.get("type") or item.get("item_type") - if item_type == "assistant_message": - item_type = "agent_message" - if item_type == "agent_message" and isinstance(item.get("text"), str): - if state.final_answer is None: - state.final_answer = item["text"] - else: - logger.debug( - "[codex] emitted multiple agent messages; using the last one" + case "turn.rate_limited": + retry_ms = data.get("retry_after_ms") + message = "rate limited" + if isinstance(retry_ms, int): + message = f"rate limited (retry after {retry_ms}ms)" + return [self.note_event(message, state=state, ok=False)] + case "turn.started": + action_id = f"turn_{state.turn_index}" + state.turn_index += 1 + return [ + _action_event( + phase="started", + action_id=action_id, + kind="turn", + title="turn started", ) - state.final_answer = item["text"] + ] + case "turn.completed": + resume_for_completed = found_session or resume + return [ + _completed_event( + resume=resume_for_completed, + ok=True, + answer=state.final_answer or "", + usage=data.get("usage"), + ) + ] + case "item.completed": + item = data["item"] + item_type = cast(str, item.get("type") or item.get("item_type")) + if item_type == "assistant_message": + item_type = "agent_message" + if item_type == "agent_message": + if state.final_answer is None: + state.final_answer = item["text"] + else: + logger.debug( + "[codex] emitted multiple agent messages; using the last one" + ) + state.final_answer = item["text"] + case _: + pass return translate_codex_event(data, title=self.session_title) diff --git a/src/takopi/scheduler.py b/src/takopi/scheduler.py new file mode 100644 index 0000000..9becdf0 --- /dev/null +++ b/src/takopi/scheduler.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Protocol + +import anyio + +from .model import ResumeToken + + +@dataclass(frozen=True, slots=True) +class ThreadJob: + chat_id: int + user_msg_id: int + text: str + resume_token: ResumeToken + + +RunJob = Callable[[ThreadJob], Awaitable[None]] + + +class TaskGroup(Protocol): + def start_soon( + self, func: Callable[..., Awaitable[object]], *args: Any + ) -> None: ... + + +class ThreadScheduler: + def __init__(self, *, task_group: TaskGroup, run_job: RunJob) -> None: + self._task_group = task_group + self._run_job = run_job + self._lock = anyio.Lock() + self._pending_by_thread: dict[str, deque[ThreadJob]] = {} + self._active_threads: set[str] = set() + self._busy_until: dict[str, anyio.Event] = {} + + @staticmethod + def thread_key(token: ResumeToken) -> str: + return f"{token.engine}:{token.value}" + + async def note_thread_known(self, token: ResumeToken, done: anyio.Event) -> None: + key = self.thread_key(token) + async with self._lock: + current = self._busy_until.get(key) + if current is None or current.is_set(): + self._busy_until[key] = done + self._task_group.start_soon(self._clear_busy, key, done) + + async def enqueue(self, job: ThreadJob) -> None: + key = self.thread_key(job.resume_token) + async with self._lock: + queue = self._pending_by_thread.get(key) + if queue is None: + queue = deque() + self._pending_by_thread[key] = queue + queue.append(job) + if key in self._active_threads: + return + self._active_threads.add(key) + self._task_group.start_soon(self._thread_worker, key) + + async def enqueue_resume( + self, + chat_id: int, + user_msg_id: int, + text: str, + resume_token: ResumeToken, + ) -> None: + await self.enqueue( + ThreadJob( + chat_id=chat_id, + user_msg_id=user_msg_id, + text=text, + resume_token=resume_token, + ) + ) + + async def _clear_busy(self, key: str, done: anyio.Event) -> None: + await done.wait() + async with self._lock: + if self._busy_until.get(key) is done: + self._busy_until.pop(key, None) + + async def _thread_worker(self, key: str) -> None: + try: + while True: + async with self._lock: + done = self._busy_until.get(key) + queue = self._pending_by_thread.get(key) + if not queue: + self._pending_by_thread.pop(key, None) + self._active_threads.discard(key) + return + job = queue.popleft() + + if done is not None and not done.is_set(): + await done.wait() + + await self._run_job(job) + finally: + async with self._lock: + self._active_threads.discard(key) diff --git a/src/takopi/utils/paths.py b/src/takopi/utils/paths.py index 91f053a..a4518dc 100644 --- a/src/takopi/utils/paths.py +++ b/src/takopi/utils/paths.py @@ -13,11 +13,11 @@ def relativize_path(value: str, *, base_dir: Path | None = None) -> str: return value if value == base_str: return "." - if value.startswith(base_str): - suffix = value[len(base_str) :] - if suffix.startswith((os.sep, "/")): - suffix = suffix[1:] - return suffix or "." + for sep in (os.sep, "/"): + prefix = base_str if base_str.endswith(sep) else f"{base_str}{sep}" + if value.startswith(prefix): + suffix = value[len(prefix) :] + return suffix or "." return value diff --git a/tests/test_exec_render.py b/tests/test_exec_render.py index ecbcfad..136f97f 100644 --- a/tests/test_exec_render.py +++ b/tests/test_exec_render.py @@ -318,8 +318,8 @@ def test_render_event_cli_ignores_turn_actions() -> None: assert render_event_cli(event) == [] -def test_progress_renderer_ignores_missing_action_id_and_titles() -> None: - renderer = ExecProgressRenderer(engine="codex", show_title=True) +def test_progress_renderer_ignores_missing_action_id() -> None: + renderer = ExecProgressRenderer(engine="codex") resume = ResumeToken(engine="codex", value="abc") renderer.note_event(StartedEvent(engine="codex", resume=resume, title="Session")) @@ -332,4 +332,4 @@ def test_progress_renderer_ignores_missing_action_id_and_titles() -> None: assert renderer.note_event(event) is False header = assemble_markdown_parts(renderer.render_progress_parts(0.0)) - assert header.startswith("working (Session) · codex · 0s") + assert header.startswith("working · codex · 0s") diff --git a/tests/test_paths.py b/tests/test_paths.py index 21bf74b..0f3d772 100644 --- a/tests/test_paths.py +++ b/tests/test_paths.py @@ -2,7 +2,7 @@ from __future__ import annotations from pathlib import Path -from takopi.utils.paths import relativize_command +from takopi.utils.paths import relativize_command, relativize_path def test_relativize_command_rewrites_cwd_paths(tmp_path: Path) -> None: @@ -19,3 +19,17 @@ def test_relativize_command_rewrites_equals_paths(tmp_path: Path) -> None: command = f'rg -n --files -g "*.py" --path={base}/src' expected = 'rg -n --files -g "*.py" --path=src' assert relativize_command(command, base_dir=base) == expected + + +def test_relativize_path_ignores_sibling_prefix(tmp_path: Path) -> None: + base = tmp_path / "repo" + base.mkdir() + value = str(tmp_path / "repo2" / "file.txt") + assert relativize_path(value, base_dir=base) == value + + +def test_relativize_path_inside_base(tmp_path: Path) -> None: + base = tmp_path / "repo" + base.mkdir() + value = str(base / "src" / "app.py") + assert relativize_path(value, base_dir=base) == "src/app.py" diff --git a/uv.lock b/uv.lock index 89079a4..968900c 100644 --- a/uv.lock +++ b/uv.lock @@ -354,7 +354,7 @@ wheels = [ [[package]] name = "takopi" -version = "0.4.0" +version = "0.5.0.dev0" source = { editable = "." } dependencies = [ { name = "anyio" },