diff --git a/src/takopi/runners/codex.py b/src/takopi/runners/codex.py index 32774d4..4527d42 100644 --- a/src/takopi/runners/codex.py +++ b/src/takopi/runners/codex.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -192,10 +192,53 @@ def _todo_title(summary: _TodoSummary) -> str: return f"todo {summary.done}/{summary.total}: done" +@dataclass(frozen=True, slots=True) +class _AgentMessageSummary: + text: str + phase: str | None + + +def _select_final_answer(agent_messages: list[_AgentMessageSummary]) -> str | None: + for message in reversed(agent_messages): + if message.phase == "final_answer": + return message.text + for message in reversed(agent_messages): + if message.phase in {None, ""}: + return message.text + return None + + def _translate_item_event( phase: ActionPhase, item: codex_schema.ThreadItem, *, factory: EventFactory ) -> list[TakopiEvent]: match item: + case codex_schema.AgentMessageItem( + id=action_id, + text=text, + phase="commentary", + ): + detail = {"phase": "commentary"} + if phase in {"started", "updated"}: + return [ + factory.action( + phase=phase, + action_id=action_id, + kind="note", + title=text, + detail=detail, + ) + ] + if phase == "completed": + return [ + factory.action_completed( + action_id=action_id, + kind="note", + title=text, + detail=detail, + ok=True, + ) + ] + return [] case codex_schema.AgentMessageItem(): return [] case codex_schema.ErrorItem(id=action_id, message=message): @@ -398,6 +441,7 @@ class CodexRunState: factory: EventFactory note_seq: int = 0 final_answer: str | None = None + turn_agent_messages: list[_AgentMessageSummary] = field(default_factory=list) turn_index: int = 0 @@ -532,6 +576,8 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner): case codex_schema.TurnStarted(): action_id = f"turn_{state.turn_index}" state.turn_index += 1 + state.final_answer = None + state.turn_agent_messages.clear() return [ factory.action_started( action_id=action_id, @@ -549,13 +595,16 @@ class CodexRunner(ResumeTokenMixin, JsonlSubprocessRunner): ) ] case codex_schema.ItemCompleted( - item=codex_schema.AgentMessageItem(text=text) + item=codex_schema.AgentMessageItem(text=text, phase=message_phase) ): - if state.final_answer is None: - state.final_answer = text - else: + state.turn_agent_messages.append( + _AgentMessageSummary(text=text, phase=message_phase) + ) + selected = _select_final_answer(state.turn_agent_messages) + if selected is not None: + state.final_answer = selected + if len(state.turn_agent_messages) > 1: logger.debug("codex.multiple_agent_messages") - state.final_answer = text case _: pass diff --git a/src/takopi/schemas/codex.py b/src/takopi/schemas/codex.py index e3f10ce..00a9b08 100644 --- a/src/takopi/schemas/codex.py +++ b/src/takopi/schemas/codex.py @@ -27,6 +27,11 @@ type McpToolCallStatus = Literal[ "completed", "failed", ] +type CollabToolCallStatus = Literal[ + "in_progress", + "completed", + "failed", +] class Usage(msgspec.Struct, kw_only=True): @@ -62,6 +67,7 @@ class StreamError(msgspec.Struct, tag="error", kw_only=True): class AgentMessageItem(msgspec.Struct, tag="agent_message", kw_only=True): id: str text: str + phase: str | None = None class ReasoningItem(msgspec.Struct, tag="reasoning", kw_only=True): @@ -107,6 +113,21 @@ class McpToolCallItem(msgspec.Struct, tag="mcp_tool_call", kw_only=True): status: McpToolCallStatus +class CollabAgentState(msgspec.Struct, kw_only=True): + status: str + message: str | None = None + + +class CollabToolCallItem(msgspec.Struct, tag="collab_tool_call", kw_only=True): + id: str + tool: str | None = None + sender_thread_id: str | None = None + receiver_thread_ids: list[str] = msgspec.field(default_factory=list) + prompt: str | None = None + agents_states: dict[str, CollabAgentState] = msgspec.field(default_factory=dict) + status: CollabToolCallStatus = "in_progress" + + class WebSearchItem(msgspec.Struct, tag="web_search", kw_only=True): id: str query: str @@ -127,15 +148,23 @@ class TodoListItem(msgspec.Struct, tag="todo_list", kw_only=True): items: list[TodoItem] +class UnknownItem(msgspec.Struct, tag="unknown_item", kw_only=True): + id: str + item_type: str + payload: dict[str, Any] = msgspec.field(default_factory=dict) + + type ThreadItem = ( AgentMessageItem | ReasoningItem | CommandExecutionItem | FileChangeItem | McpToolCallItem + | CollabToolCallItem | WebSearchItem | TodoListItem | ErrorItem + | UnknownItem ) @@ -151,6 +180,9 @@ class ItemCompleted(msgspec.Struct, tag="item.completed", kw_only=True): item: ThreadItem +type ItemEvent = ItemStarted | ItemUpdated | ItemCompleted + + type ThreadEvent = ( ThreadStarted | TurnStarted @@ -163,7 +195,57 @@ type ThreadEvent = ( ) _DECODER = msgspec.json.Decoder(ThreadEvent) +_RAW_OBJECT_DECODER = msgspec.json.Decoder(dict[str, Any]) +_KNOWN_ITEM_TYPES = { + "agent_message", + "reasoning", + "command_execution", + "file_change", + "mcp_tool_call", + "collab_tool_call", + "web_search", + "todo_list", + "error", +} + + +def _decode_unknown_item_fallback(data: bytes | str) -> ItemEvent | None: + payload = _RAW_OBJECT_DECODER.decode(data) + event_type = payload.get("type") + if event_type not in {"item.started", "item.updated", "item.completed"}: + return None + + item = payload.get("item") + if not isinstance(item, dict): + return None + + item_type = item.get("type") + if not isinstance(item_type, str) or item_type in _KNOWN_ITEM_TYPES: + return None + + item_id = item.get("id") + if not isinstance(item_id, str): + return None + + unknown_item = UnknownItem( + id=item_id, + item_type=item_type, + payload={ + str(key): value for key, value in item.items() if key not in {"id", "type"} + }, + ) + if event_type == "item.started": + return ItemStarted(item=unknown_item) + if event_type == "item.updated": + return ItemUpdated(item=unknown_item) + return ItemCompleted(item=unknown_item) def decode_event(data: bytes | str) -> ThreadEvent: - return _DECODER.decode(data) + try: + return _DECODER.decode(data) + except msgspec.DecodeError: + fallback = _decode_unknown_item_fallback(data) + if fallback is not None: + return fallback + raise diff --git a/tests/fixtures/codex_exec_json_phase_and_unknown.jsonl b/tests/fixtures/codex_exec_json_phase_and_unknown.jsonl new file mode 100644 index 0000000..1f61c31 --- /dev/null +++ b/tests/fixtures/codex_exec_json_phase_and_unknown.jsonl @@ -0,0 +1,8 @@ +{"type":"thread.started","thread_id":"0199a213-81c0-7800-8aa1-bbab2a035a53"} +{"type":"turn.started"} +{"type":"item.completed","item":{"id":"item_1","type":"agent_message","phase":"commentary","text":"Inspecting repository state."}} +{"type":"item.completed","item":{"id":"item_2","type":"agent_message","phase":"final_answer","text":"Implemented the requested changes."}} +{"type":"item.started","item":{"id":"item_3","type":"collab_tool_call","tool":"spawn_agent","sender_thread_id":"main-thread","receiver_thread_ids":["worker-thread"],"prompt":"Find failing tests","agents_states":{},"status":"in_progress"}} +{"type":"item.completed","item":{"id":"item_3","type":"collab_tool_call","tool":"spawn_agent","sender_thread_id":"main-thread","receiver_thread_ids":["worker-thread"],"prompt":"Find failing tests","agents_states":{"worker-thread":{"status":"completed","message":"done"}},"status":"completed"}} +{"type":"item.completed","item":{"id":"item_4","type":"future_item","foo":"bar","count":2}} +{"type":"turn.completed","usage":{"input_tokens":10,"cached_input_tokens":0,"output_tokens":5}} diff --git a/tests/test_codex_runner_helpers.py b/tests/test_codex_runner_helpers.py index 83af546..920ee46 100644 --- a/tests/test_codex_runner_helpers.py +++ b/tests/test_codex_runner_helpers.py @@ -9,10 +9,12 @@ from takopi.config import ConfigError from takopi.events import EventFactory from takopi.model import ActionEvent, CompletedEvent, StartedEvent from takopi.runners.codex import ( + _AgentMessageSummary, CodexRunner, _format_change_summary, _normalize_change_list, _parse_reconnect_message, + _select_final_answer, _short_tool_name, _summarize_todo_list, _summarize_tool_result, @@ -73,6 +75,39 @@ def test_summarize_todo_list_and_title() -> None: assert _todo_title(_summarize_todo_list("nope")) == "todo" +def test_select_final_answer() -> None: + assert ( + _select_final_answer( + [ + _AgentMessageSummary(text="working", phase="commentary"), + _AgentMessageSummary(text="done", phase="final_answer"), + ] + ) + == "done" + ) + + assert ( + _select_final_answer( + [ + _AgentMessageSummary(text="first", phase=None), + _AgentMessageSummary(text="second", phase=None), + ] + ) + == "second" + ) + + assert ( + _select_final_answer([_AgentMessageSummary(text="working", phase="commentary")]) + is None + ) + assert ( + _select_final_answer( + [_AgentMessageSummary(text="intermediate", phase="foobar")] + ) + is None + ) + + def test_translate_codex_events_for_items() -> None: factory = EventFactory("codex") event = codex_schema.ItemStarted( @@ -100,6 +135,20 @@ def test_translate_codex_events_for_items() -> None: assert out[0].action.kind == "note" assert out[0].action.title == "thinking" + event = codex_schema.ItemCompleted( + item=codex_schema.AgentMessageItem( + id="m1", + text="working", + phase="commentary", + ) + ) + out = translate_codex_event(event, title="Codex", factory=factory) + assert isinstance(out[0], ActionEvent) + assert out[0].action.kind == "note" + assert out[0].action.title == "working" + assert out[0].phase == "completed" + assert out[0].ok is True + event = codex_schema.ItemUpdated( item=codex_schema.TodoListItem( id="t1", diff --git a/tests/test_codex_schema.py b/tests/test_codex_schema.py index ad0c7fa..2456097 100644 --- a/tests/test_codex_schema.py +++ b/tests/test_codex_schema.py @@ -36,9 +36,21 @@ def _decode_fixture(name: str) -> list[str]: "fixture", [ "codex_exec_json_all_formats.jsonl", + "codex_exec_json_phase_and_unknown.jsonl", ], ) def test_codex_schema_parses_fixture(fixture: str) -> None: errors = _decode_fixture(fixture) assert not errors, f"{fixture} had {len(errors)} errors: " + "; ".join(errors[:5]) + + +def test_codex_schema_decodes_unknown_item_type() -> None: + event = codex_schema.decode_event( + '{"type":"item.completed","item":{"id":"item_99","type":"future_item",' + '"foo":"bar","count":2}}' + ) + assert isinstance(event, codex_schema.ItemCompleted) + assert isinstance(event.item, codex_schema.UnknownItem) + assert event.item.item_type == "future_item" + assert event.item.payload == {"foo": "bar", "count": 2} diff --git a/tests/test_exec_runner.py b/tests/test_exec_runner.py index bff08ae..82bb09d 100644 --- a/tests/test_exec_runner.py +++ b/tests/test_exec_runner.py @@ -347,6 +347,125 @@ async def test_codex_runner_reconnect_notice_updates_phase(tmp_path) -> None: assert isinstance(seen[3], CompletedEvent) +@pytest.mark.anyio +async def test_codex_runner_prefers_final_answer_phase(tmp_path) -> None: + thread_id = "019b73c4-0c3f-7701-a0bb-aac6b4d8a3bc" + + codex_path = tmp_path / "codex" + codex_path.write_text( + "#!/usr/bin/env python3\n" + "import json\n" + "import sys\n" + "\n" + "sys.stdin.read()\n" + f"print(json.dumps({{'type': 'thread.started', 'thread_id': '{thread_id}'}}), flush=True)\n" + "print(json.dumps({'type': 'turn.started'}), flush=True)\n" + "print(json.dumps({'type': 'item.completed', 'item': {'id': 'item_0', 'type': 'agent_message', 'phase': 'commentary', 'text': 'Working through the task.'}}), flush=True)\n" + "print(json.dumps({'type': 'item.completed', 'item': {'id': 'item_1', 'type': 'agent_message', 'phase': 'final_answer', 'text': 'Done.'}}), flush=True)\n" + "print(json.dumps({'type': 'turn.completed', 'usage': {'input_tokens': 1, 'cached_input_tokens': 0, 'output_tokens': 1}}), flush=True)\n", + encoding="utf-8", + ) + codex_path.chmod(0o755) + + runner = CodexRunner(codex_cmd=str(codex_path), extra_args=[]) + seen = [evt async for evt in runner.run("hi", None)] + + assert len(seen) == 4 + assert isinstance(seen[0], StartedEvent) + assert isinstance(seen[1], ActionEvent) + assert seen[1].action.kind == "turn" + assert isinstance(seen[2], ActionEvent) + assert seen[2].action.kind == "note" + assert seen[2].action.title == "Working through the task." + assert seen[2].phase == "completed" + assert seen[2].ok is True + assert isinstance(seen[3], CompletedEvent) + assert seen[3].answer == "Done." + + +@pytest.mark.anyio +async def test_codex_runner_legacy_agent_message_no_phase(tmp_path) -> None: + thread_id = "019b73c4-0c3f-7701-a0bb-aac6b4d8a3bc" + + codex_path = tmp_path / "codex" + codex_path.write_text( + "#!/usr/bin/env python3\n" + "import json\n" + "import sys\n" + "\n" + "sys.stdin.read()\n" + f"print(json.dumps({{'type': 'thread.started', 'thread_id': '{thread_id}'}}), flush=True)\n" + "print(json.dumps({'type': 'turn.started'}), flush=True)\n" + "print(json.dumps({'type': 'item.completed', 'item': {'id': 'item_0', 'type': 'agent_message', 'text': 'first'}}), flush=True)\n" + "print(json.dumps({'type': 'item.completed', 'item': {'id': 'item_1', 'type': 'agent_message', 'text': 'second'}}), flush=True)\n" + "print(json.dumps({'type': 'turn.completed', 'usage': {'input_tokens': 1, 'cached_input_tokens': 0, 'output_tokens': 1}}), flush=True)\n", + encoding="utf-8", + ) + codex_path.chmod(0o755) + + runner = CodexRunner(codex_cmd=str(codex_path), extra_args=[]) + seen = [evt async for evt in runner.run("hi", None)] + + completed = next(evt for evt in seen if isinstance(evt, CompletedEvent)) + assert completed.answer == "second" + + +@pytest.mark.anyio +async def test_codex_runner_collab_tool_call_does_not_break_stream(tmp_path) -> None: + thread_id = "019b73c4-0c3f-7701-a0bb-aac6b4d8a3bc" + + codex_path = tmp_path / "codex" + codex_path.write_text( + "#!/usr/bin/env python3\n" + "import json\n" + "import sys\n" + "\n" + "sys.stdin.read()\n" + f"print(json.dumps({{'type': 'thread.started', 'thread_id': '{thread_id}'}}), flush=True)\n" + "print(json.dumps({'type': 'turn.started'}), flush=True)\n" + "print(json.dumps({'type': 'item.started', 'item': {'id': 'item_0', 'type': 'collab_tool_call', 'tool': 'spawn_agent', 'sender_thread_id': 'main', 'receiver_thread_ids': ['worker'], 'prompt': 'check tests', 'agents_states': {}, 'status': 'in_progress'}}), flush=True)\n" + "print(json.dumps({'type': 'item.completed', 'item': {'id': 'item_0', 'type': 'collab_tool_call', 'tool': 'spawn_agent', 'sender_thread_id': 'main', 'receiver_thread_ids': ['worker'], 'prompt': 'check tests', 'agents_states': {'worker': {'status': 'completed', 'message': 'ok'}}, 'status': 'completed'}}), flush=True)\n" + "print(json.dumps({'type': 'item.completed', 'item': {'id': 'item_1', 'type': 'agent_message', 'text': 'ok'}}), flush=True)\n" + "print(json.dumps({'type': 'turn.completed', 'usage': {'input_tokens': 1, 'cached_input_tokens': 0, 'output_tokens': 1}}), flush=True)\n", + encoding="utf-8", + ) + codex_path.chmod(0o755) + + runner = CodexRunner(codex_cmd=str(codex_path), extra_args=[]) + seen = [evt async for evt in runner.run("hi", None)] + + completed = next(evt for evt in seen if isinstance(evt, CompletedEvent)) + assert completed.answer == "ok" + + +@pytest.mark.anyio +async def test_codex_runner_unknown_item_type_does_not_break_stream(tmp_path) -> None: + thread_id = "019b73c4-0c3f-7701-a0bb-aac6b4d8a3bc" + + codex_path = tmp_path / "codex" + codex_path.write_text( + "#!/usr/bin/env python3\n" + "import json\n" + "import sys\n" + "\n" + "sys.stdin.read()\n" + f"print(json.dumps({{'type': 'thread.started', 'thread_id': '{thread_id}'}}), flush=True)\n" + "print(json.dumps({'type': 'turn.started'}), flush=True)\n" + "print(json.dumps({'type': 'item.completed', 'item': {'id': 'item_0', 'type': 'future_item', 'foo': 'bar'}}), flush=True)\n" + "print(json.dumps({'type': 'item.completed', 'item': {'id': 'item_1', 'type': 'agent_message', 'text': 'ok'}}), flush=True)\n" + "print(json.dumps({'type': 'turn.completed', 'usage': {'input_tokens': 1, 'cached_input_tokens': 0, 'output_tokens': 1}}), flush=True)\n", + encoding="utf-8", + ) + codex_path.chmod(0o755) + + runner = CodexRunner(codex_cmd=str(codex_path), extra_args=[]) + seen = [evt async for evt in runner.run("hi", None)] + + completed = next(evt for evt in seen if isinstance(evt, CompletedEvent)) + assert completed.ok is True + assert completed.answer == "ok" + + @pytest.mark.anyio async def test_codex_runner_includes_stderr_reason(tmp_path) -> None: codex_path = tmp_path / "codex"