refactor: use pattern matching in exec render

This commit is contained in:
banteg
2025-12-29 03:00:14 +04:00
parent deb040b219
commit e4bcfe7f88
@@ -93,38 +93,36 @@ def attach_id(item_id: Optional[int], line: str) -> str:
def format_item_action_line(etype: str, item: dict[str, Any]) -> str | None: def format_item_action_line(etype: str, item: dict[str, Any]) -> str | None:
itype = item["type"] match (item["type"], etype):
if itype == "command_execution": case ("command_execution", "item.started"):
command = format_command(item["command"]) command = format_command(item["command"])
if etype == "item.started":
return f"{STATUS_RUNNING} running: {command}" return f"{STATUS_RUNNING} running: {command}"
if etype == "item.completed": case ("command_execution", "item.completed"):
command = format_command(item["command"])
exit_code = item["exit_code"] exit_code = item["exit_code"]
exit_part = f" (exit {exit_code})" if exit_code is not None else "" exit_part = f" (exit {exit_code})" if exit_code is not None else ""
return f"{STATUS_DONE} ran: {command}{exit_part}" return f"{STATUS_DONE} ran: {command}{exit_part}"
return None case ("mcp_tool_call", "item.started"):
if itype == "mcp_tool_call":
name = format_tool_call(item["server"], item["tool"]) name = format_tool_call(item["server"], item["tool"])
if etype == "item.started":
return f"{STATUS_RUNNING} tool: {name}" return f"{STATUS_RUNNING} tool: {name}"
if etype == "item.completed": case ("mcp_tool_call", "item.completed"):
name = format_tool_call(item["server"], item["tool"])
return f"{STATUS_DONE} tool: {name}" return f"{STATUS_DONE} tool: {name}"
return None case _:
return None return None
def format_item_completed_line(item: dict[str, Any]) -> str | None: def format_item_completed_line(item: dict[str, Any]) -> str | None:
itype = item["type"] match item["type"]:
if itype == "web_search": case "web_search":
query = format_query(item["query"]) query = format_query(item["query"])
return f"{STATUS_DONE} searched: {query}" return f"{STATUS_DONE} searched: {query}"
if itype == "file_change": case "file_change":
return f"{STATUS_DONE} {format_file_change(item['changes'])}" return f"{STATUS_DONE} {format_file_change(item['changes'])}"
if itype == "error": case "error":
warning = truncate(item["message"], 120) warning = truncate(item["message"], 120)
return f"{STATUS_DONE} warning: {warning}" return f"{STATUS_DONE} warning: {warning}"
case _:
return None return None
@@ -144,37 +142,30 @@ def render_event_cli(
event: dict[str, Any], event: dict[str, Any],
state: ExecRenderState, state: ExecRenderState,
) -> list[str]: ) -> list[str]:
etype = event["type"]
lines: list[str] = [] lines: list[str] = []
if etype == "thread.started": etype = event["type"]
match etype:
case "thread.started":
return ["thread started"] return ["thread started"]
case "turn.started":
if etype == "turn.started":
return ["turn started"] return ["turn started"]
case "turn.completed":
if etype == "turn.completed":
return ["turn completed"] return ["turn completed"]
case "turn.failed":
if etype == "turn.failed": return [f"turn failed: {event['error']['message']}"]
error = event["error"]["message"] case "error":
return [f"turn failed: {error}"]
if etype == "error":
return [f"stream error: {event['message']}"] return [f"stream error: {event['message']}"]
case "item.started" | "item.updated" | "item.completed":
if etype in {"item.started", "item.updated", "item.completed"}:
item = event["item"] item = event["item"]
record_item(state, item) record_item(state, item)
itype = item["type"]
item_num = extract_numeric_id(item["id"], state.last_turn) item_num = extract_numeric_id(item["id"], state.last_turn)
match (item["type"], etype):
if itype == "agent_message" and etype == "item.completed": case ("agent_message", "item.completed"):
lines.append("assistant:") lines.append("assistant:")
lines.extend(indent(item["text"], " ").splitlines()) lines.extend(indent(item["text"], " ").splitlines())
case _:
else:
action_line = format_item_action_line(etype, item) action_line = format_item_action_line(etype, item)
if action_line is not None: if action_line is not None:
lines.append(attach_id(item_num, action_line)) lines.append(attach_id(item_num, action_line))
@@ -182,7 +173,8 @@ def render_event_cli(
completed_line = format_item_completed_line(item) completed_line = format_item_completed_line(item)
if completed_line is not None: if completed_line is not None:
lines.append(attach_id(item_num, completed_line)) lines.append(attach_id(item_num, completed_line))
return lines
case _:
return lines return lines
@@ -194,30 +186,28 @@ class ExecProgressRenderer:
def note_event(self, event: dict[str, Any]) -> bool: def note_event(self, event: dict[str, Any]) -> bool:
etype = event["type"] etype = event["type"]
match etype:
if etype in {"thread.started", "turn.started"}: case "thread.started" | "turn.started":
return True return True
case "item.started" | "item.updated" | "item.completed":
if etype in {"item.started", "item.updated", "item.completed"}:
item = event["item"] item = event["item"]
record_item(self.state, item) record_item(self.state, item)
itype = item["type"]
item_id = extract_numeric_id(item["id"], self.state.last_turn) item_id = extract_numeric_id(item["id"], self.state.last_turn)
match item["type"]:
if itype == "agent_message": case "agent_message":
return False return False
case _:
action_line = format_item_action_line(etype, item) action_line = format_item_action_line(etype, item)
if action_line is not None: if action_line is not None:
self.state.recent_actions.append(attach_id(item_id, action_line)) self.state.recent_actions.append(attach_id(item_id, action_line))
return True return True
if etype == "item.completed": if etype == "item.completed":
completed_line = format_item_completed_line(item) completed_line = format_item_completed_line(item)
if completed_line is not None: if completed_line is not None:
self.state.recent_actions.append(attach_id(item_id, completed_line)) self.state.recent_actions.append(attach_id(item_id, completed_line))
return True return True
return False
case _:
return False return False
def render_progress(self, elapsed_s: float) -> str: def render_progress(self, elapsed_s: float) -> str: