refactor: simplify exec render architecture

This commit is contained in:
banteg
2025-12-29 03:13:13 +04:00
parent 3fa21c8d3c
commit 695f882a22
3 changed files with 91 additions and 144 deletions
@@ -20,7 +20,7 @@ import typer
from .config import load_telegram_config
from .constants import TELEGRAM_HARD_LIMIT
from .exec_render import ExecProgressRenderer, ExecRenderState, render_event_cli
from .exec_render import ExecProgressRenderer, render_event_cli
from .rendering import render_markdown
from .telegram_client import TelegramClient
@@ -224,7 +224,7 @@ class CodexExecRunner:
last_agent_text: Optional[str] = None
saw_agent_message = False
cli_state = ExecRenderState()
cli_last_turn = None
for line in proc.stdout:
line = line.strip()
@@ -234,7 +234,8 @@ class CodexExecRunner:
evt = json.loads(line)
except json.JSONDecodeError:
continue
for out in render_event_cli(evt, cli_state):
cli_last_turn, out_lines = render_event_cli(evt, cli_last_turn)
for out in out_lines:
log(f"[codex] {out}")
if on_event is not None:
try:
@@ -3,9 +3,8 @@ from __future__ import annotations
import re
import textwrap
from collections import deque
from dataclasses import dataclass, field
from textwrap import indent
from typing import Any, Optional
from typing import Any
STATUS_RUNNING = ""
STATUS_DONE = ""
@@ -18,17 +17,6 @@ MAX_PATH_LEN = 40
MAX_PROGRESS_CHARS = 300
def one_line(text: str) -> str:
return " ".join(text.split())
def shorten_field(text: str, max_len: int) -> str:
return textwrap.shorten(one_line(text), width=max_len, placeholder="")
def truncate(text: str, max_len: int) -> str:
return one_line(text)[:max_len]
def format_elapsed(elapsed_s: float) -> str:
total = max(0, int(elapsed_s))
minutes, seconds = divmod(total, 60)
@@ -40,48 +28,18 @@ def format_elapsed(elapsed_s: float) -> str:
return f"{seconds}s"
def format_header(elapsed_s: float, turn: Optional[int], label: str) -> str:
def format_header(elapsed_s: float, turn: int | None, label: str) -> str:
elapsed = format_elapsed(elapsed_s)
if turn is not None:
return f"{label}{HEADER_SEP}{elapsed}{HEADER_SEP}turn {turn}"
return f"{label}{HEADER_SEP}{elapsed}"
def format_command(command: str) -> str:
command = shorten_field(command, MAX_CMD_LEN)
return f"`{command}`"
def format_query(query: str) -> str:
return truncate(query, MAX_QUERY_LEN)
def format_paths(paths: list[str]) -> str:
rendered = []
for path in paths:
rendered.append(f"`{shorten_field(path, MAX_PATH_LEN)}`")
return ", ".join(rendered)
def format_file_change(changes: list[dict[str, Any]]) -> str:
paths = [change.get("path") for change in changes if change.get("path")]
if not paths:
total = len(changes)
return "updated files" if total == 0 else f"updated {total} files"
if len(paths) <= 3:
return f"updated {format_paths(paths)}"
return f"updated {len(paths)} files"
def format_tool_call(server: str, tool: str) -> str:
name = ".".join(part for part in (server, tool) if part)
return name or "tool"
def is_command_log_line(line: str) -> bool:
return f"{STATUS_RUNNING} running:" in line or f"{STATUS_DONE} ran:" in line
def extract_numeric_id(item_id: Optional[object], fallback: Optional[int] = None) -> Optional[int]:
def extract_numeric_id(item_id: object, fallback: int | None = None) -> int | None:
if isinstance(item_id, int):
return item_id
if isinstance(item_id, str):
@@ -91,143 +49,130 @@ def extract_numeric_id(item_id: Optional[object], fallback: Optional[int] = None
return fallback
def attach_id(item_id: Optional[int], line: str) -> str:
return f"[{item_id if item_id is not None else '?'}] {line}"
def format_reasoning(text: str) -> str:
return text
def _shorten(text: str, width: int) -> str:
return textwrap.shorten(text, width=width, placeholder="")
def format_item_line(etype: str, item: dict[str, Any]) -> str | None:
match (item["type"], etype):
case ("reasoning", "item.completed"):
return format_reasoning(item["text"])
case ("command_execution", "item.started"):
command = format_command(item["command"])
return f"{STATUS_RUNNING} running: {command}"
case ("command_execution", "item.completed"):
command = format_command(item["command"])
exit_code = item["exit_code"]
exit_part = f" (exit {exit_code})" if exit_code is not None else ""
return f"{STATUS_DONE} ran: {command}{exit_part}"
case ("mcp_tool_call", "item.started"):
name = format_tool_call(item["server"], item["tool"])
return f"{STATUS_RUNNING} tool: {name}"
case ("mcp_tool_call", "item.completed"):
name = format_tool_call(item["server"], item["tool"])
return f"{STATUS_DONE} tool: {name}"
case ("web_search", "item.completed"):
query = format_query(item["query"])
return f"{STATUS_DONE} searched: {query}"
case ("file_change", "item.completed"):
return f"{STATUS_DONE} {format_file_change(item['changes'])}"
case ("error", "item.completed"):
warning = truncate(item["message"], 120)
return f"{STATUS_DONE} warning: {warning}"
case _:
return None
def _shorten_path(path: str, width: int) -> str:
# Encourage word-boundary truncation for paths (since they may have no spaces).
return _shorten(path.replace("/", " /"), width).replace(" /", "/")
@dataclass
class ExecRenderState:
recent_actions: deque[str] = field(default_factory=lambda: deque(maxlen=5))
last_turn: Optional[int] = None
def record_item(state: ExecRenderState, item: dict[str, Any]) -> None:
numeric_id = extract_numeric_id(item["id"])
if numeric_id is not None:
state.last_turn = numeric_id
def render_event_cli(
event: dict[str, Any],
state: ExecRenderState,
) -> list[str]:
def render_event_cli(event: dict[str, Any], last_turn: int | None = None) -> tuple[int | None, list[str]]:
lines: list[str] = []
etype = event["type"]
match etype:
match event["type"]:
case "thread.started":
return ["thread started"]
return last_turn, ["thread started"]
case "turn.started":
return ["turn started"]
return last_turn, ["turn started"]
case "turn.completed":
return ["turn completed"]
return last_turn, ["turn completed"]
case "turn.failed":
return [f"turn failed: {event['error']['message']}"]
return last_turn, [f"turn failed: {event['error']['message']}"]
case "error":
return [f"stream error: {event['message']}"]
case "item.started" | "item.updated" | "item.completed":
return last_turn, [f"stream error: {event['message']}"]
case "item.started" | "item.updated" | "item.completed" as etype:
item = event["item"]
record_item(state, item)
item_num = extract_numeric_id(item["id"], last_turn)
last_turn = item_num if item_num is not None else last_turn
prefix = f"[{item_num if item_num is not None else '?'}] "
item_num = extract_numeric_id(item["id"], state.last_turn)
match (item["type"], etype):
case ("agent_message", "item.completed"):
lines.append("assistant:")
lines.extend(indent(item["text"], " ").splitlines())
case ("reasoning", "item.completed"):
lines.append(prefix + item["text"])
case ("command_execution", "item.started"):
command = f"`{_shorten(item['command'], MAX_CMD_LEN)}`"
lines.append(prefix + f"{STATUS_RUNNING} running: {command}")
case ("command_execution", "item.completed"):
command = f"`{_shorten(item['command'], MAX_CMD_LEN)}`"
exit_code = item["exit_code"]
exit_part = f" (exit {exit_code})" if exit_code is not None else ""
lines.append(prefix + f"{STATUS_DONE} ran: {command}{exit_part}")
case ("mcp_tool_call", "item.started"):
name = ".".join(part for part in (item["server"], item["tool"]) if part) or "tool"
lines.append(prefix + f"{STATUS_RUNNING} tool: {name}")
case ("mcp_tool_call", "item.completed"):
name = ".".join(part for part in (item["server"], item["tool"]) if part) or "tool"
lines.append(prefix + f"{STATUS_DONE} tool: {name}")
case ("web_search", "item.completed"):
query = _shorten(item["query"], MAX_QUERY_LEN)
lines.append(prefix + f"{STATUS_DONE} searched: {query}")
case ("file_change", "item.completed"):
paths = [change["path"] for change in item["changes"] if change.get("path")]
if not paths:
total = len(item["changes"])
desc = "updated files" if total == 0 else f"updated {total} files"
elif len(paths) <= 3:
desc = "updated " + ", ".join(f"`{_shorten_path(p, MAX_PATH_LEN)}`" for p in paths)
else:
desc = f"updated {len(paths)} files"
lines.append(prefix + f"{STATUS_DONE} {desc}")
case ("error", "item.completed"):
warning = _shorten(item["message"], 120)
lines.append(prefix + f"{STATUS_DONE} warning: {warning}")
case _:
line = format_item_line(etype, item)
if line is not None:
lines.append(attach_id(item_num, line))
return lines
pass
return last_turn, lines
case _:
return lines
return last_turn, lines
class ExecProgressRenderer:
def __init__(self, max_actions: int = 5, max_chars: int = MAX_PROGRESS_CHARS) -> None:
self.max_actions = max_actions
self.state = ExecRenderState(recent_actions=deque(maxlen=max_actions))
self.max_chars = max_chars
self.recent_actions: deque[str] = deque(maxlen=max_actions)
self.last_turn: int | None = None
def note_event(self, event: dict[str, Any]) -> bool:
etype = event["type"]
match etype:
match event["type"]:
case "thread.started" | "turn.started":
return True
case "item.started" | "item.updated" | "item.completed":
case "item.started" | "item.updated" | "item.completed" as etype:
item = event["item"]
record_item(self.state, item)
item_id = extract_numeric_id(item["id"], self.state.last_turn)
item_id = extract_numeric_id(item["id"], self.last_turn)
self.last_turn = item_id if item_id is not None else self.last_turn
prefix = f"[{item_id if item_id is not None else '?'}] "
match item["type"]:
case "agent_message":
return False
case _:
line = format_item_line(etype, item)
if line is not None:
full = attach_id(item_id, line)
if etype == "item.completed" and self.state.recent_actions:
last = self.state.recent_actions[-1]
if last.startswith(f"[{item_id}] {STATUS_RUNNING} "):
self.state.recent_actions.pop()
self.state.recent_actions.append(full)
return True
return False
_, lines = render_event_cli(event, self.last_turn)
if not lines:
return False
line = lines[0]
# Replace the preceding "running" line for the same item on completion.
if etype == "item.completed" and self.recent_actions:
last = self.recent_actions[-1]
if last.startswith(prefix + f"{STATUS_RUNNING} "):
self.recent_actions.pop()
self.recent_actions.append(line)
return True
case _:
return False
def render_progress(self, elapsed_s: float) -> str:
header = format_header(elapsed_s, self.state.last_turn, label="working")
message = self._assemble(header, list(self.state.recent_actions))
if len(message) <= self.max_chars:
return message
return header
header = format_header(elapsed_s, self.last_turn, label="working")
message = self._assemble(header, list(self.recent_actions))
return message if len(message) <= self.max_chars else header
def render_final(self, elapsed_s: float, answer: str, status: str = "done") -> str:
header = format_header(elapsed_s, self.state.last_turn, label=status)
lines = list(self.state.recent_actions)
header = format_header(elapsed_s, self.last_turn, label=status)
lines = list(self.recent_actions)
if status == "done":
lines = [line for line in lines if not is_command_log_line(line)]
body = self._assemble(header, lines)
answer = (answer or "").strip()
if answer:
body = body + "\n\n" + answer
return body
return body + ("\n\n" + answer if answer else "")
@staticmethod
def _assemble(header: str, lines: list[str]) -> str:
if not lines:
return header
return header + "\n\n" + HARD_BREAK.join(lines)
return header if not lines else header + "\n\n" + HARD_BREAK.join(lines)
@@ -1,6 +1,6 @@
import json
from codex_telegram_bridge.exec_render import ExecProgressRenderer, ExecRenderState, render_event_cli
from codex_telegram_bridge.exec_render import ExecProgressRenderer, render_event_cli
def _loads(lines: str) -> list[dict]:
@@ -20,10 +20,11 @@ SAMPLE_STREAM = """
def test_render_event_cli_sample_stream() -> None:
state = ExecRenderState()
last_turn = None
out: list[str] = []
for evt in _loads(SAMPLE_STREAM):
out.extend(render_event_cli(evt, state))
last_turn, lines = render_event_cli(evt, last_turn)
out.extend(lines)
assert out == [
"thread started",