refactor: runners and scheduler, fix path handling (#23)

This commit is contained in:
banteg
2026-01-02 14:22:59 +04:00
committed by GitHub
parent 9c59056666
commit 51cdb72d0b
11 changed files with 377 additions and 424 deletions
+14 -81
View File
@@ -4,7 +4,6 @@ from __future__ import annotations
import logging
import time
from collections import deque
from collections.abc import AsyncIterator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
@@ -21,6 +20,7 @@ from .render import (
)
from .router import AutoRouter, RunnerUnavailableError
from .runner import Runner
from .scheduler import ThreadJob, ThreadScheduler
from .telegram import BotClient
@@ -801,35 +801,6 @@ async def run_main_loop(
try:
await _set_command_menu(cfg)
async with anyio.create_task_group() as tg:
scheduler_lock = anyio.Lock()
@dataclass(frozen=True, slots=True)
class ThreadJob:
chat_id: int
user_msg_id: int
text: str
resume_token: ResumeToken
pending_by_thread: dict[str, deque[ThreadJob]] = {}
active_threads: set[str] = set()
busy_until: dict[str, anyio.Event] = {}
def thread_key(token: ResumeToken) -> str:
return f"{token.engine}:{token.value}"
async def clear_busy(key: str, done: anyio.Event) -> None:
await done.wait()
async with scheduler_lock:
if busy_until.get(key) is done:
busy_until.pop(key, None)
async def note_thread_known(token: ResumeToken, done: anyio.Event) -> None:
key = thread_key(token)
async with scheduler_lock:
current = busy_until.get(key)
if current is None or current.is_set():
busy_until[key] = done
tg.start_soon(clear_busy, key, done)
async def run_job(
chat_id: int,
@@ -882,55 +853,15 @@ async def run_main_loop(
except Exception:
logger.exception("[handle] worker failed")
async def thread_worker(key: str) -> None:
try:
while True:
async with scheduler_lock:
done = busy_until.get(key)
queue = pending_by_thread.get(key)
if not queue:
pending_by_thread.pop(key, None)
active_threads.discard(key)
return
job = queue.popleft()
async def run_thread_job(job: ThreadJob) -> None:
await run_job(
job.chat_id,
job.user_msg_id,
job.text,
job.resume_token,
)
if done is not None and not done.is_set():
await done.wait()
await run_job(
job.chat_id,
job.user_msg_id,
job.text,
job.resume_token,
)
finally:
async with scheduler_lock:
active_threads.discard(key)
async def enqueue(
chat_id: int,
user_msg_id: int,
text: str,
resume_token: ResumeToken,
) -> None:
key = thread_key(resume_token)
async with scheduler_lock:
queue = pending_by_thread.get(key)
if queue is None:
queue = deque()
pending_by_thread[key] = queue
queue.append(
ThreadJob(
chat_id=chat_id,
user_msg_id=user_msg_id,
text=text,
resume_token=resume_token,
)
)
if key in active_threads:
return
active_threads.add(key)
tg.start_soon(thread_worker, key)
scheduler = ThreadScheduler(task_group=tg, run_job=run_thread_job)
async for msg in poller(cfg):
text = msg["text"]
@@ -953,7 +884,7 @@ async def run_main_loop(
tg.start_soon(
_send_with_resume,
cfg.bot,
enqueue,
scheduler.enqueue_resume,
running_task,
msg["chat"]["id"],
user_msg_id,
@@ -968,10 +899,12 @@ async def run_main_loop(
user_msg_id,
text,
None,
note_thread_known,
scheduler.note_thread_known,
engine_override,
)
else:
await enqueue(msg["chat"]["id"], user_msg_id, text, resume_token)
await scheduler.enqueue_resume(
msg["chat"]["id"], user_msg_id, text, resume_token
)
finally:
await cfg.bot.close()
+1 -22
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence
from rich.console import Console
from rich.panel import Panel
@@ -110,24 +110,3 @@ def render_setup_guide(result: SetupResult) -> None:
expand=False,
)
console.print(panel)
def render_engine_choice(backends: Sequence[EngineBackend]) -> None:
console = Console(stderr=True)
parts: list[str] = []
parts.append("[bold]available engines:[/]")
parts.append("")
for idx, backend in enumerate(backends, start=1):
parts.append(f"[bold yellow]{idx}.[/] [dim]$[/] takopi {backend.id}")
parts.append(f" [dim]use {backend.id}[/]")
parts.append("")
panel = Panel(
"\n".join(parts).rstrip(),
title="[bold]welcome to takopi![/]",
subtitle=f"{_OCTOPUS} choose engine",
border_style="yellow",
padding=(1, 2),
expand=False,
)
console.print(panel)
+8 -15
View File
@@ -23,6 +23,9 @@ HARD_BREAK = " \n"
MAX_PROGRESS_CMD_LEN = 300
MAX_FILE_CHANGES_INLINE = 3
_MD_RENDERER = MarkdownIt("commonmark", {"html": False})
_BULLET_RE = re.compile(r"(?m)^(\s*)•")
@dataclass(frozen=True)
class MarkdownParts:
@@ -38,11 +41,10 @@ def assemble_markdown_parts(parts: MarkdownParts) -> str:
def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]:
md_renderer = MarkdownIt("commonmark", {"html": False})
html = md_renderer.render(md or "")
html = _MD_RENDERER.render(md or "")
rendered = transform_html(html)
text = re.sub(r"(?m)^(\s*)•", r"\1-", rendered.text)
text = _BULLET_RE.sub(r"\1-", rendered.text)
entities = [dict(e) for e in rendered.entities]
return text, entities
@@ -218,7 +220,6 @@ class ExecProgressRenderer:
max_actions: int = 5,
command_width: int | None = MAX_PROGRESS_CMD_LEN,
resume_formatter: Callable[[ResumeToken], str] | None = None,
show_title: bool = False,
) -> None:
self.max_actions = max(0, int(max_actions))
self.command_width = command_width
@@ -226,16 +227,13 @@ class ExecProgressRenderer:
self.action_count = 0
self.seen_action_ids: set[str] = set()
self.resume_token: ResumeToken | None = None
self.session_title: str | None = None
self._resume_formatter = resume_formatter
self.show_title = show_title
self.engine = engine
def note_event(self, event: TakopiEvent) -> bool:
match event:
case StartedEvent(resume=resume, title=title):
case StartedEvent(resume=resume):
self.resume_token = resume
self.session_title = title
return True
case ActionEvent(action=action, phase=phase, ok=ok):
if action.kind == "turn":
@@ -286,7 +284,7 @@ class ExecProgressRenderer:
header = format_header(
elapsed_s,
step,
label=self.label_with_title(label),
label=label,
engine=self.engine,
)
body = self.assemble_body([line.text for line in self.lines])
@@ -299,18 +297,13 @@ class ExecProgressRenderer:
header = format_header(
elapsed_s,
step,
label=self.label_with_title(status),
label=status,
engine=self.engine,
)
answer = (answer or "").strip()
body = answer if answer else None
return MarkdownParts(header=header, body=body, footer=self.render_footer())
def label_with_title(self, label: str) -> str:
if self.show_title and self.session_title:
return f"{label} ({self.session_title})"
return label
def render_footer(self) -> str | None:
if not self.resume_token or self._resume_formatter is None:
return None
+2 -3
View File
@@ -91,11 +91,10 @@ class SessionLockMixin:
class BaseRunner(SessionLockMixin):
engine: EngineId
async def run(
def run(
self, prompt: str, resume: ResumeToken | None
) -> AsyncIterator[TakopiEvent]:
async for evt in self.run_locked(prompt, resume):
yield evt
return self.run_locked(prompt, resume)
async def run_locked(
self, prompt: str, resume: ResumeToken | None
+115 -146
View File
@@ -142,14 +142,10 @@ def _tool_action(
*,
message_id: str | None,
parent_tool_use_id: str | None,
) -> Action | None:
tool_id = content.get("id")
if not isinstance(tool_id, str) or not tool_id:
return None
) -> Action:
tool_id = content["id"]
tool_name = str(content.get("name") or "tool")
tool_input = content.get("input")
if not isinstance(tool_input, dict):
tool_input = {}
tool_input = content["input"]
kind, title = _tool_kind_and_title(tool_name, tool_input)
@@ -247,153 +243,126 @@ def translate_claude_event(
title: str,
state: ClaudeStreamState,
) -> list[TakopiEvent]:
etype = event.get("type")
if etype == "system" and event.get("subtype") == "init":
session_id = event.get("session_id")
if not session_id:
return []
model = event.get("model")
event_title = str(model) if model else title
meta: dict[str, Any] = {}
for key in ("cwd", "tools", "permissionMode", "output_style", "apiKeySource"):
if key in event:
meta[key] = event.get(key)
if "mcp_servers" in event:
meta["mcp_servers"] = event.get("mcp_servers")
etype = event["type"]
match etype:
case "system" if event.get("subtype") == "init":
session_id = event["session_id"]
model = event.get("model")
event_title = str(model) if model else title
meta: dict[str, Any] = {}
for key in (
"cwd",
"tools",
"permissionMode",
"output_style",
"apiKeySource",
):
if key in event:
meta[key] = event.get(key)
if "mcp_servers" in event:
meta["mcp_servers"] = event.get("mcp_servers")
return [
StartedEvent(
engine=ENGINE,
resume=ResumeToken(engine=ENGINE, value=str(session_id)),
title=event_title,
meta=meta or None,
)
]
if etype == "assistant":
message = event.get("message")
if not isinstance(message, dict):
return []
message_id = message.get("id")
if not isinstance(message_id, str):
message_id = None
parent_tool_use_id = event.get("parent_tool_use_id")
if not isinstance(parent_tool_use_id, str):
parent_tool_use_id = None
content_blocks = message.get("content")
if not isinstance(content_blocks, list):
return []
out: list[TakopiEvent] = []
for content in content_blocks:
if not isinstance(content, dict):
continue
ctype = content.get("type")
if ctype == "tool_use":
action = _tool_action(
content,
message_id=message_id,
parent_tool_use_id=parent_tool_use_id,
return [
StartedEvent(
engine=ENGINE,
resume=ResumeToken(engine=ENGINE, value=str(session_id)),
title=event_title,
meta=meta or None,
)
if action is None:
]
case "assistant":
message = event["message"]
message_id = message.get("id")
parent_tool_use_id = event.get("parent_tool_use_id")
content_blocks = message["content"]
out: list[TakopiEvent] = []
for content in content_blocks:
match content["type"]:
case "tool_use":
action = _tool_action(
content,
message_id=message_id,
parent_tool_use_id=parent_tool_use_id,
)
state.pending_actions[action.id] = action
out.append(_action_event(phase="started", action=action))
case "text":
text = content["text"]
if text:
state.last_assistant_text = text
case _:
continue
return out
case "user":
message = event["message"]
message_id = message.get("id")
content_blocks = message["content"]
out: list[TakopiEvent] = []
for content in content_blocks:
if content["type"] != "tool_result":
continue
state.pending_actions[action.id] = action
out.append(_action_event(phase="started", action=action))
elif ctype == "text":
text = content.get("text")
if isinstance(text, str) and text:
state.last_assistant_text = text
return out
if etype == "user":
message = event.get("message")
if not isinstance(message, dict):
return []
message_id = message.get("id")
if not isinstance(message_id, str):
message_id = None
content_blocks = message.get("content")
if not isinstance(content_blocks, list):
return []
out: list[TakopiEvent] = []
for content in content_blocks:
if not isinstance(content, dict):
continue
if content.get("type") != "tool_result":
continue
tool_use_id = content.get("tool_use_id")
if not isinstance(tool_use_id, str) or not tool_use_id:
continue
action = state.pending_actions.pop(tool_use_id, None)
if action is None:
action = Action(
id=tool_use_id,
kind="tool",
title="tool result",
detail={},
tool_use_id = content["tool_use_id"]
action = state.pending_actions.pop(tool_use_id, None)
if action is None:
action = Action(
id=tool_use_id,
kind="tool",
title="tool result",
detail={},
)
out.append(
_tool_result_event(content, action=action, message_id=message_id)
)
return out
case "result":
out: list[TakopiEvent] = []
for idx, denial in enumerate(event.get("permission_denials", [])):
tool_name = denial.get("tool_name")
denial_title = "permission denied"
if tool_name:
denial_title = f"permission denied: {tool_name}"
tool_use_id = denial.get("tool_use_id")
action_id = (
f"claude.permission.{tool_use_id}"
if tool_use_id
else f"claude.permission.{idx}"
)
out.append(
_action_event(
phase="completed",
action=Action(
id=action_id,
kind="warning",
title=denial_title,
detail=denial,
),
ok=False,
level="warning",
)
)
out.append(
_tool_result_event(content, action=action, message_id=message_id)
)
return out
if etype == "result":
out: list[TakopiEvent] = []
for idx, denial in enumerate(event.get("permission_denials") or []):
if not isinstance(denial, dict):
continue
tool_name = denial.get("tool_name")
denial_title = "permission denied"
if isinstance(tool_name, str) and tool_name:
denial_title = f"permission denied: {tool_name}"
tool_use_id = denial.get("tool_use_id")
action_id = (
f"claude.permission.{tool_use_id}"
if isinstance(tool_use_id, str) and tool_use_id
else f"claude.permission.{idx}"
)
ok = not event.get("is_error", False)
result_text = event["result"]
if ok and not result_text and state.last_assistant_text:
result_text = state.last_assistant_text
resume = ResumeToken(engine=ENGINE, value=str(event["session_id"]))
error = None if ok else _extract_error(event)
usage = _usage_payload(event)
out.append(
_action_event(
phase="completed",
action=Action(
id=action_id,
kind="warning",
title=denial_title,
detail=denial,
),
ok=False,
level="warning",
CompletedEvent(
engine=ENGINE,
ok=ok,
answer=result_text,
resume=resume,
error=error,
usage=usage or None,
)
)
ok = not event.get("is_error", False)
result_text = event.get("result")
if not isinstance(result_text, str):
result_text = ""
if ok and not result_text and state.last_assistant_text:
result_text = state.last_assistant_text
resume_value = event.get("session_id")
resume = (
ResumeToken(engine=ENGINE, value=str(resume_value))
if resume_value
else None
)
error = None if ok else _extract_error(event)
usage = _usage_payload(event)
out.append(
CompletedEvent(
engine=ENGINE,
ok=ok,
answer=result_text,
resume=resume,
error=error,
usage=usage or None,
)
)
return out
return []
return out
case _:
return []
@dataclass
+110 -147
View File
@@ -185,27 +185,21 @@ def _todo_title(summary: _TodoSummary) -> str:
def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]:
item_type = item.get("type") or item.get("item_type")
item_type = cast(str, item.get("type") or item.get("item_type"))
if item_type == "assistant_message":
item_type = "agent_message"
if not item_type:
return []
if item_type == "agent_message":
return []
action_id = item.get("id")
if not isinstance(action_id, str) or not action_id:
logger.debug("[codex] missing item id in codex event: %r", item)
return []
action_id = str(item["id"])
phase = cast(ActionPhase, etype.split(".")[-1])
if item_type == "error":
if phase != "completed":
return []
message = str(item.get("message") or "codex item error")
message = str(item["message"])
return [
_action_event(
phase="completed",
@@ -224,7 +218,7 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
return []
if kind == "command":
title = relativize_command(str(item.get("command") or ""))
title = relativize_command(str(item["command"]))
if phase in {"started", "updated"}:
return [
_action_event(
@@ -235,13 +229,13 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
)
]
if phase == "completed":
exit_code = item.get("exit_code")
ok = item.get("status") != "failed"
if isinstance(exit_code, int):
exit_code = item["exit_code"]
ok = item["status"] != "failed"
if exit_code is not None:
ok = ok and exit_code == 0
detail = {
"exit_code": exit_code,
"status": item.get("status"),
"status": item["status"],
}
return [
_action_event(
@@ -255,22 +249,23 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
]
if kind == "tool":
tool_name = _short_tool_name(item)
title = tool_name
detail = {
"server": item.get("server"),
"tool": item.get("tool"),
"status": item.get("status"),
}
if "arguments" in item:
detail["arguments"] = item.get("arguments")
if item_type == "tool_call":
name = item.get("name")
tool_name = str(name) if name else "tool"
name = item["name"]
title = str(name) if name else "tool"
detail = {
"name": name,
"status": item["status"],
"arguments": item.get("arguments"),
}
else:
tool_name = _short_tool_name(item)
title = tool_name
detail = {"name": name, "status": item.get("status")}
if "arguments" in item:
detail["arguments"] = item.get("arguments")
detail = {
"server": item["server"],
"tool": item["tool"],
"status": item["status"],
"arguments": item.get("arguments"),
}
if phase in {"started", "updated"}:
return [
@@ -283,12 +278,10 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
)
]
if phase == "completed":
ok = item.get("status") != "failed" and not item.get("error")
error = item.get("error")
ok = item["status"] != "failed" and not item["error"]
error = item["error"]
if error:
detail["error_message"] = str(
error.get("message") if isinstance(error, dict) else error
)
detail["error_message"] = str(error.get("message") or error)
result_summary = _summarize_tool_result(item.get("result"))
if result_summary is not None:
detail["result_summary"] = result_summary
@@ -304,8 +297,8 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
]
if kind == "web_search":
title = str(item.get("query") or "")
detail = {"query": item.get("query")}
title = str(item["query"])
detail = {"query": item["query"]}
if phase in {"started", "updated"}:
return [
_action_event(
@@ -333,11 +326,11 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
return []
title = _format_change_summary(item)
detail = {
"changes": item.get("changes") or [],
"status": item.get("status"),
"error": item.get("error"),
"changes": item["changes"],
"status": item["status"],
"error": item["error"],
}
ok = item.get("status") != "failed"
ok = item["status"] != "failed"
return [
_action_event(
phase="completed",
@@ -351,11 +344,11 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
if kind == "note":
if item_type == "todo_list":
summary = _summarize_todo_list(item.get("items"))
summary = _summarize_todo_list(item["items"])
title = _todo_title(summary)
detail = {"done": summary.done, "total": summary.total}
else:
title = str(item.get("text") or "")
title = str(item["text"])
detail = None
if phase in {"started", "updated"}:
@@ -384,20 +377,15 @@ def _translate_item_event(etype: str, item: dict[str, Any]) -> list[TakopiEvent]
def translate_codex_event(event: dict[str, Any], *, title: str) -> list[TakopiEvent]:
etype = event.get("type")
if etype == "thread.started":
thread_id = event.get("thread_id")
if thread_id:
token = ResumeToken(engine=ENGINE, value=str(thread_id))
etype = event["type"]
match etype:
case "thread.started":
token = ResumeToken(engine=ENGINE, value=str(event["thread_id"]))
return [_started_event(token, title=title)]
logger.debug("[codex] codex thread.started missing thread_id: %r", event)
return []
if etype in {"item.started", "item.updated", "item.completed"}:
item = event.get("item") or {}
return _translate_item_event(etype, item)
return []
case "item.started" | "item.updated" | "item.completed":
return _translate_item_event(etype, event["item"])
case _:
return []
@dataclass(slots=True)
@@ -460,33 +448,6 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
def pipes_error_message(self) -> str:
return "codex exec failed to open subprocess pipes"
def handle_started_event(
self,
event: StartedEvent,
*,
expected_session: ResumeToken | None,
found_session: ResumeToken | None,
) -> tuple[ResumeToken | None, bool]:
if event.engine != ENGINE:
raise RuntimeError(
f"codex emitted session token for engine {event.engine!r}"
)
if expected_session is not None and event.resume != expected_session:
message = (
f"codex emitted session id {event.resume.value} "
f"but expected {expected_session.value}"
)
raise RuntimeError(message)
if found_session is None:
return event.resume, True
if event.resume != found_session:
message = (
f"codex emitted session id {event.resume.value} "
f"but expected {found_session.value}"
)
raise RuntimeError(message)
return found_session, False
def translate(
self,
data: dict[str, Any],
@@ -495,12 +456,33 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
resume: ResumeToken | None,
found_session: ResumeToken | None,
) -> list[TakopiEvent]:
etype = data.get("type")
if etype == "error":
message = str(data.get("message") or "codex error")
fatal_flag = data.get("fatal")
fatal = fatal_flag is True or fatal_flag is None
if fatal:
etype = data["type"]
match etype:
case "error":
message = str(data["message"])
fatal_flag = data.get("fatal")
fatal = fatal_flag is True or fatal_flag is None
if fatal:
resume_for_completed = found_session or resume
return [
_completed_event(
resume=resume_for_completed,
ok=False,
answer=state.final_answer or "",
error=message,
)
]
return [
self.note_event(
message,
state=state,
ok=False,
detail={"code": data.get("code"), "fatal": data.get("fatal")},
)
]
case "turn.failed":
error = data["error"]
message = str(error["message"])
resume_for_completed = found_session or resume
return [
_completed_event(
@@ -510,67 +492,48 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner):
error=message,
)
]
return [
self.note_event(
message,
state=state,
ok=False,
detail={"code": data.get("code"), "fatal": data.get("fatal")},
)
]
if etype == "turn.failed":
error = data.get("error") or {}
message = str(error.get("message") or "codex turn failed")
resume_for_completed = found_session or resume
return [
_completed_event(
resume=resume_for_completed,
ok=False,
answer=state.final_answer or "",
error=message,
)
]
if etype == "turn.rate_limited":
retry_ms = data.get("retry_after_ms")
message = "rate limited"
if isinstance(retry_ms, int):
message = f"rate limited (retry after {retry_ms}ms)"
return [self.note_event(message, state=state, ok=False)]
if etype == "turn.started":
action_id = f"turn_{state.turn_index}"
state.turn_index += 1
return [
_action_event(
phase="started",
action_id=action_id,
kind="turn",
title="turn started",
)
]
if etype == "turn.completed":
resume_for_completed = found_session or resume
return [
_completed_event(
resume=resume_for_completed,
ok=True,
answer=state.final_answer or "",
usage=data.get("usage"),
)
]
if data.get("type") == "item.completed":
item = data.get("item") or {}
item_type = item.get("type") or item.get("item_type")
if item_type == "assistant_message":
item_type = "agent_message"
if item_type == "agent_message" and isinstance(item.get("text"), str):
if state.final_answer is None:
state.final_answer = item["text"]
else:
logger.debug(
"[codex] emitted multiple agent messages; using the last one"
case "turn.rate_limited":
retry_ms = data.get("retry_after_ms")
message = "rate limited"
if isinstance(retry_ms, int):
message = f"rate limited (retry after {retry_ms}ms)"
return [self.note_event(message, state=state, ok=False)]
case "turn.started":
action_id = f"turn_{state.turn_index}"
state.turn_index += 1
return [
_action_event(
phase="started",
action_id=action_id,
kind="turn",
title="turn started",
)
state.final_answer = item["text"]
]
case "turn.completed":
resume_for_completed = found_session or resume
return [
_completed_event(
resume=resume_for_completed,
ok=True,
answer=state.final_answer or "",
usage=data.get("usage"),
)
]
case "item.completed":
item = data["item"]
item_type = cast(str, item.get("type") or item.get("item_type"))
if item_type == "assistant_message":
item_type = "agent_message"
if item_type == "agent_message":
if state.final_answer is None:
state.final_answer = item["text"]
else:
logger.debug(
"[codex] emitted multiple agent messages; using the last one"
)
state.final_answer = item["text"]
case _:
pass
return translate_codex_event(data, title=self.session_title)
+103
View File
@@ -0,0 +1,103 @@
from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Protocol
import anyio
from .model import ResumeToken
@dataclass(frozen=True, slots=True)
class ThreadJob:
chat_id: int
user_msg_id: int
text: str
resume_token: ResumeToken
RunJob = Callable[[ThreadJob], Awaitable[None]]
class TaskGroup(Protocol):
def start_soon(
self, func: Callable[..., Awaitable[object]], *args: Any
) -> None: ...
class ThreadScheduler:
def __init__(self, *, task_group: TaskGroup, run_job: RunJob) -> None:
self._task_group = task_group
self._run_job = run_job
self._lock = anyio.Lock()
self._pending_by_thread: dict[str, deque[ThreadJob]] = {}
self._active_threads: set[str] = set()
self._busy_until: dict[str, anyio.Event] = {}
@staticmethod
def thread_key(token: ResumeToken) -> str:
return f"{token.engine}:{token.value}"
async def note_thread_known(self, token: ResumeToken, done: anyio.Event) -> None:
key = self.thread_key(token)
async with self._lock:
current = self._busy_until.get(key)
if current is None or current.is_set():
self._busy_until[key] = done
self._task_group.start_soon(self._clear_busy, key, done)
async def enqueue(self, job: ThreadJob) -> None:
key = self.thread_key(job.resume_token)
async with self._lock:
queue = self._pending_by_thread.get(key)
if queue is None:
queue = deque()
self._pending_by_thread[key] = queue
queue.append(job)
if key in self._active_threads:
return
self._active_threads.add(key)
self._task_group.start_soon(self._thread_worker, key)
async def enqueue_resume(
self,
chat_id: int,
user_msg_id: int,
text: str,
resume_token: ResumeToken,
) -> None:
await self.enqueue(
ThreadJob(
chat_id=chat_id,
user_msg_id=user_msg_id,
text=text,
resume_token=resume_token,
)
)
async def _clear_busy(self, key: str, done: anyio.Event) -> None:
await done.wait()
async with self._lock:
if self._busy_until.get(key) is done:
self._busy_until.pop(key, None)
async def _thread_worker(self, key: str) -> None:
try:
while True:
async with self._lock:
done = self._busy_until.get(key)
queue = self._pending_by_thread.get(key)
if not queue:
self._pending_by_thread.pop(key, None)
self._active_threads.discard(key)
return
job = queue.popleft()
if done is not None and not done.is_set():
await done.wait()
await self._run_job(job)
finally:
async with self._lock:
self._active_threads.discard(key)
+5 -5
View File
@@ -13,11 +13,11 @@ def relativize_path(value: str, *, base_dir: Path | None = None) -> str:
return value
if value == base_str:
return "."
if value.startswith(base_str):
suffix = value[len(base_str) :]
if suffix.startswith((os.sep, "/")):
suffix = suffix[1:]
return suffix or "."
for sep in (os.sep, "/"):
prefix = base_str if base_str.endswith(sep) else f"{base_str}{sep}"
if value.startswith(prefix):
suffix = value[len(prefix) :]
return suffix or "."
return value
+3 -3
View File
@@ -318,8 +318,8 @@ def test_render_event_cli_ignores_turn_actions() -> None:
assert render_event_cli(event) == []
def test_progress_renderer_ignores_missing_action_id_and_titles() -> None:
renderer = ExecProgressRenderer(engine="codex", show_title=True)
def test_progress_renderer_ignores_missing_action_id() -> None:
renderer = ExecProgressRenderer(engine="codex")
resume = ResumeToken(engine="codex", value="abc")
renderer.note_event(StartedEvent(engine="codex", resume=resume, title="Session"))
@@ -332,4 +332,4 @@ def test_progress_renderer_ignores_missing_action_id_and_titles() -> None:
assert renderer.note_event(event) is False
header = assemble_markdown_parts(renderer.render_progress_parts(0.0))
assert header.startswith("working (Session) · codex · 0s")
assert header.startswith("working · codex · 0s")
+15 -1
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from pathlib import Path
from takopi.utils.paths import relativize_command
from takopi.utils.paths import relativize_command, relativize_path
def test_relativize_command_rewrites_cwd_paths(tmp_path: Path) -> None:
@@ -19,3 +19,17 @@ def test_relativize_command_rewrites_equals_paths(tmp_path: Path) -> None:
command = f'rg -n --files -g "*.py" --path={base}/src'
expected = 'rg -n --files -g "*.py" --path=src'
assert relativize_command(command, base_dir=base) == expected
def test_relativize_path_ignores_sibling_prefix(tmp_path: Path) -> None:
base = tmp_path / "repo"
base.mkdir()
value = str(tmp_path / "repo2" / "file.txt")
assert relativize_path(value, base_dir=base) == value
def test_relativize_path_inside_base(tmp_path: Path) -> None:
base = tmp_path / "repo"
base.mkdir()
value = str(base / "src" / "app.py")
assert relativize_path(value, base_dir=base) == "src/app.py"
Generated
+1 -1
View File
@@ -354,7 +354,7 @@ wheels = [
[[package]]
name = "takopi"
version = "0.4.0"
version = "0.5.0.dev0"
source = { editable = "." }
dependencies = [
{ name = "anyio" },