refactor: runners and scheduler, fix path handling (#23)

This commit is contained in:
banteg
2026-01-02 14:22:59 +04:00
committed by GitHub
parent 9c59056666
commit 51cdb72d0b
11 changed files with 377 additions and 424 deletions
+8 -75
View File
@@ -4,7 +4,6 @@ from __future__ import annotations
import logging import logging
import time import time
from collections import deque
from collections.abc import AsyncIterator, Awaitable, Callable from collections.abc import AsyncIterator, Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@@ -21,6 +20,7 @@ from .render import (
) )
from .router import AutoRouter, RunnerUnavailableError from .router import AutoRouter, RunnerUnavailableError
from .runner import Runner from .runner import Runner
from .scheduler import ThreadJob, ThreadScheduler
from .telegram import BotClient from .telegram import BotClient
@@ -801,35 +801,6 @@ async def run_main_loop(
try: try:
await _set_command_menu(cfg) await _set_command_menu(cfg)
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
scheduler_lock = anyio.Lock()
@dataclass(frozen=True, slots=True)
class ThreadJob:
chat_id: int
user_msg_id: int
text: str
resume_token: ResumeToken
pending_by_thread: dict[str, deque[ThreadJob]] = {}
active_threads: set[str] = set()
busy_until: dict[str, anyio.Event] = {}
def thread_key(token: ResumeToken) -> str:
return f"{token.engine}:{token.value}"
async def clear_busy(key: str, done: anyio.Event) -> None:
await done.wait()
async with scheduler_lock:
if busy_until.get(key) is done:
busy_until.pop(key, None)
async def note_thread_known(token: ResumeToken, done: anyio.Event) -> None:
key = thread_key(token)
async with scheduler_lock:
current = busy_until.get(key)
if current is None or current.is_set():
busy_until[key] = done
tg.start_soon(clear_busy, key, done)
async def run_job( async def run_job(
chat_id: int, chat_id: int,
@@ -882,55 +853,15 @@ async def run_main_loop(
except Exception: except Exception:
logger.exception("[handle] worker failed") logger.exception("[handle] worker failed")
async def thread_worker(key: str) -> None: async def run_thread_job(job: ThreadJob) -> None:
try:
while True:
async with scheduler_lock:
done = busy_until.get(key)
queue = pending_by_thread.get(key)
if not queue:
pending_by_thread.pop(key, None)
active_threads.discard(key)
return
job = queue.popleft()
if done is not None and not done.is_set():
await done.wait()
await run_job( await run_job(
job.chat_id, job.chat_id,
job.user_msg_id, job.user_msg_id,
job.text, job.text,
job.resume_token, job.resume_token,
) )
finally:
async with scheduler_lock:
active_threads.discard(key)
async def enqueue( scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
chat_id: int,
user_msg_id: int,
text: str,
resume_token: ResumeToken,
) -> None:
key = thread_key(resume_token)
async with scheduler_lock:
queue = pending_by_thread.get(key)
if queue is None:
queue = deque()
pending_by_thread[key] = queue
queue.append(
ThreadJob(
chat_id=chat_id,
user_msg_id=user_msg_id,
text=text,
resume_token=resume_token,
)
)
if key in active_threads:
return
active_threads.add(key)
tg.start_soon(thread_worker, key)
async for msg in poller(cfg): async for msg in poller(cfg):
text = msg["text"] text = msg["text"]
@@ -953,7 +884,7 @@ async def run_main_loop(
tg.start_soon( tg.start_soon(
_send_with_resume, _send_with_resume,
cfg.bot, cfg.bot,
enqueue, scheduler.enqueue_resume,
running_task, running_task,
msg["chat"]["id"], msg["chat"]["id"],
user_msg_id, user_msg_id,
@@ -968,10 +899,12 @@ async def run_main_loop(
user_msg_id, user_msg_id,
text, text,
None, None,
note_thread_known, scheduler.note_thread_known,
engine_override, engine_override,
) )
else: else:
await enqueue(msg["chat"]["id"], user_msg_id, text, resume_token) await scheduler.enqueue_resume(
msg["chat"]["id"], user_msg_id, text, resume_token
)
finally: finally:
await cfg.bot.close() await cfg.bot.close()
+1 -22
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
import shutil import shutil
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Sequence
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
@@ -110,24 +110,3 @@ def render_setup_guide(result: SetupResult) -> None:
expand=False, expand=False,
) )
console.print(panel) console.print(panel)
def render_engine_choice(backends: Sequence[EngineBackend]) -> None:
console = Console(stderr=True)
parts: list[str] = []
parts.append("[bold]available engines:[/]")
parts.append("")
for idx, backend in enumerate(backends, start=1):
parts.append(f"[bold yellow]{idx}.[/] [dim]$[/] takopi {backend.id}")
parts.append(f" [dim]use {backend.id}[/]")
parts.append("")
panel = Panel(
"\n".join(parts).rstrip(),
title="[bold]welcome to takopi![/]",
subtitle=f"{_OCTOPUS} choose engine",
border_style="yellow",
padding=(1, 2),
expand=False,
)
console.print(panel)
+8 -15
View File
@@ -23,6 +23,9 @@ HARD_BREAK = " \n"
MAX_PROGRESS_CMD_LEN = 300 MAX_PROGRESS_CMD_LEN = 300
MAX_FILE_CHANGES_INLINE = 3 MAX_FILE_CHANGES_INLINE = 3
_MD_RENDERER = MarkdownIt("commonmark", {"html": False})
_BULLET_RE = re.compile(r"(?m)^(\s*)•")
@dataclass(frozen=True) @dataclass(frozen=True)
class MarkdownParts: class MarkdownParts:
@@ -38,11 +41,10 @@ def assemble_markdown_parts(parts: MarkdownParts) -> str:
def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]: def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]:
md_renderer = MarkdownIt("commonmark", {"html": False}) html = _MD_RENDERER.render(md or "")
html = md_renderer.render(md or "")
rendered = transform_html(html) rendered = transform_html(html)
text = re.sub(r"(?m)^(\s*)•", r"\1-", rendered.text) text = _BULLET_RE.sub(r"\1-", rendered.text)
entities = [dict(e) for e in rendered.entities] entities = [dict(e) for e in rendered.entities]
return text, entities return text, entities
@@ -218,7 +220,6 @@ class ExecProgressRenderer:
max_actions: int = 5, max_actions: int = 5,
command_width: int | None = MAX_PROGRESS_CMD_LEN, command_width: int | None = MAX_PROGRESS_CMD_LEN,
resume_formatter: Callable[[ResumeToken], str] | None = None, resume_formatter: Callable[[ResumeToken], str] | None = None,
show_title: bool = False,
) -> None: ) -> None:
self.max_actions = max(0, int(max_actions)) self.max_actions = max(0, int(max_actions))
self.command_width = command_width self.command_width = command_width
@@ -226,16 +227,13 @@ class ExecProgressRenderer:
self.action_count = 0 self.action_count = 0
self.seen_action_ids: set[str] = set() self.seen_action_ids: set[str] = set()
self.resume_token: ResumeToken | None = None self.resume_token: ResumeToken | None = None
self.session_title: str | None = None
self._resume_formatter = resume_formatter self._resume_formatter = resume_formatter
self.show_title = show_title
self.engine = engine self.engine = engine
def note_event(self, event: TakopiEvent) -> bool: def note_event(self, event: TakopiEvent) -> bool:
match event: match event:
case StartedEvent(resume=resume, title=title): case StartedEvent(resume=resume):
self.resume_token = resume self.resume_token = resume
self.session_title = title
return True return True
case ActionEvent(action=action, phase=phase, ok=ok): case ActionEvent(action=action, phase=phase, ok=ok):
if action.kind == "turn": if action.kind == "turn":
@@ -286,7 +284,7 @@ class ExecProgressRenderer:
header = format_header( header = format_header(
elapsed_s, elapsed_s,
step, step,
label=self.label_with_title(label), label=label,
engine=self.engine, engine=self.engine,
) )
body = self.assemble_body([line.text for line in self.lines]) body = self.assemble_body([line.text for line in self.lines])
@@ -299,18 +297,13 @@ class ExecProgressRenderer:
header = format_header( header = format_header(
elapsed_s, elapsed_s,
step, step,
label=self.label_with_title(status), label=status,
engine=self.engine, engine=self.engine,
) )
answer = (answer or "").strip() answer = (answer or "").strip()
body = answer if answer else None body = answer if answer else None
return MarkdownParts(header=header, body=body, footer=self.render_footer()) return MarkdownParts(header=header, body=body, footer=self.render_footer())
def label_with_title(self, label: str) -> str:
if self.show_title and self.session_title:
return f"{label} ({self.session_title})"
return label
def render_footer(self) -> str | None: def render_footer(self) -> str | None:
if not self.resume_token or self._resume_formatter is None: if not self.resume_token or self._resume_formatter is None:
return None return None
+2 -3
View File
@@ -91,11 +91,10 @@ class SessionLockMixin:
class BaseRunner(SessionLockMixin): class BaseRunner(SessionLockMixin):
engine: EngineId engine: EngineId
async def run( def run(
self, prompt: str, resume: ResumeToken | None self, prompt: str, resume: ResumeToken | None
) -> AsyncIterator[TakopiEvent]: ) -> AsyncIterator[TakopiEvent]:
async for evt in self.run_locked(prompt, resume): return self.run_locked(prompt, resume)
yield evt
async def run_locked( async def run_locked(
self, prompt: str, resume: ResumeToken | None self, prompt: str, resume: ResumeToken | None
+36 -67
View File
@@ -142,14 +142,10 @@ def _tool_action(
*, *,
message_id: str | None, message_id: str | None,
parent_tool_use_id: str | None, parent_tool_use_id: str | None,
) -> Action | None: ) -> Action:
tool_id = content.get("id") tool_id = content["id"]
if not isinstance(tool_id, str) or not tool_id:
return None
tool_name = str(content.get("name") or "tool") tool_name = str(content.get("name") or "tool")
tool_input = content.get("input") tool_input = content["input"]
if not isinstance(tool_input, dict):
tool_input = {}
kind, title = _tool_kind_and_title(tool_name, tool_input) kind, title = _tool_kind_and_title(tool_name, tool_input)
@@ -247,15 +243,20 @@ def translate_claude_event(
title: str, title: str,
state: ClaudeStreamState, state: ClaudeStreamState,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
etype = event.get("type") etype = event["type"]
if etype == "system" and event.get("subtype") == "init": match etype:
session_id = event.get("session_id") case "system" if event.get("subtype") == "init":
if not session_id: session_id = event["session_id"]
return []
model = event.get("model") model = event.get("model")
event_title = str(model) if model else title event_title = str(model) if model else title
meta: dict[str, Any] = {} meta: dict[str, Any] = {}
for key in ("cwd", "tools", "permissionMode", "output_style", "apiKeySource"): for key in (
"cwd",
"tools",
"permissionMode",
"output_style",
"apiKeySource",
):
if key in event: if key in event:
meta[key] = event.get(key) meta[key] = event.get(key)
if "mcp_servers" in event: if "mcp_servers" in event:
@@ -269,60 +270,38 @@ def translate_claude_event(
meta=meta or None, meta=meta or None,
) )
] ]
case "assistant":
if etype == "assistant": message = event["message"]
message = event.get("message")
if not isinstance(message, dict):
return []
message_id = message.get("id") message_id = message.get("id")
if not isinstance(message_id, str):
message_id = None
parent_tool_use_id = event.get("parent_tool_use_id") parent_tool_use_id = event.get("parent_tool_use_id")
if not isinstance(parent_tool_use_id, str): content_blocks = message["content"]
parent_tool_use_id = None
content_blocks = message.get("content")
if not isinstance(content_blocks, list):
return []
out: list[TakopiEvent] = [] out: list[TakopiEvent] = []
for content in content_blocks: for content in content_blocks:
if not isinstance(content, dict): match content["type"]:
continue case "tool_use":
ctype = content.get("type")
if ctype == "tool_use":
action = _tool_action( action = _tool_action(
content, content,
message_id=message_id, message_id=message_id,
parent_tool_use_id=parent_tool_use_id, parent_tool_use_id=parent_tool_use_id,
) )
if action is None:
continue
state.pending_actions[action.id] = action state.pending_actions[action.id] = action
out.append(_action_event(phase="started", action=action)) out.append(_action_event(phase="started", action=action))
elif ctype == "text": case "text":
text = content.get("text") text = content["text"]
if isinstance(text, str) and text: if text:
state.last_assistant_text = text state.last_assistant_text = text
case _:
continue
return out return out
case "user":
if etype == "user": message = event["message"]
message = event.get("message")
if not isinstance(message, dict):
return []
message_id = message.get("id") message_id = message.get("id")
if not isinstance(message_id, str): content_blocks = message["content"]
message_id = None
content_blocks = message.get("content")
if not isinstance(content_blocks, list):
return []
out: list[TakopiEvent] = [] out: list[TakopiEvent] = []
for content in content_blocks: for content in content_blocks:
if not isinstance(content, dict): if content["type"] != "tool_result":
continue
if content.get("type") != "tool_result":
continue
tool_use_id = content.get("tool_use_id")
if not isinstance(tool_use_id, str) or not tool_use_id:
continue continue
tool_use_id = content["tool_use_id"]
action = state.pending_actions.pop(tool_use_id, None) action = state.pending_actions.pop(tool_use_id, None)
if action is None: if action is None:
action = Action( action = Action(
@@ -335,20 +314,17 @@ def translate_claude_event(
_tool_result_event(content, action=action, message_id=message_id) _tool_result_event(content, action=action, message_id=message_id)
) )
return out return out
case "result":
if etype == "result":
out: list[TakopiEvent] = [] out: list[TakopiEvent] = []
for idx, denial in enumerate(event.get("permission_denials") or []): for idx, denial in enumerate(event.get("permission_denials", [])):
if not isinstance(denial, dict):
continue
tool_name = denial.get("tool_name") tool_name = denial.get("tool_name")
denial_title = "permission denied" denial_title = "permission denied"
if isinstance(tool_name, str) and tool_name: if tool_name:
denial_title = f"permission denied: {tool_name}" denial_title = f"permission denied: {tool_name}"
tool_use_id = denial.get("tool_use_id") tool_use_id = denial.get("tool_use_id")
action_id = ( action_id = (
f"claude.permission.{tool_use_id}" f"claude.permission.{tool_use_id}"
if isinstance(tool_use_id, str) and tool_use_id if tool_use_id
else f"claude.permission.{idx}" else f"claude.permission.{idx}"
) )
out.append( out.append(
@@ -366,18 +342,11 @@ def translate_claude_event(
) )
ok = not event.get("is_error", False) ok = not event.get("is_error", False)
result_text = event.get("result") result_text = event["result"]
if not isinstance(result_text, str):
result_text = ""
if ok and not result_text and state.last_assistant_text: if ok and not result_text and state.last_assistant_text:
result_text = state.last_assistant_text result_text = state.last_assistant_text
resume_value = event.get("session_id") resume = ResumeToken(engine=ENGINE, value=str(event["session_id"]))
resume = (
ResumeToken(engine=ENGINE, value=str(resume_value))
if resume_value
else None
)
error = None if ok else _extract_error(event) error = None if ok else _extract_error(event)
usage = _usage_payload(event) usage = _usage_payload(event)
@@ -392,7 +361,7 @@ def translate_claude_event(
) )
) )
return out return out
case _:
return [] return []
+55 -92
View File
@@ -185,27 +185,21 @@ def _todo_title(summary: _TodoSummary) -> str:
def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]: def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]:
item_type = item.get("type") or item.get("item_type") item_type = cast(str, item.get("type") or item.get("item_type"))
if item_type == "assistant_message": if item_type == "assistant_message":
item_type = "agent_message" item_type = "agent_message"
if not item_type:
return []
if item_type == "agent_message": if item_type == "agent_message":
return [] return []
action_id = item.get("id") action_id = str(item["id"])
if not isinstance(action_id, str) or not action_id:
logger.debug("[codex] missing item id in codex event: %r", item)
return []
phase = cast(ActionPhase, etype.split(".")[-1]) phase = cast(ActionPhase, etype.split(".")[-1])
if item_type == "error": if item_type == "error":
if phase != "completed": if phase != "completed":
return [] return []
message = str(item.get("message") or "codex item error") message = str(item["message"])
return [ return [
_action_event( _action_event(
phase="completed", phase="completed",
@@ -224,7 +218,7 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
return [] return []
if kind == "command": if kind == "command":
title = relativize_command(str(item.get("command") or "")) title = relativize_command(str(item["command"]))
if phase in {"started", "updated"}: if phase in {"started", "updated"}:
return [ return [
_action_event( _action_event(
@@ -235,13 +229,13 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
) )
] ]
if phase == "completed": if phase == "completed":
exit_code = item.get("exit_code") exit_code = item["exit_code"]
ok = item.get("status") != "failed" ok = item["status"] != "failed"
if isinstance(exit_code, int): if exit_code is not None:
ok = ok and exit_code == 0 ok = ok and exit_code == 0
detail = { detail = {
"exit_code": exit_code, "exit_code": exit_code,
"status": item.get("status"), "status": item["status"],
} }
return [ return [
_action_event( _action_event(
@@ -255,22 +249,23 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
] ]
if kind == "tool": if kind == "tool":
if item_type == "tool_call":
name = item["name"]
title = str(name) if name else "tool"
detail = {
"name": name,
"status": item["status"],
"arguments": item.get("arguments"),
}
else:
tool_name = _short_tool_name(item) tool_name = _short_tool_name(item)
title = tool_name title = tool_name
detail = { detail = {
"server": item.get("server"), "server": item["server"],
"tool": item.get("tool"), "tool": item["tool"],
"status": item.get("status"), "status": item["status"],
"arguments": item.get("arguments"),
} }
if "arguments" in item:
detail["arguments"] = item.get("arguments")
if item_type == "tool_call":
name = item.get("name")
tool_name = str(name) if name else "tool"
title = tool_name
detail = {"name": name, "status": item.get("status")}
if "arguments" in item:
detail["arguments"] = item.get("arguments")
if phase in {"started", "updated"}: if phase in {"started", "updated"}:
return [ return [
@@ -283,12 +278,10 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
) )
] ]
if phase == "completed": if phase == "completed":
ok = item.get("status") != "failed" and not item.get("error") ok = item["status"] != "failed" and not item["error"]
error = item.get("error") error = item["error"]
if error: if error:
detail["error_message"] = str( detail["error_message"] = str(error.get("message") or error)
error.get("message") if isinstance(error, dict) else error
)
result_summary = _summarize_tool_result(item.get("result")) result_summary = _summarize_tool_result(item.get("result"))
if result_summary is not None: if result_summary is not None:
detail["result_summary"] = result_summary detail["result_summary"] = result_summary
@@ -304,8 +297,8 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
] ]
if kind == "web_search": if kind == "web_search":
title = str(item.get("query") or "") title = str(item["query"])
detail = {"query": item.get("query")} detail = {"query": item["query"]}
if phase in {"started", "updated"}: if phase in {"started", "updated"}:
return [ return [
_action_event( _action_event(
@@ -333,11 +326,11 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
return [] return []
title = _format_change_summary(item) title = _format_change_summary(item)
detail = { detail = {
"changes": item.get("changes") or [], "changes": item["changes"],
"status": item.get("status"), "status": item["status"],
"error": item.get("error"), "error": item["error"],
} }
ok = item.get("status") != "failed" ok = item["status"] != "failed"
return [ return [
_action_event( _action_event(
phase="completed", phase="completed",
@@ -351,11 +344,11 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
if kind == "note": if kind == "note":
if item_type == "todo_list": if item_type == "todo_list":
summary = _summarize_todo_list(item.get("items")) summary = _summarize_todo_list(item["items"])
title = _todo_title(summary) title = _todo_title(summary)
detail = {"done": summary.done, "total": summary.total} detail = {"done": summary.done, "total": summary.total}
else: else:
title = str(item.get("text") or "") title = str(item["text"])
detail = None detail = None
if phase in {"started", "updated"}: if phase in {"started", "updated"}:
@@ -384,19 +377,14 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
def translate_codex_event(event: dict[str, Any], *, title: str) -> list[TakopiEvent]: def translate_codex_event(event: dict[str, Any], *, title: str) -> list[TakopiEvent]:
etype = event.get("type") etype = event["type"]
if etype == "thread.started": match etype:
thread_id = event.get("thread_id") case "thread.started":
if thread_id: token = ResumeToken(engine=ENGINE, value=str(event["thread_id"]))
token = ResumeToken(engine=ENGINE, value=str(thread_id))
return [_started_event(token, title=title)] return [_started_event(token, title=title)]
logger.debug("[codex] codex thread.started missing thread_id: %r", event) case "item.started" | "item.updated" | "item.completed":
return [] return _translate_item_event(etype, event["item"])
case _:
if etype in {"item.started", "item.updated", "item.completed"}:
item = event.get("item") or {}
return _translate_item_event(etype, item)
return [] return []
@@ -460,33 +448,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
def pipes_error_message(self) -> str: def pipes_error_message(self) -> str:
return "codex exec failed to open subprocess pipes" return "codex exec failed to open subprocess pipes"
def handle_started_event(
self,
event: StartedEvent,
*,
expected_session: ResumeToken | None,
found_session: ResumeToken | None,
) -> tuple[ResumeToken | None, bool]:
if event.engine != ENGINE:
raise RuntimeError(
f"codex emitted session token for engine {event.engine!r}"
)
if expected_session is not None and event.resume != expected_session:
message = (
f"codex emitted session id {event.resume.value} "
f"but expected {expected_session.value}"
)
raise RuntimeError(message)
if found_session is None:
return event.resume, True
if event.resume != found_session:
message = (
f"codex emitted session id {event.resume.value} "
f"but expected {found_session.value}"
)
raise RuntimeError(message)
return found_session, False
def translate( def translate(
self, self,
data: dict[str, Any], data: dict[str, Any],
@@ -495,9 +456,10 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
resume: ResumeToken | None, resume: ResumeToken | None,
found_session: ResumeToken | None, found_session: ResumeToken | None,
) -> list[TakopiEvent]: ) -> list[TakopiEvent]:
etype = data.get("type") etype = data["type"]
if etype == "error": match etype:
message = str(data.get("message") or "codex error") case "error":
message = str(data["message"])
fatal_flag = data.get("fatal") fatal_flag = data.get("fatal")
fatal = fatal_flag is True or fatal_flag is None fatal = fatal_flag is True or fatal_flag is None
if fatal: if fatal:
@@ -518,9 +480,9 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
detail={"code": data.get("code"), "fatal": data.get("fatal")}, detail={"code": data.get("code"), "fatal": data.get("fatal")},
) )
] ]
if etype == "turn.failed": case "turn.failed":
error = data.get("error") or {} error = data["error"]
message = str(error.get("message") or "codex turn failed") message = str(error["message"])
resume_for_completed = found_session or resume resume_for_completed = found_session or resume
return [ return [
_completed_event( _completed_event(
@@ -530,13 +492,13 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
error=message, error=message,
) )
] ]
if etype == "turn.rate_limited": case "turn.rate_limited":
retry_ms = data.get("retry_after_ms") retry_ms = data.get("retry_after_ms")
message = "rate limited" message = "rate limited"
if isinstance(retry_ms, int): if isinstance(retry_ms, int):
message = f"rate limited (retry after {retry_ms}ms)" message = f"rate limited (retry after {retry_ms}ms)"
return [self.note_event(message, state=state, ok=False)] return [self.note_event(message, state=state, ok=False)]
if etype == "turn.started": case "turn.started":
action_id = f"turn_{state.turn_index}" action_id = f"turn_{state.turn_index}"
state.turn_index += 1 state.turn_index += 1
return [ return [
@@ -547,7 +509,7 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
title="turn started", title="turn started",
) )
] ]
if etype == "turn.completed": case "turn.completed":
resume_for_completed = found_session or resume resume_for_completed = found_session or resume
return [ return [
_completed_event( _completed_event(
@@ -557,13 +519,12 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
usage=data.get("usage"), usage=data.get("usage"),
) )
] ]
case "item.completed":
if data.get("type") == "item.completed": item = data["item"]
item = data.get("item") or {} item_type = cast(str, item.get("type") or item.get("item_type"))
item_type = item.get("type") or item.get("item_type")
if item_type == "assistant_message": if item_type == "assistant_message":
item_type = "agent_message" item_type = "agent_message"
if item_type == "agent_message" and isinstance(item.get("text"), str): if item_type == "agent_message":
if state.final_answer is None: if state.final_answer is None:
state.final_answer = item["text"] state.final_answer = item["text"]
else: else:
@@ -571,6 +532,8 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
"[codex] emitted multiple agent messages; using the last one" "[codex] emitted multiple agent messages; using the last one"
) )
state.final_answer = item["text"] state.final_answer = item["text"]
case _:
pass
return translate_codex_event(data, title=self.session_title) return translate_codex_event(data, title=self.session_title)
+103
View File
@@ -0,0 +1,103 @@
from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Protocol
import anyio
from .model import ResumeToken
@dataclass(frozen=True, slots=True)
class ThreadJob:
chat_id: int
user_msg_id: int
text: str
resume_token: ResumeToken
RunJob = Callable[[ThreadJob], Awaitable[None]]
class TaskGroup(Protocol):
def start_soon(
self, func: Callable[..., Awaitable[object]], *args: Any
) -> None: ...
class ThreadScheduler:
def __init__(self, *, task_group: TaskGroup, run_job: RunJob) -> None:
self._task_group = task_group
self._run_job = run_job
self._lock = anyio.Lock()
self._pending_by_thread: dict[str, deque[ThreadJob]] = {}
self._active_threads: set[str] = set()
self._busy_until: dict[str, anyio.Event] = {}
@staticmethod
def thread_key(token: ResumeToken) -> str:
return f"{token.engine}:{token.value}"
async def note_thread_known(self, token: ResumeToken, done: anyio.Event) -> None:
key = self.thread_key(token)
async with self._lock:
current = self._busy_until.get(key)
if current is None or current.is_set():
self._busy_until[key] = done
self._task_group.start_soon(self._clear_busy, key, done)
async def enqueue(self, job: ThreadJob) -> None:
key = self.thread_key(job.resume_token)
async with self._lock:
queue = self._pending_by_thread.get(key)
if queue is None:
queue = deque()
self._pending_by_thread[key] = queue
queue.append(job)
if key in self._active_threads:
return
self._active_threads.add(key)
self._task_group.start_soon(self._thread_worker, key)
async def enqueue_resume(
self,
chat_id: int,
user_msg_id: int,
text: str,
resume_token: ResumeToken,
) -> None:
await self.enqueue(
ThreadJob(
chat_id=chat_id,
user_msg_id=user_msg_id,
text=text,
resume_token=resume_token,
)
)
async def _clear_busy(self, key: str, done: anyio.Event) -> None:
await done.wait()
async with self._lock:
if self._busy_until.get(key) is done:
self._busy_until.pop(key, None)
async def _thread_worker(self, key: str) -> None:
try:
while True:
async with self._lock:
done = self._busy_until.get(key)
queue = self._pending_by_thread.get(key)
if not queue:
self._pending_by_thread.pop(key, None)
self._active_threads.discard(key)
return
job = queue.popleft()
if done is not None and not done.is_set():
await done.wait()
await self._run_job(job)
finally:
async with self._lock:
self._active_threads.discard(key)
+4 -4
View File
@@ -13,10 +13,10 @@ def relativize_path(value: str, *, base_dir: Path | None = None) -> str:
return value return value
if value == base_str: if value == base_str:
return "." return "."
if value.startswith(base_str): for sep in (os.sep, "/"):
suffix = value[len(base_str) :] prefix = base_str if base_str.endswith(sep) else f"{base_str}{sep}"
if suffix.startswith((os.sep, "/")): if value.startswith(prefix):
suffix = suffix[1:] suffix = value[len(prefix) :]
return suffix or "." return suffix or "."
return value return value
+3 -3
View File
@@ -318,8 +318,8 @@ def test_render_event_cli_ignores_turn_actions() -> None:
assert render_event_cli(event) == [] assert render_event_cli(event) == []
def test_progress_renderer_ignores_missing_action_id_and_titles() -> None: def test_progress_renderer_ignores_missing_action_id() -> None:
renderer = ExecProgressRenderer(engine="codex", show_title=True) renderer = ExecProgressRenderer(engine="codex")
resume = ResumeToken(engine="codex", value="abc") resume = ResumeToken(engine="codex", value="abc")
renderer.note_event(StartedEvent(engine="codex", resume=resume, title="Session")) renderer.note_event(StartedEvent(engine="codex", resume=resume, title="Session"))
@@ -332,4 +332,4 @@ def test_progress_renderer_ignores_missing_action_id_and_titles() -> None:
assert renderer.note_event(event) is False assert renderer.note_event(event) is False
header = assemble_markdown_parts(renderer.render_progress_parts(0.0)) header = assemble_markdown_parts(renderer.render_progress_parts(0.0))
assert header.startswith("working (Session) · codex · 0s") assert header.startswith("working · codex · 0s")
+15 -1
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from pathlib import Path from pathlib import Path
from takopi.utils.paths import relativize_command from takopi.utils.paths import relativize_command, relativize_path
def test_relativize_command_rewrites_cwd_paths(tmp_path: Path) -> None: def test_relativize_command_rewrites_cwd_paths(tmp_path: Path) -> None:
@@ -19,3 +19,17 @@ def test_relativize_command_rewrites_equals_paths(tmp_path: Path) -> None:
command = f'rg -n --files -g "*.py" --path={base}/src' command = f'rg -n --files -g "*.py" --path={base}/src'
expected = 'rg -n --files -g "*.py" --path=src' expected = 'rg -n --files -g "*.py" --path=src'
assert relativize_command(command, base_dir=base) == expected assert relativize_command(command, base_dir=base) == expected
def test_relativize_path_ignores_sibling_prefix(tmp_path: Path) -> None:
base = tmp_path / "repo"
base.mkdir()
value = str(tmp_path / "repo2" / "file.txt")
assert relativize_path(value, base_dir=base) == value
def test_relativize_path_inside_base(tmp_path: Path) -> None:
base = tmp_path / "repo"
base.mkdir()
value = str(base / "src" / "app.py")
assert relativize_path(value, base_dir=base) == "src/app.py"
Generated
+1 -1
View File
@@ -354,7 +354,7 @@ wheels = [
[[package]] [[package]]
name = "takopi" name = "takopi"
version = "0.4.0" version = "0.5.0.dev0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },