refactor: runners and scheduler, fix path handling (#23)

This commit is contained in:
banteg
2026-01-02 14:22:59 +04:00
committed by GitHub
parent 9c59056666
commit 51cdb72d0b
11 changed files with 377 additions and 424 deletions
+14 -81
View File
@@ -4,7 +4,6 @@ from __future__ import annotations
import logging import logging
import time import time
from collections import deque
from collections.abc import AsyncIterator, Awaitable, Callable from collections.abc import AsyncIterator, Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@@ -21,6 +20,7 @@ from .render import (
) )
from .router import AutoRouter, RunnerUnavailableError from .router import AutoRouter, RunnerUnavailableError
from .runner import Runner from .runner import Runner
from .scheduler import ThreadJob, ThreadScheduler
from .telegram import BotClient from .telegram import BotClient
@@ -801,35 +801,6 @@ async def run_main_loop(
try: try:
await _set_command_menu(cfg) await _set_command_menu(cfg)
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
scheduler_lock = anyio.Lock()
@dataclass(frozen=True, slots=True)
class ThreadJob:
chat_id: int
user_msg_id: int
text: str
resume_token: ResumeToken
pending_by_thread: dict[str, deque[ThreadJob]] = {}
active_threads: set[str] = set()
busy_until: dict[str, anyio.Event] = {}
def thread_key(token: ResumeToken) -> str:
return f"{token.engine}:{token.value}"
async def clear_busy(key: str, done: anyio.Event) -> None:
await done.wait()
async with scheduler_lock:
if busy_until.get(key) is done:
busy_until.pop(key, None)
async def note_thread_known(token: ResumeToken, done: anyio.Event) -> None:
key = thread_key(token)
async with scheduler_lock:
current = busy_until.get(key)
if current is None or current.is_set():
busy_until[key] = done
tg.start_soon(clear_busy, key, done)
async def run_job( async def run_job(
chat_id: int, chat_id: int,
@@ -882,55 +853,15 @@ async def run_main_loop(
except Exception: except Exception:
logger.exception("[handle] worker failed") logger.exception("[handle] worker failed")
async def thread_worker(key: str) -> None: async def run_thread_job(job: ThreadJob) -> None:
try: await run_job(
while True: job.chat_id,
async with scheduler_lock: job.user_msg_id,
done = busy_until.get(key) job.text,
queue = pending_by_thread.get(key) job.resume_token,
if not queue: )
pending_by_thread.pop(key, None)
active_threads.discard(key)
return
job = queue.popleft()
if done is not None and not done.is_set(): scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
await done.wait()
await run_job(
job.chat_id,
job.user_msg_id,
job.text,
job.resume_token,
)
finally:
async with scheduler_lock:
active_threads.discard(key)
async def enqueue(
chat_id: int,
user_msg_id: int,
text: str,
resume_token: ResumeToken,
) -> None:
key = thread_key(resume_token)
async with scheduler_lock:
queue = pending_by_thread.get(key)
if queue is None:
queue = deque()
pending_by_thread[key] = queue
queue.append(
ThreadJob(
chat_id=chat_id,
user_msg_id=user_msg_id,
text=text,
resume_token=resume_token,
)
)
if key in active_threads:
return
active_threads.add(key)
tg.start_soon(thread_worker, key)
async for msg in poller(cfg): async for msg in poller(cfg):
text = msg["text"] text = msg["text"]
@@ -953,7 +884,7 @@ async def run_main_loop(
tg.start_soon( tg.start_soon(
_send_with_resume, _send_with_resume,
cfg.bot, cfg.bot,
enqueue, scheduler.enqueue_resume,
running_task, running_task,
msg["chat"]["id"], msg["chat"]["id"],
user_msg_id, user_msg_id,
@@ -968,10 +899,12 @@ async def run_main_loop(
user_msg_id, user_msg_id,
text, text,
None, None,
note_thread_known, scheduler.note_thread_known,
engine_override, engine_override,
) )
else: else:
await enqueue(msg["chat"]["id"], user_msg_id, text, resume_token) await scheduler.enqueue_resume(
msg["chat"]["id"], user_msg_id, text, resume_token
)
finally: finally:
await cfg.bot.close() await cfg.bot.close()
+1 -22
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
import shutil import shutil
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Sequence
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
@@ -110,24 +110,3 @@ def render_setup_guide(result: SetupResult) -> None:
expand=False, expand=False,
) )
console.print(panel) console.print(panel)
def render_engine_choice(backends: Sequence[EngineBackend]) -> None:
console = Console(stderr=True)
parts: list[str] = []
parts.append("[bold]available engines:[/]")
parts.append("")
for idx, backend in enumerate(backends, start=1):
parts.append(f"[bold yellow]{idx}.[/] [dim]$[/] takopi {backend.id}")
parts.append(f" [dim]use {backend.id}[/]")
parts.append("")
panel = Panel(
"\n".join(parts).rstrip(),
title="[bold]welcome to takopi![/]",
subtitle=f"{_OCTOPUS} choose engine",
border_style="yellow",
padding=(1, 2),
expand=False,
)
console.print(panel)
+8 -15
View File
@@ -23,6 +23,9 @@ HARD_BREAK = " \n"
MAX_PROGRESS_CMD_LEN = 300 MAX_PROGRESS_CMD_LEN = 300
MAX_FILE_CHANGES_INLINE = 3 MAX_FILE_CHANGES_INLINE = 3
_MD_RENDERER = MarkdownIt("commonmark", {"html": False})
_BULLET_RE = re.compile(r"(?m)^(\s*)•")
@dataclass(frozen=True) @dataclass(frozen=True)
class MarkdownParts: class MarkdownParts:
@@ -38,11 +41,10 @@ def assemble_markdown_parts(parts: MarkdownParts) -> str:
def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]: def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]:
md_renderer = MarkdownIt("commonmark", {"html": False}) html = _MD_RENDERER.render(md or "")
html = md_renderer.render(md or "")
rendered = transform_html(html) rendered = transform_html(html)
text = re.sub(r"(?m)^(\s*)•", r"\1-", rendered.text) text = _BULLET_RE.sub(r"\1-", rendered.text)
entities = [dict(e) for e in rendered.entities] entities = [dict(e) for e in rendered.entities]
return text, entities return text, entities
@@ -218,7 +220,6 @@ class ExecProgressRenderer:
max_actions: int = 5, max_actions: int = 5,
command_width: int | None = MAX_PROGRESS_CMD_LEN, command_width: int | None = MAX_PROGRESS_CMD_LEN,
resume_formatter: Callable[[ResumeToken], str] | None = None, resume_formatter: Callable[[ResumeToken], str] | None = None,
show_title: bool = False,
) -> None: ) -> None:
self.max_actions = max(0, int(max_actions)) self.max_actions = max(0, int(max_actions))
self.command_width = command_width self.command_width = command_width
@@ -226,16 +227,13 @@ class ExecProgressRenderer:
self.action_count = 0 self.action_count = 0
self.seen_action_ids: set[str] = set() self.seen_action_ids: set[str] = set()
self.resume_token: ResumeToken | None = None self.resume_token: ResumeToken | None = None
self.session_title: str | None = None
self._resume_formatter = resume_formatter self._resume_formatter = resume_formatter
self.show_title = show_title
self.engine = engine self.engine = engine
def note_event(self, event: TakopiEvent) -> bool: def note_event(self, event: TakopiEvent) -> bool:
match event: match event:
case StartedEvent(resume=resume, title=title): case StartedEvent(resume=resume):
self.resume_token = resume self.resume_token = resume
self.session_title = title
return True return True
case ActionEvent(action=action, phase=phase, ok=ok): case ActionEvent(action=action, phase=phase, ok=ok):
if action.kind == "turn": if action.kind == "turn":
@@ -286,7 +284,7 @@ class ExecProgressRenderer:
header = format_header( header = format_header(
elapsed_s, elapsed_s,
step, step,
label=self.label_with_title(label), label=label,
engine=self.engine, engine=self.engine,
) )
body = self.assemble_body([line.text for line in self.lines]) body = self.assemble_body([line.text for line in self.lines])
@@ -299,18 +297,13 @@ class ExecProgressRenderer:
header = format_header( header = format_header(
elapsed_s, elapsed_s,
step, step,
label=self.label_with_title(status), label=status,
engine=self.engine, engine=self.engine,
) )
answer = (answer or "").strip() answer = (answer or "").strip()
body = answer if answer else None body = answer if answer else None
return MarkdownParts(header=header, body=body, footer=self.render_footer()) return MarkdownParts(header=header, body=body, footer=self.render_footer())
def label_with_title(self, label: str) -> str:
if self.show_title and self.session_title:
return f"{label} ({self.session_title})"
return label
def render_footer(self) -> str | None: def render_footer(self) -> str | None:
if not self.resume_token or self._resume_formatter is None: if not self.resume_token or self._resume_formatter is None:
return None return None
+2 -3
View File
@@ -91,11 +91,10 @@ class SessionLockMixin:
class BaseRunner(SessionLockMixin): class BaseRunner(SessionLockMixin):
engine: EngineId engine: EngineId
async def run( def run(
self, prompt: str, resume: ResumeToken | None self, prompt: str, resume: ResumeToken | None
) -> AsyncIterator[TakopiEvent]: ) -> AsyncIterator[TakopiEvent]:
async for evt in self.run_locked(prompt, resume): return self.run_locked(prompt, resume)
yield evt
async def run_locked( async def run_locked(
self, prompt: str, resume: ResumeToken | None self, prompt: str, resume: ResumeToken | None
+115 -146
View File
@@ -142,14 +142,10 @@ def _tool_action(
*, *,
message_id: str | None, message_id: str | None,
parent_tool_use_id: str | None, parent_tool_use_id: str | None,
) -> Action | None: ) -> Action:
tool_id = content.get("id") tool_id = content["id"]
if not isinstance(tool_id, str) or not tool_id:
return None
tool_name = str(content.get("name") or "tool") tool_name = str(content.get("name") or "tool")
tool_input = content.get("input") tool_input = content["input"]
if not isinstance(tool_input, dict):
tool_input = {}
kind, title = _tool_kind_and_title(tool_name, tool_input) kind, title = _tool_kind_and_title(tool_name, tool_input)
@@ -247,153 +243,126 @@ def translate_claude_event(
title: str, title: str,
state: ClaudeStreamState, state: ClaudeStreamState,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
etype = event.get("type") etype = event["type"]
if etype == "system" and event.get("subtype") == "init": match etype:
session_id = event.get("session_id") case "system" if event.get("subtype") == "init":
if not session_id: session_id = event["session_id"]
return [] model = event.get("model")
model = event.get("model") event_title = str(model) if model else title
event_title = str(model) if model else title meta: dict[str, Any] = {}
meta: dict[str, Any] = {} for key in (
for key in ("cwd", "tools", "permissionMode", "output_style", "apiKeySource"): "cwd",
if key in event: "tools",
meta[key] = event.get(key) "permissionMode",
if "mcp_servers" in event: "output_style",
meta["mcp_servers"] = event.get("mcp_servers") "apiKeySource",
):
if key in event:
meta[key] = event.get(key)
if "mcp_servers" in event:
meta["mcp_servers"] = event.get("mcp_servers")
return [ return [
StartedEvent( StartedEvent(
engine=ENGINE, engine=ENGINE,
resume=ResumeToken(engine=ENGINE, value=str(session_id)), resume=ResumeToken(engine=ENGINE, value=str(session_id)),
title=event_title, title=event_title,
meta=meta or None, meta=meta or None,
)
]
if etype == "assistant":
message = event.get("message")
if not isinstance(message, dict):
return []
message_id = message.get("id")
if not isinstance(message_id, str):
message_id = None
parent_tool_use_id = event.get("parent_tool_use_id")
if not isinstance(parent_tool_use_id, str):
parent_tool_use_id = None
content_blocks = message.get("content")
if not isinstance(content_blocks, list):
return []
out: list[TakopiEvent] = []
for content in content_blocks:
if not isinstance(content, dict):
continue
ctype = content.get("type")
if ctype == "tool_use":
action = _tool_action(
content,
message_id=message_id,
parent_tool_use_id=parent_tool_use_id,
) )
if action is None: ]
case "assistant":
message = event["message"]
message_id = message.get("id")
parent_tool_use_id = event.get("parent_tool_use_id")
content_blocks = message["content"]
out: list[TakopiEvent] = []
for content in content_blocks:
match content["type"]:
case "tool_use":
action = _tool_action(
content,
message_id=message_id,
parent_tool_use_id=parent_tool_use_id,
)
state.pending_actions[action.id] = action
out.append(_action_event(phase="started", action=action))
case "text":
text = content["text"]
if text:
state.last_assistant_text = text
case _:
continue
return out
case "user":
message = event["message"]
message_id = message.get("id")
content_blocks = message["content"]
out: list[TakopiEvent] = []
for content in content_blocks:
if content["type"] != "tool_result":
continue continue
state.pending_actions[action.id] = action tool_use_id = content["tool_use_id"]
out.append(_action_event(phase="started", action=action)) action = state.pending_actions.pop(tool_use_id, None)
elif ctype == "text": if action is None:
text = content.get("text") action = Action(
if isinstance(text, str) and text: id=tool_use_id,
state.last_assistant_text = text kind="tool",
return out title="tool result",
detail={},
if etype == "user": )
message = event.get("message") out.append(
if not isinstance(message, dict): _tool_result_event(content, action=action, message_id=message_id)
return [] )
message_id = message.get("id") return out
if not isinstance(message_id, str): case "result":
message_id = None out: list[TakopiEvent] = []
content_blocks = message.get("content") for idx, denial in enumerate(event.get("permission_denials", [])):
if not isinstance(content_blocks, list): tool_name = denial.get("tool_name")
return [] denial_title = "permission denied"
out: list[TakopiEvent] = [] if tool_name:
for content in content_blocks: denial_title = f"permission denied: {tool_name}"
if not isinstance(content, dict): tool_use_id = denial.get("tool_use_id")
continue action_id = (
if content.get("type") != "tool_result": f"claude.permission.{tool_use_id}"
continue if tool_use_id
tool_use_id = content.get("tool_use_id") else f"claude.permission.{idx}"
if not isinstance(tool_use_id, str) or not tool_use_id: )
continue out.append(
action = state.pending_actions.pop(tool_use_id, None) _action_event(
if action is None: phase="completed",
action = Action( action=Action(
id=tool_use_id, id=action_id,
kind="tool", kind="warning",
title="tool result", title=denial_title,
detail={}, detail=denial,
),
ok=False,
level="warning",
)
) )
out.append(
_tool_result_event(content, action=action, message_id=message_id)
)
return out
if etype == "result": ok = not event.get("is_error", False)
out: list[TakopiEvent] = [] result_text = event["result"]
for idx, denial in enumerate(event.get("permission_denials") or []): if ok and not result_text and state.last_assistant_text:
if not isinstance(denial, dict): result_text = state.last_assistant_text
continue
tool_name = denial.get("tool_name") resume = ResumeToken(engine=ENGINE, value=str(event["session_id"]))
denial_title = "permission denied" error = None if ok else _extract_error(event)
if isinstance(tool_name, str) and tool_name: usage = _usage_payload(event)
denial_title = f"permission denied: {tool_name}"
tool_use_id = denial.get("tool_use_id")
action_id = (
f"claude.permission.{tool_use_id}"
if isinstance(tool_use_id, str) and tool_use_id
else f"claude.permission.{idx}"
)
out.append( out.append(
_action_event( CompletedEvent(
phase="completed", engine=ENGINE,
action=Action( ok=ok,
id=action_id, answer=result_text,
kind="warning", resume=resume,
title=denial_title, error=error,
detail=denial, usage=usage or None,
),
ok=False,
level="warning",
) )
) )
return out
ok = not event.get("is_error", False) case _:
result_text = event.get("result") return []
if not isinstance(result_text, str):
result_text = ""
if ok and not result_text and state.last_assistant_text:
result_text = state.last_assistant_text
resume_value = event.get("session_id")
resume = (
ResumeToken(engine=ENGINE, value=str(resume_value))
if resume_value
else None
)
error = None if ok else _extract_error(event)
usage = _usage_payload(event)
out.append(
CompletedEvent(
engine=ENGINE,
ok=ok,
answer=result_text,
resume=resume,
error=error,
usage=usage or None,
)
)
return out
return []
@dataclass @dataclass
+110 -147
View File
@@ -185,27 +185,21 @@ def _todo_title(summary: _TodoSummary) -> str:
def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]: def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]:
item_type = item.get("type") or item.get("item_type") item_type = cast(str, item.get("type") or item.get("item_type"))
if item_type == "assistant_message": if item_type == "assistant_message":
item_type = "agent_message" item_type = "agent_message"
if not item_type:
return []
if item_type == "agent_message": if item_type == "agent_message":
return [] return []
action_id = item.get("id") action_id = str(item["id"])
if not isinstance(action_id, str) or not action_id:
logger.debug("[codex] missing item id in codex event: %r", item)
return []
phase = cast(ActionPhase, etype.split(".")[-1]) phase = cast(ActionPhase, etype.split(".")[-1])
if item_type == "error": if item_type == "error":
if phase != "completed": if phase != "completed":
return [] return []
message = str(item.get("message") or "codex item error") message = str(item["message"])
return [ return [
_action_event( _action_event(
phase="completed", phase="completed",
@@ -224,7 +218,7 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
return [] return []
if kind == "command": if kind == "command":
title = relativize_command(str(item.get("command") or "")) title = relativize_command(str(item["command"]))
if phase in {"started", "updated"}: if phase in {"started", "updated"}:
return [ return [
_action_event( _action_event(
@@ -235,13 +229,13 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
) )
] ]
if phase == "completed": if phase == "completed":
exit_code = item.get("exit_code") exit_code = item["exit_code"]
ok = item.get("status") != "failed" ok = item["status"] != "failed"
if isinstance(exit_code, int): if exit_code is not None:
ok = ok and exit_code == 0 ok = ok and exit_code == 0
detail = { detail = {
"exit_code": exit_code, "exit_code": exit_code,
"status": item.get("status"), "status": item["status"],
} }
return [ return [
_action_event( _action_event(
@@ -255,22 +249,23 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
] ]
if kind == "tool": if kind == "tool":
tool_name = _short_tool_name(item)
title = tool_name
detail = {
"server": item.get("server"),
"tool": item.get("tool"),
"status": item.get("status"),
}
if "arguments" in item:
detail["arguments"] = item.get("arguments")
if item_type == "tool_call": if item_type == "tool_call":
name = item.get("name") name = item["name"]
tool_name = str(name) if name else "tool" title = str(name) if name else "tool"
detail = {
"name": name,
"status": item["status"],
"arguments": item.get("arguments"),
}
else:
tool_name = _short_tool_name(item)
title = tool_name title = tool_name
detail = {"name": name, "status": item.get("status")} detail = {
if "arguments" in item: "server": item["server"],
detail["arguments"] = item.get("arguments") "tool": item["tool"],
"status": item["status"],
"arguments": item.get("arguments"),
}
if phase in {"started", "updated"}: if phase in {"started", "updated"}:
return [ return [
@@ -283,12 +278,10 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
) )
] ]
if phase == "completed": if phase == "completed":
ok = item.get("status") != "failed" and not item.get("error") ok = item["status"] != "failed" and not item["error"]
error = item.get("error") error = item["error"]
if error: if error:
detail["error_message"] = str( detail["error_message"] = str(error.get("message") or error)
error.get("message") if isinstance(error, dict) else error
)
result_summary = _summarize_tool_result(item.get("result")) result_summary = _summarize_tool_result(item.get("result"))
if result_summary is not None: if result_summary is not None:
detail["result_summary"] = result_summary detail["result_summary"] = result_summary
@@ -304,8 +297,8 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
] ]
if kind == "web_search": if kind == "web_search":
title = str(item.get("query") or "") title = str(item["query"])
detail = {"query": item.get("query")} detail = {"query": item["query"]}
if phase in {"started", "updated"}: if phase in {"started", "updated"}:
return [ return [
_action_event( _action_event(
@@ -333,11 +326,11 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
return [] return []
title = _format_change_summary(item) title = _format_change_summary(item)
detail = { detail = {
"changes": item.get("changes") or [], "changes": item["changes"],
"status": item.get("status"), "status": item["status"],
"error": item.get("error"), "error": item["error"],
} }
ok = item.get("status") != "failed" ok = item["status"] != "failed"
return [ return [
_action_event( _action_event(
phase="completed", phase="completed",
@@ -351,11 +344,11 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
if kind == "note": if kind == "note":
if item_type == "todo_list": if item_type == "todo_list":
summary = _summarize_todo_list(item.get("items")) summary = _summarize_todo_list(item["items"])
title = _todo_title(summary) title = _todo_title(summary)
detail = {"done": summary.done, "total": summary.total} detail = {"done": summary.done, "total": summary.total}
else: else:
title = str(item.get("text") or "") title = str(item["text"])
detail = None detail = None
if phase in {"started", "updated"}: if phase in {"started", "updated"}:
@@ -384,20 +377,15 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
def translate_codex_event(event: dict[str, Any], *, title: str) -> list[TakopiEvent]: def translate_codex_event(event: dict[str, Any], *, title: str) -> list[TakopiEvent]:
etype = event.get("type") etype = event["type"]
if etype == "thread.started": match etype:
thread_id = event.get("thread_id") case "thread.started":
if thread_id: token = ResumeToken(engine=ENGINE, value=str(event["thread_id"]))
token = ResumeToken(engine=ENGINE, value=str(thread_id))
return [_started_event(token, title=title)] return [_started_event(token, title=title)]
logger.debug("[codex] codex thread.started missing thread_id: %r", event) case "item.started" | "item.updated" | "item.completed":
return [] return _translate_item_event(etype, event["item"])
case _:
if etype in {"item.started", "item.updated", "item.completed"}: return []
item = event.get("item") or {}
return _translate_item_event(etype, item)
return []
@dataclass(slots=True) @dataclass(slots=True)
@@ -460,33 +448,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
def pipes_error_message(self) -> str: def pipes_error_message(self) -> str:
return "codex exec failed to open subprocess pipes" return "codex exec failed to open subprocess pipes"
def handle_started_event(
self,
event: StartedEvent,
*,
expected_session: ResumeToken | None,
found_session: ResumeToken | None,
) -> tuple[ResumeToken | None, bool]:
if event.engine != ENGINE:
raise RuntimeError(
f"codex emitted session token for engine {event.engine!r}"
)
if expected_session is not None and event.resume != expected_session:
message = (
f"codex emitted session id {event.resume.value} "
f"but expected {expected_session.value}"
)
raise RuntimeError(message)
if found_session is None:
return event.resume, True
if event.resume != found_session:
message = (
f"codex emitted session id {event.resume.value} "
f"but expected {found_session.value}"
)
raise RuntimeError(message)
return found_session, False
def translate( def translate(
self, self,
data: dict[str, Any], data: dict[str, Any],
@@ -495,12 +456,33 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
resume: ResumeToken | None, resume: ResumeToken | None,
found_session: ResumeToken | None, found_session: ResumeToken | None,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
etype = data.get("type") etype = data["type"]
if etype == "error": match etype:
message = str(data.get("message") or "codex error") case "error":
fatal_flag = data.get("fatal") message = str(data["message"])
fatal = fatal_flag is True or fatal_flag is None fatal_flag = data.get("fatal")
if fatal: fatal = fatal_flag is True or fatal_flag is None
if fatal:
resume_for_completed = found_session or resume
return [
_completed_event(
resume=resume_for_completed,
ok=False,
answer=state.final_answer or "",
error=message,
)
]
return [
self.note_event(
message,
state=state,
ok=False,
detail={"code": data.get("code"), "fatal": data.get("fatal")},
)
]
case "turn.failed":
error = data["error"]
message = str(error["message"])
resume_for_completed = found_session or resume resume_for_completed = found_session or resume
return [ return [
_completed_event( _completed_event(
@@ -510,67 +492,48 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
error=message, error=message,
) )
] ]
return [ case "turn.rate_limited":
self.note_event( retry_ms = data.get("retry_after_ms")
message, message = "rate limited"
state=state, if isinstance(retry_ms, int):
ok=False, message = f"rate limited (retry after {retry_ms}ms)"
detail={"code": data.get("code"), "fatal": data.get("fatal")}, return [self.note_event(message, state=state, ok=False)]
) case "turn.started":
] action_id = f"turn_{state.turn_index}"
if etype == "turn.failed": state.turn_index += 1
error = data.get("error") or {} return [
message = str(error.get("message") or "codex turn failed") _action_event(
resume_for_completed = found_session or resume phase="started",
return [ action_id=action_id,
_completed_event( kind="turn",
resume=resume_for_completed, title="turn started",
ok=False,
answer=state.final_answer or "",
error=message,
)
]
if etype == "turn.rate_limited":
retry_ms = data.get("retry_after_ms")
message = "rate limited"
if isinstance(retry_ms, int):
message = f"rate limited (retry after {retry_ms}ms)"
return [self.note_event(message, state=state, ok=False)]
if etype == "turn.started":
action_id = f"turn_{state.turn_index}"
state.turn_index += 1
return [
_action_event(
phase="started",
action_id=action_id,
kind="turn",
title="turn started",
)
]
if etype == "turn.completed":
resume_for_completed = found_session or resume
return [
_completed_event(
resume=resume_for_completed,
ok=True,
answer=state.final_answer or "",
usage=data.get("usage"),
)
]
if data.get("type") == "item.completed":
item = data.get("item") or {}
item_type = item.get("type") or item.get("item_type")
if item_type == "assistant_message":
item_type = "agent_message"
if item_type == "agent_message" and isinstance(item.get("text"), str):
if state.final_answer is None:
state.final_answer = item["text"]
else:
logger.debug(
"[codex] emitted multiple agent messages; using the last one"
) )
state.final_answer = item["text"] ]
case "turn.completed":
resume_for_completed = found_session or resume
return [
_completed_event(
resume=resume_for_completed,
ok=True,
answer=state.final_answer or "",
usage=data.get("usage"),
)
]
case "item.completed":
item = data["item"]
item_type = cast(str, item.get("type") or item.get("item_type"))
if item_type == "assistant_message":
item_type = "agent_message"
if item_type == "agent_message":
if state.final_answer is None:
state.final_answer = item["text"]
else:
logger.debug(
"[codex] emitted multiple agent messages; using the last one"
)
state.final_answer = item["text"]
case _:
pass
return translate_codex_event(data, title=self.session_title) return translate_codex_event(data, title=self.session_title)
+103
View File
@@ -0,0 +1,103 @@
from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Protocol
import anyio
from .model import ResumeToken
@dataclass(frozen=True, slots=True)
class ThreadJob:
chat_id: int
user_msg_id: int
text: str
resume_token: ResumeToken
RunJob = Callable[[ThreadJob], Awaitable[None]]
class TaskGroup(Protocol):
def start_soon(
self, func: Callable[..., Awaitable[object]], *args: Any
) -> None: ...
class ThreadScheduler:
def __init__(self, *, task_group: TaskGroup, run_job: RunJob) -> None:
self._task_group = task_group
self._run_job = run_job
self._lock = anyio.Lock()
self._pending_by_thread: dict[str, deque[ThreadJob]] = {}
self._active_threads: set[str] = set()
self._busy_until: dict[str, anyio.Event] = {}
@staticmethod
def thread_key(token: ResumeToken) -> str:
return f"{token.engine}:{token.value}"
async def note_thread_known(self, token: ResumeToken, done: anyio.Event) -> None:
key = self.thread_key(token)
async with self._lock:
current = self._busy_until.get(key)
if current is None or current.is_set():
self._busy_until[key] = done
self._task_group.start_soon(self._clear_busy, key, done)
async def enqueue(self, job: ThreadJob) -> None:
key = self.thread_key(job.resume_token)
async with self._lock:
queue = self._pending_by_thread.get(key)
if queue is None:
queue = deque()
self._pending_by_thread[key] = queue
queue.append(job)
if key in self._active_threads:
return
self._active_threads.add(key)
self._task_group.start_soon(self._thread_worker, key)
async def enqueue_resume(
self,
chat_id: int,
user_msg_id: int,
text: str,
resume_token: ResumeToken,
) -> None:
await self.enqueue(
ThreadJob(
chat_id=chat_id,
user_msg_id=user_msg_id,
text=text,
resume_token=resume_token,
)
)
async def _clear_busy(self, key: str, done: anyio.Event) -> None:
await done.wait()
async with self._lock:
if self._busy_until.get(key) is done:
self._busy_until.pop(key, None)
async def _thread_worker(self, key: str) -> None:
try:
while True:
async with self._lock:
done = self._busy_until.get(key)
queue = self._pending_by_thread.get(key)
if not queue:
self._pending_by_thread.pop(key, None)
self._active_threads.discard(key)
return
job = queue.popleft()
if done is not None and not done.is_set():
await done.wait()
await self._run_job(job)
finally:
async with self._lock:
self._active_threads.discard(key)
+5 -5
View File
@@ -13,11 +13,11 @@ def relativize_path(value: str, *, base_dir: Path | None = None) -> str:
return value return value
if value == base_str: if value == base_str:
return "." return "."
if value.startswith(base_str): for sep in (os.sep, "/"):
suffix = value[len(base_str) :] prefix = base_str if base_str.endswith(sep) else f"{base_str}{sep}"
if suffix.startswith((os.sep, "/")): if value.startswith(prefix):
suffix = suffix[1:] suffix = value[len(prefix) :]
return suffix or "." return suffix or "."
return value return value
+3 -3
View File
@@ -318,8 +318,8 @@ def test_render_event_cli_ignores_turn_actions() -> None:
assert render_event_cli(event) == [] assert render_event_cli(event) == []
def test_progress_renderer_ignores_missing_action_id_and_titles() -> None: def test_progress_renderer_ignores_missing_action_id() -> None:
renderer = ExecProgressRenderer(engine="codex", show_title=True) renderer = ExecProgressRenderer(engine="codex")
resume = ResumeToken(engine="codex", value="abc") resume = ResumeToken(engine="codex", value="abc")
renderer.note_event(StartedEvent(engine="codex", resume=resume, title="Session")) renderer.note_event(StartedEvent(engine="codex", resume=resume, title="Session"))
@@ -332,4 +332,4 @@ def test_progress_renderer_ignores_missing_action_id_and_titles() -> None:
assert renderer.note_event(event) is False assert renderer.note_event(event) is False
header = assemble_markdown_parts(renderer.render_progress_parts(0.0)) header = assemble_markdown_parts(renderer.render_progress_parts(0.0))
assert header.startswith("working (Session) · codex · 0s") assert header.startswith("working · codex · 0s")
+15 -1
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from pathlib import Path from pathlib import Path
from takopi.utils.paths import relativize_command from takopi.utils.paths import relativize_command, relativize_path
def test_relativize_command_rewrites_cwd_paths(tmp_path: Path) -> None: def test_relativize_command_rewrites_cwd_paths(tmp_path: Path) -> None:
@@ -19,3 +19,17 @@ def test_relativize_command_rewrites_equals_paths(tmp_path: Path) -> None:
command = f'rg -n --files -g "*.py" --path={base}/src' command = f'rg -n --files -g "*.py" --path={base}/src'
expected = 'rg -n --files -g "*.py" --path=src' expected = 'rg -n --files -g "*.py" --path=src'
assert relativize_command(command, base_dir=base) == expected assert relativize_command(command, base_dir=base) == expected
def test_relativize_path_ignores_sibling_prefix(tmp_path: Path) -> None:
base = tmp_path / "repo"
base.mkdir()
value = str(tmp_path / "repo2" / "file.txt")
assert relativize_path(value, base_dir=base) == value
def test_relativize_path_inside_base(tmp_path: Path) -> None:
base = tmp_path / "repo"
base.mkdir()
value = str(base / "src" / "app.py")
assert relativize_path(value, base_dir=base) == "src/app.py"
Generated
+1 -1
View File
@@ -354,7 +354,7 @@ wheels = [
[[package]] [[package]]
name = "takopi" name = "takopi"
version = "0.4.0" version = "0.5.0.dev0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },