fix(exec-bridge): clean up subprocess on cancellation
This commit is contained in:
@@ -24,7 +24,7 @@ from .exec_render import ExecProgressRenderer, render_event_cli
|
||||
from .rendering import render_markdown
|
||||
from .telegram_client import TelegramClient
|
||||
|
||||
logger = logging.getLogger("exec_bridge")
|
||||
logger = logging.getLogger(__name__)
|
||||
UUID_PATTERN = re.compile(
|
||||
r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b"
|
||||
)
|
||||
@@ -54,16 +54,18 @@ async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -
|
||||
|
||||
|
||||
def setup_logging(log_file: str | None) -> None:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.handlers.clear()
|
||||
logger.propagate = False
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.DEBUG)
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
handler.close()
|
||||
|
||||
fmt = logging.Formatter("%(asctime)s %(message)s")
|
||||
fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
console.setFormatter(fmt)
|
||||
logger.addHandler(console)
|
||||
root_logger.addHandler(console)
|
||||
|
||||
if log_file:
|
||||
file_handler = RotatingFileHandler(
|
||||
@@ -74,7 +76,7 @@ def setup_logging(log_file: str | None) -> None:
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(fmt)
|
||||
logger.addHandler(file_handler)
|
||||
root_logger.addHandler(file_handler)
|
||||
logger.debug("[debug] file logger initialized path=%r", log_file)
|
||||
|
||||
|
||||
@@ -127,6 +129,16 @@ def truncate_for_telegram(text: str, limit: int) -> str:
|
||||
return (head + sep + tail)[:limit]
|
||||
|
||||
|
||||
def render_for_telegram(
|
||||
md: str, *, limit: int
|
||||
) -> tuple[str, list[dict[str, Any]] | None]:
|
||||
rendered, entities = render_markdown(md)
|
||||
if len(rendered) > limit:
|
||||
rendered = truncate_for_telegram(rendered, limit)
|
||||
return rendered, None
|
||||
return rendered, entities or None
|
||||
|
||||
|
||||
async def _send_markdown(
|
||||
bot: TelegramClient,
|
||||
*,
|
||||
@@ -135,15 +147,12 @@ async def _send_markdown(
|
||||
reply_to_message_id: int | None = None,
|
||||
disable_notification: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
rendered, entities = render_markdown(text)
|
||||
if len(rendered) > TELEGRAM_MARKDOWN_LIMIT:
|
||||
rendered = truncate_for_telegram(rendered, TELEGRAM_MARKDOWN_LIMIT)
|
||||
entities = []
|
||||
rendered, entities = render_for_telegram(text, limit=TELEGRAM_MARKDOWN_LIMIT)
|
||||
|
||||
return await bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=rendered,
|
||||
entities=entities or None,
|
||||
entities=entities,
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
disable_notification=disable_notification,
|
||||
)
|
||||
@@ -207,10 +216,6 @@ class CodexExecRunner:
|
||||
assert proc.stdin and proc.stdout and proc.stderr
|
||||
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
||||
|
||||
proc.stdin.write(prompt.encode())
|
||||
await proc.stdin.drain()
|
||||
proc.stdin.close()
|
||||
|
||||
stderr_tail: deque[str] = deque(maxlen=200)
|
||||
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
|
||||
|
||||
@@ -219,7 +224,14 @@ class CodexExecRunner:
|
||||
saw_agent_message = False
|
||||
cli_last_turn: int | None = None
|
||||
|
||||
cancelled = False
|
||||
rc: int | None = None
|
||||
|
||||
try:
|
||||
proc.stdin.write(prompt.encode())
|
||||
await proc.stdin.drain()
|
||||
proc.stdin.close()
|
||||
|
||||
async for raw_line in proc.stdout:
|
||||
line = raw_line.decode(errors="replace").strip()
|
||||
if not line:
|
||||
@@ -252,11 +264,32 @@ class CodexExecRunner:
|
||||
last_agent_text = item["text"]
|
||||
saw_agent_message = True
|
||||
except asyncio.CancelledError:
|
||||
proc.terminate()
|
||||
raise
|
||||
cancelled = True
|
||||
if proc.returncode is None:
|
||||
proc.terminate()
|
||||
finally:
|
||||
rc = await proc.wait()
|
||||
await stderr_task
|
||||
if cancelled:
|
||||
task = asyncio.current_task()
|
||||
if task is not None:
|
||||
while task.cancelling():
|
||||
task.uncancel()
|
||||
|
||||
try:
|
||||
rc = await asyncio.wait_for(proc.wait(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug(
|
||||
"[codex] terminate timed out pid=%s, sending kill", proc.pid
|
||||
)
|
||||
if proc.returncode is None:
|
||||
proc.kill()
|
||||
rc = await proc.wait()
|
||||
else:
|
||||
rc = await proc.wait()
|
||||
|
||||
await asyncio.gather(stderr_task, return_exceptions=True)
|
||||
|
||||
if cancelled:
|
||||
raise asyncio.CancelledError
|
||||
|
||||
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
|
||||
if rc != 0:
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from codex_telegram_bridge.exec_bridge import extract_session_id, truncate_for_telegram
|
||||
|
||||
@@ -148,3 +151,54 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
|
||||
assert len(bot.send_calls) == 2
|
||||
assert bot.send_calls[0]["disable_notification"] is True
|
||||
assert bot.send_calls[1]["disable_notification"] is False
|
||||
|
||||
|
||||
def test_codex_runner_cancellation_terminates_subprocess(tmp_path, monkeypatch) -> None:
|
||||
from codex_telegram_bridge.exec_bridge import CodexExecRunner
|
||||
|
||||
pid_file = tmp_path / "pid"
|
||||
codex_path = tmp_path / "codex"
|
||||
codex_path.write_text(
|
||||
"#!/usr/bin/env python3\n"
|
||||
"import os\n"
|
||||
"import time\n"
|
||||
"\n"
|
||||
"pid_file = os.environ.get('CODEX_FAKE_PID_FILE')\n"
|
||||
"if pid_file:\n"
|
||||
" with open(pid_file, 'w', encoding='utf-8') as f:\n"
|
||||
" f.write(str(os.getpid()))\n"
|
||||
" f.flush()\n"
|
||||
"\n"
|
||||
"time.sleep(60)\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
codex_path.chmod(0o755)
|
||||
monkeypatch.setenv("CODEX_FAKE_PID_FILE", str(pid_file))
|
||||
|
||||
runner = CodexExecRunner(codex_cmd=str(codex_path), workspace=None, extra_args=[])
|
||||
|
||||
async def run_and_cancel() -> None:
|
||||
task = asyncio.create_task(runner.run("hello", session_id=None))
|
||||
|
||||
for _ in range(100):
|
||||
if pid_file.exists():
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
assert pid_file.exists()
|
||||
|
||||
pid = int(pid_file.read_text(encoding="utf-8").strip())
|
||||
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
for _ in range(200):
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
return
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
raise AssertionError("cancelled codex subprocess is still running")
|
||||
|
||||
asyncio.run(run_and_cancel())
|
||||
|
||||
Reference in New Issue
Block a user