refactor: drop redundant defensive checks

This commit is contained in:
banteg
2025-12-29 18:17:27 +04:00
parent f3f0a1fea3
commit 4243556d61
2 changed files with 14 additions and 17 deletions
+11 -13
View File
@@ -12,7 +12,7 @@ from collections import deque
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any, cast
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
import typer import typer
@@ -49,9 +49,7 @@ def resolve_resume_session(
return extract_session_id(text) or extract_session_id(reply_text) return extract_session_id(text) or extract_session_id(reply_text)
async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -> None: async def _drain_stderr(stderr: asyncio.StreamReader, tail: deque[str]) -> None:
if stderr is None:
return
try: try:
while True: while True:
line = await stderr.readline() line = await stderr.readline()
@@ -235,11 +233,13 @@ class CodexExecRunner:
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
) as proc: ) as proc:
assert proc.stdin and proc.stdout and proc.stderr proc_stdin = cast(asyncio.StreamWriter, proc.stdin)
proc_stdout = cast(asyncio.StreamReader, proc.stdout)
proc_stderr = cast(asyncio.StreamReader, proc.stderr)
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args) logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
stderr_tail: deque[str] = deque(maxlen=200) stderr_tail: deque[str] = deque(maxlen=200)
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail)) stderr_task = asyncio.create_task(_drain_stderr(proc_stderr, stderr_tail))
found_session: str | None = session_id found_session: str | None = session_id
last_agent_text: str | None = None last_agent_text: str | None = None
@@ -250,11 +250,11 @@ class CodexExecRunner:
rc: int | None = None rc: int | None = None
try: try:
proc.stdin.write(prompt.encode()) proc_stdin.write(prompt.encode())
await proc.stdin.drain() await proc_stdin.drain()
proc.stdin.close() proc_stdin.close()
async for raw_line in proc.stdout: async for raw_line in proc_stdout:
raw = raw_line.decode(errors="replace") raw = raw_line.decode(errors="replace")
logger.debug("[codex][jsonl] %s", raw.rstrip("\n")) logger.debug("[codex][jsonl] %s", raw.rstrip("\n"))
line = raw.strip() line = raw.strip()
@@ -292,8 +292,7 @@ class CodexExecRunner:
cancelled = True cancelled = True
finally: finally:
if cancelled: if cancelled:
task = asyncio.current_task() task = cast(asyncio.Task, asyncio.current_task())
if task is not None:
while task.cancelling(): while task.cancelling():
task.uncancel() task.uncancel()
if not cancelled: if not cancelled:
@@ -555,7 +554,6 @@ async def _handle_message(
if edit_task is not None: if edit_task is not None:
await asyncio.gather(edit_task, return_exceptions=True) await asyncio.gather(edit_task, return_exceptions=True)
answer = answer or "(No agent_message captured from JSON stream.)"
elapsed = clock() - started_at elapsed = clock() - started_at
status = "done" if saw_agent_message else "error" status = "done" if saw_agent_message else "error"
final_md = ( final_md = (
+1 -2
View File
@@ -10,7 +10,7 @@ _md = MarkdownIt("commonmark", {"html": False})
def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]: def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]:
html = _md.render(md or "") html = _md.render(md)
rendered = transform_html(html) rendered = transform_html(html)
text = re.sub(r"(?m)^(\s*)•", r"\1-", rendered.text) text = re.sub(r"(?m)^(\s*)•", r"\1-", rendered.text)
@@ -23,4 +23,3 @@ def render_markdown(md: str) -> tuple[str, list[dict[str, Any]]]:
d.pop("language", None) d.pop("language", None)
entities.append(d) entities.append(d)
return text, entities return text, entities