refactor: simplify exec render architecture

This commit is contained in:
banteg
2025-12-29 03:13:13 +04:00
parent 3fa21c8d3c
commit 695f882a22
3 changed files with 91 additions and 144 deletions
@@ -20,7 +20,7 @@ import typer
from .config import load_telegram_config from .config import load_telegram_config
from .constants import TELEGRAM_HARD_LIMIT from .constants import TELEGRAM_HARD_LIMIT
from .exec_render import ExecProgressRenderer, ExecRenderState, render_event_cli from .exec_render import ExecProgressRenderer, render_event_cli
from .rendering import render_markdown from .rendering import render_markdown
from .telegram_client import TelegramClient from .telegram_client import TelegramClient
@@ -224,7 +224,7 @@ class CodexExecRunner:
last_agent_text: Optional[str] = None last_agent_text: Optional[str] = None
saw_agent_message = False saw_agent_message = False
cli_state = ExecRenderState() cli_last_turn = None
for line in proc.stdout: for line in proc.stdout:
line = line.strip() line = line.strip()
@@ -234,7 +234,8 @@ class CodexExecRunner:
evt = json.loads(line) evt = json.loads(line)
except json.JSONDecodeError: except json.JSONDecodeError:
continue continue
for out in render_event_cli(evt, cli_state): cli_last_turn, out_lines = render_event_cli(evt, cli_last_turn)
for out in out_lines:
log(f"[codex] {out}") log(f"[codex] {out}")
if on_event is not None: if on_event is not None:
try: try:
@@ -3,9 +3,8 @@ from __future__ import annotations
import re import re
import textwrap import textwrap
from collections import deque from collections import deque
from dataclasses import dataclass, field
from textwrap import indent from textwrap import indent
from typing import Any, Optional from typing import Any
STATUS_RUNNING = "" STATUS_RUNNING = ""
STATUS_DONE = "" STATUS_DONE = ""
@@ -18,17 +17,6 @@ MAX_PATH_LEN = 40
MAX_PROGRESS_CHARS = 300 MAX_PROGRESS_CHARS = 300
def one_line(text: str) -> str:
return " ".join(text.split())
def shorten_field(text: str, max_len: int) -> str:
return textwrap.shorten(one_line(text), width=max_len, placeholder="")
def truncate(text: str, max_len: int) -> str:
return one_line(text)[:max_len]
def format_elapsed(elapsed_s: float) -> str: def format_elapsed(elapsed_s: float) -> str:
total = max(0, int(elapsed_s)) total = max(0, int(elapsed_s))
minutes, seconds = divmod(total, 60) minutes, seconds = divmod(total, 60)
@@ -40,48 +28,18 @@ def format_elapsed(elapsed_s: float) -> str:
return f"{seconds}s" return f"{seconds}s"
def format_header(elapsed_s: float, turn: Optional[int], label: str) -> str: def format_header(elapsed_s: float, turn: int | None, label: str) -> str:
elapsed = format_elapsed(elapsed_s) elapsed = format_elapsed(elapsed_s)
if turn is not None: if turn is not None:
return f"{label}{HEADER_SEP}{elapsed}{HEADER_SEP}turn {turn}" return f"{label}{HEADER_SEP}{elapsed}{HEADER_SEP}turn {turn}"
return f"{label}{HEADER_SEP}{elapsed}" return f"{label}{HEADER_SEP}{elapsed}"
def format_command(command: str) -> str:
command = shorten_field(command, MAX_CMD_LEN)
return f"`{command}`"
def format_query(query: str) -> str:
return truncate(query, MAX_QUERY_LEN)
def format_paths(paths: list[str]) -> str:
rendered = []
for path in paths:
rendered.append(f"`{shorten_field(path, MAX_PATH_LEN)}`")
return ", ".join(rendered)
def format_file_change(changes: list[dict[str, Any]]) -> str:
paths = [change.get("path") for change in changes if change.get("path")]
if not paths:
total = len(changes)
return "updated files" if total == 0 else f"updated {total} files"
if len(paths) <= 3:
return f"updated {format_paths(paths)}"
return f"updated {len(paths)} files"
def format_tool_call(server: str, tool: str) -> str:
name = ".".join(part for part in (server, tool) if part)
return name or "tool"
def is_command_log_line(line: str) -> bool: def is_command_log_line(line: str) -> bool:
return f"{STATUS_RUNNING} running:" in line or f"{STATUS_DONE} ran:" in line return f"{STATUS_RUNNING} running:" in line or f"{STATUS_DONE} ran:" in line
def extract_numeric_id(item_id: Optional[object], fallback: Optional[int] = None) -> Optional[int]: def extract_numeric_id(item_id: object, fallback: int | None = None) -> int | None:
if isinstance(item_id, int): if isinstance(item_id, int):
return item_id return item_id
if isinstance(item_id, str): if isinstance(item_id, str):
@@ -91,143 +49,130 @@ def extract_numeric_id(item_id: Optional[object], fallback: Optional[int] = None
return fallback return fallback
def attach_id(item_id: Optional[int], line: str) -> str: def _shorten(text: str, width: int) -> str:
return f"[{item_id if item_id is not None else '?'}] {line}" return textwrap.shorten(text, width=width, placeholder="")
def format_reasoning(text: str) -> str:
return text
def format_item_line(etype: str, item: dict[str, Any]) -> str | None: def _shorten_path(path: str, width: int) -> str:
match (item["type"], etype): # Encourage word-boundary truncation for paths (since they may have no spaces).
case ("reasoning", "item.completed"): return _shorten(path.replace("/", " /"), width).replace(" /", "/")
return format_reasoning(item["text"])
case ("command_execution", "item.started"):
command = format_command(item["command"])
return f"{STATUS_RUNNING} running: {command}"
case ("command_execution", "item.completed"):
command = format_command(item["command"])
exit_code = item["exit_code"]
exit_part = f" (exit {exit_code})" if exit_code is not None else ""
return f"{STATUS_DONE} ran: {command}{exit_part}"
case ("mcp_tool_call", "item.started"):
name = format_tool_call(item["server"], item["tool"])
return f"{STATUS_RUNNING} tool: {name}"
case ("mcp_tool_call", "item.completed"):
name = format_tool_call(item["server"], item["tool"])
return f"{STATUS_DONE} tool: {name}"
case ("web_search", "item.completed"):
query = format_query(item["query"])
return f"{STATUS_DONE} searched: {query}"
case ("file_change", "item.completed"):
return f"{STATUS_DONE} {format_file_change(item['changes'])}"
case ("error", "item.completed"):
warning = truncate(item["message"], 120)
return f"{STATUS_DONE} warning: {warning}"
case _:
return None
@dataclass def render_event_cli(event: dict[str, Any], last_turn: int | None = None) -> tuple[int | None, list[str]]:
class ExecRenderState:
recent_actions: deque[str] = field(default_factory=lambda: deque(maxlen=5))
last_turn: Optional[int] = None
def record_item(state: ExecRenderState, item: dict[str, Any]) -> None:
numeric_id = extract_numeric_id(item["id"])
if numeric_id is not None:
state.last_turn = numeric_id
def render_event_cli(
event: dict[str, Any],
state: ExecRenderState,
) -> list[str]:
lines: list[str] = [] lines: list[str] = []
etype = event["type"] match event["type"]:
match etype:
case "thread.started": case "thread.started":
return ["thread started"] return last_turn, ["thread started"]
case "turn.started": case "turn.started":
return ["turn started"] return last_turn, ["turn started"]
case "turn.completed": case "turn.completed":
return ["turn completed"] return last_turn, ["turn completed"]
case "turn.failed": case "turn.failed":
return [f"turn failed: {event['error']['message']}"] return last_turn, [f"turn failed: {event['error']['message']}"]
case "error": case "error":
return [f"stream error: {event['message']}"] return last_turn, [f"stream error: {event['message']}"]
case "item.started" | "item.updated" | "item.completed": case "item.started" | "item.updated" | "item.completed" as etype:
item = event["item"] item = event["item"]
record_item(state, item) item_num = extract_numeric_id(item["id"], last_turn)
last_turn = item_num if item_num is not None else last_turn
prefix = f"[{item_num if item_num is not None else '?'}] "
item_num = extract_numeric_id(item["id"], state.last_turn)
match (item["type"], etype): match (item["type"], etype):
case ("agent_message", "item.completed"): case ("agent_message", "item.completed"):
lines.append("assistant:") lines.append("assistant:")
lines.extend(indent(item["text"], " ").splitlines()) lines.extend(indent(item["text"], " ").splitlines())
case ("reasoning", "item.completed"):
lines.append(prefix + item["text"])
case ("command_execution", "item.started"):
command = f"`{_shorten(item['command'], MAX_CMD_LEN)}`"
lines.append(prefix + f"{STATUS_RUNNING} running: {command}")
case ("command_execution", "item.completed"):
command = f"`{_shorten(item['command'], MAX_CMD_LEN)}`"
exit_code = item["exit_code"]
exit_part = f" (exit {exit_code})" if exit_code is not None else ""
lines.append(prefix + f"{STATUS_DONE} ran: {command}{exit_part}")
case ("mcp_tool_call", "item.started"):
name = ".".join(part for part in (item["server"], item["tool"]) if part) or "tool"
lines.append(prefix + f"{STATUS_RUNNING} tool: {name}")
case ("mcp_tool_call", "item.completed"):
name = ".".join(part for part in (item["server"], item["tool"]) if part) or "tool"
lines.append(prefix + f"{STATUS_DONE} tool: {name}")
case ("web_search", "item.completed"):
query = _shorten(item["query"], MAX_QUERY_LEN)
lines.append(prefix + f"{STATUS_DONE} searched: {query}")
case ("file_change", "item.completed"):
paths = [change["path"] for change in item["changes"] if change.get("path")]
if not paths:
total = len(item["changes"])
desc = "updated files" if total == 0 else f"updated {total} files"
elif len(paths) <= 3:
desc = "updated " + ", ".join(f"`{_shorten_path(p, MAX_PATH_LEN)}`" for p in paths)
else:
desc = f"updated {len(paths)} files"
lines.append(prefix + f"{STATUS_DONE} {desc}")
case ("error", "item.completed"):
warning = _shorten(item["message"], 120)
lines.append(prefix + f"{STATUS_DONE} warning: {warning}")
case _: case _:
line = format_item_line(etype, item) pass
if line is not None: return last_turn, lines
lines.append(attach_id(item_num, line))
return lines
case _: case _:
return lines return last_turn, lines
class ExecProgressRenderer: class ExecProgressRenderer:
def __init__(self, max_actions: int = 5, max_chars: int = MAX_PROGRESS_CHARS) -> None: def __init__(self, max_actions: int = 5, max_chars: int = MAX_PROGRESS_CHARS) -> None:
self.max_actions = max_actions self.max_actions = max_actions
self.state = ExecRenderState(recent_actions=deque(maxlen=max_actions))
self.max_chars = max_chars self.max_chars = max_chars
self.recent_actions: deque[str] = deque(maxlen=max_actions)
self.last_turn: int | None = None
def note_event(self, event: dict[str, Any]) -> bool: def note_event(self, event: dict[str, Any]) -> bool:
etype = event["type"] match event["type"]:
match etype:
case "thread.started" | "turn.started": case "thread.started" | "turn.started":
return True return True
case "item.started" | "item.updated" | "item.completed": case "item.started" | "item.updated" | "item.completed" as etype:
item = event["item"] item = event["item"]
record_item(self.state, item) item_id = extract_numeric_id(item["id"], self.last_turn)
item_id = extract_numeric_id(item["id"], self.state.last_turn) self.last_turn = item_id if item_id is not None else self.last_turn
prefix = f"[{item_id if item_id is not None else '?'}] "
match item["type"]: match item["type"]:
case "agent_message": case "agent_message":
return False return False
case _: case _:
line = format_item_line(etype, item) _, lines = render_event_cli(event, self.last_turn)
if line is not None: if not lines:
full = attach_id(item_id, line) return False
if etype == "item.completed" and self.state.recent_actions: line = lines[0]
last = self.state.recent_actions[-1]
if last.startswith(f"[{item_id}] {STATUS_RUNNING} "): # Replace the preceding "running" line for the same item on completion.
self.state.recent_actions.pop() if etype == "item.completed" and self.recent_actions:
self.state.recent_actions.append(full) last = self.recent_actions[-1]
return True if last.startswith(prefix + f"{STATUS_RUNNING} "):
return False self.recent_actions.pop()
self.recent_actions.append(line)
return True
case _: case _:
return False return False
def render_progress(self, elapsed_s: float) -> str: def render_progress(self, elapsed_s: float) -> str:
header = format_header(elapsed_s, self.state.last_turn, label="working") header = format_header(elapsed_s, self.last_turn, label="working")
message = self._assemble(header, list(self.state.recent_actions)) message = self._assemble(header, list(self.recent_actions))
if len(message) <= self.max_chars: return message if len(message) <= self.max_chars else header
return message
return header
def render_final(self, elapsed_s: float, answer: str, status: str = "done") -> str: def render_final(self, elapsed_s: float, answer: str, status: str = "done") -> str:
header = format_header(elapsed_s, self.state.last_turn, label=status) header = format_header(elapsed_s, self.last_turn, label=status)
lines = list(self.state.recent_actions) lines = list(self.recent_actions)
if status == "done": if status == "done":
lines = [line for line in lines if not is_command_log_line(line)] lines = [line for line in lines if not is_command_log_line(line)]
body = self._assemble(header, lines) body = self._assemble(header, lines)
answer = (answer or "").strip() answer = (answer or "").strip()
if answer: return body + ("\n\n" + answer if answer else "")
body = body + "\n\n" + answer
return body
@staticmethod @staticmethod
def _assemble(header: str, lines: list[str]) -> str: def _assemble(header: str, lines: list[str]) -> str:
if not lines: return header if not lines else header + "\n\n" + HARD_BREAK.join(lines)
return header
return header + "\n\n" + HARD_BREAK.join(lines)
@@ -1,6 +1,6 @@
import json import json
from codex_telegram_bridge.exec_render import ExecProgressRenderer, ExecRenderState, render_event_cli from codex_telegram_bridge.exec_render import ExecProgressRenderer, render_event_cli
def _loads(lines: str) -> list[dict]: def _loads(lines: str) -> list[dict]:
@@ -20,10 +20,11 @@ SAMPLE_STREAM = """
def test_render_event_cli_sample_stream() -> None: def test_render_event_cli_sample_stream() -> None:
state = ExecRenderState() last_turn = None
out: list[str] = [] out: list[str] = []
for evt in _loads(SAMPLE_STREAM): for evt in _loads(SAMPLE_STREAM):
out.extend(render_event_cli(evt, state)) last_turn, lines = render_event_cli(evt, last_turn)
out.extend(lines)
assert out == [ assert out == [
"thread started", "thread started",