fix(exec-bridge): clean up subprocess on cancellation
This commit is contained in:
@@ -24,7 +24,7 @@ from .exec_render import ExecProgressRenderer, render_event_cli
|
|||||||
from .rendering import render_markdown
|
from .rendering import render_markdown
|
||||||
from .telegram_client import TelegramClient
|
from .telegram_client import TelegramClient
|
||||||
|
|
||||||
logger = logging.getLogger("exec_bridge")
|
logger = logging.getLogger(__name__)
|
||||||
UUID_PATTERN = re.compile(
|
UUID_PATTERN = re.compile(
|
||||||
r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b"
|
r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b"
|
||||||
)
|
)
|
||||||
@@ -54,16 +54,18 @@ async def _drain_stderr(stderr: asyncio.StreamReader | None, tail: deque[str]) -
|
|||||||
|
|
||||||
|
|
||||||
def setup_logging(log_file: str | None) -> None:
|
def setup_logging(log_file: str | None) -> None:
|
||||||
logger.setLevel(logging.DEBUG)
|
root_logger = logging.getLogger()
|
||||||
logger.handlers.clear()
|
root_logger.setLevel(logging.DEBUG)
|
||||||
logger.propagate = False
|
for handler in root_logger.handlers[:]:
|
||||||
|
root_logger.removeHandler(handler)
|
||||||
|
handler.close()
|
||||||
|
|
||||||
fmt = logging.Formatter("%(asctime)s %(message)s")
|
fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
|
||||||
console = logging.StreamHandler(sys.stdout)
|
console = logging.StreamHandler(sys.stdout)
|
||||||
console.setLevel(logging.INFO)
|
console.setLevel(logging.INFO)
|
||||||
console.setFormatter(fmt)
|
console.setFormatter(fmt)
|
||||||
logger.addHandler(console)
|
root_logger.addHandler(console)
|
||||||
|
|
||||||
if log_file:
|
if log_file:
|
||||||
file_handler = RotatingFileHandler(
|
file_handler = RotatingFileHandler(
|
||||||
@@ -74,7 +76,7 @@ def setup_logging(log_file: str | None) -> None:
|
|||||||
)
|
)
|
||||||
file_handler.setLevel(logging.DEBUG)
|
file_handler.setLevel(logging.DEBUG)
|
||||||
file_handler.setFormatter(fmt)
|
file_handler.setFormatter(fmt)
|
||||||
logger.addHandler(file_handler)
|
root_logger.addHandler(file_handler)
|
||||||
logger.debug("[debug] file logger initialized path=%r", log_file)
|
logger.debug("[debug] file logger initialized path=%r", log_file)
|
||||||
|
|
||||||
|
|
||||||
@@ -127,6 +129,16 @@ def truncate_for_telegram(text: str, limit: int) -> str:
|
|||||||
return (head + sep + tail)[:limit]
|
return (head + sep + tail)[:limit]
|
||||||
|
|
||||||
|
|
||||||
|
def render_for_telegram(
|
||||||
|
md: str, *, limit: int
|
||||||
|
) -> tuple[str, list[dict[str, Any]] | None]:
|
||||||
|
rendered, entities = render_markdown(md)
|
||||||
|
if len(rendered) > limit:
|
||||||
|
rendered = truncate_for_telegram(rendered, limit)
|
||||||
|
return rendered, None
|
||||||
|
return rendered, entities or None
|
||||||
|
|
||||||
|
|
||||||
async def _send_markdown(
|
async def _send_markdown(
|
||||||
bot: TelegramClient,
|
bot: TelegramClient,
|
||||||
*,
|
*,
|
||||||
@@ -135,15 +147,12 @@ async def _send_markdown(
|
|||||||
reply_to_message_id: int | None = None,
|
reply_to_message_id: int | None = None,
|
||||||
disable_notification: bool = False,
|
disable_notification: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
rendered, entities = render_markdown(text)
|
rendered, entities = render_for_telegram(text, limit=TELEGRAM_MARKDOWN_LIMIT)
|
||||||
if len(rendered) > TELEGRAM_MARKDOWN_LIMIT:
|
|
||||||
rendered = truncate_for_telegram(rendered, TELEGRAM_MARKDOWN_LIMIT)
|
|
||||||
entities = []
|
|
||||||
|
|
||||||
return await bot.send_message(
|
return await bot.send_message(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
text=rendered,
|
text=rendered,
|
||||||
entities=entities or None,
|
entities=entities,
|
||||||
reply_to_message_id=reply_to_message_id,
|
reply_to_message_id=reply_to_message_id,
|
||||||
disable_notification=disable_notification,
|
disable_notification=disable_notification,
|
||||||
)
|
)
|
||||||
@@ -207,10 +216,6 @@ class CodexExecRunner:
|
|||||||
assert proc.stdin and proc.stdout and proc.stderr
|
assert proc.stdin and proc.stdout and proc.stderr
|
||||||
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
logger.debug("[codex] spawn pid=%s args=%r", proc.pid, args)
|
||||||
|
|
||||||
proc.stdin.write(prompt.encode())
|
|
||||||
await proc.stdin.drain()
|
|
||||||
proc.stdin.close()
|
|
||||||
|
|
||||||
stderr_tail: deque[str] = deque(maxlen=200)
|
stderr_tail: deque[str] = deque(maxlen=200)
|
||||||
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
|
stderr_task = asyncio.create_task(_drain_stderr(proc.stderr, stderr_tail))
|
||||||
|
|
||||||
@@ -219,7 +224,14 @@ class CodexExecRunner:
|
|||||||
saw_agent_message = False
|
saw_agent_message = False
|
||||||
cli_last_turn: int | None = None
|
cli_last_turn: int | None = None
|
||||||
|
|
||||||
|
cancelled = False
|
||||||
|
rc: int | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
proc.stdin.write(prompt.encode())
|
||||||
|
await proc.stdin.drain()
|
||||||
|
proc.stdin.close()
|
||||||
|
|
||||||
async for raw_line in proc.stdout:
|
async for raw_line in proc.stdout:
|
||||||
line = raw_line.decode(errors="replace").strip()
|
line = raw_line.decode(errors="replace").strip()
|
||||||
if not line:
|
if not line:
|
||||||
@@ -252,11 +264,32 @@ class CodexExecRunner:
|
|||||||
last_agent_text = item["text"]
|
last_agent_text = item["text"]
|
||||||
saw_agent_message = True
|
saw_agent_message = True
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
proc.terminate()
|
cancelled = True
|
||||||
raise
|
if proc.returncode is None:
|
||||||
|
proc.terminate()
|
||||||
finally:
|
finally:
|
||||||
rc = await proc.wait()
|
if cancelled:
|
||||||
await stderr_task
|
task = asyncio.current_task()
|
||||||
|
if task is not None:
|
||||||
|
while task.cancelling():
|
||||||
|
task.uncancel()
|
||||||
|
|
||||||
|
try:
|
||||||
|
rc = await asyncio.wait_for(proc.wait(), timeout=2.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.debug(
|
||||||
|
"[codex] terminate timed out pid=%s, sending kill", proc.pid
|
||||||
|
)
|
||||||
|
if proc.returncode is None:
|
||||||
|
proc.kill()
|
||||||
|
rc = await proc.wait()
|
||||||
|
else:
|
||||||
|
rc = await proc.wait()
|
||||||
|
|
||||||
|
await asyncio.gather(stderr_task, return_exceptions=True)
|
||||||
|
|
||||||
|
if cancelled:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
|
logger.debug("[codex] process exit pid=%s rc=%s", proc.pid, rc)
|
||||||
if rc != 0:
|
if rc != 0:
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from codex_telegram_bridge.exec_bridge import extract_session_id, truncate_for_telegram
|
from codex_telegram_bridge.exec_bridge import extract_session_id, truncate_for_telegram
|
||||||
|
|
||||||
@@ -148,3 +151,54 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
|
|||||||
assert len(bot.send_calls) == 2
|
assert len(bot.send_calls) == 2
|
||||||
assert bot.send_calls[0]["disable_notification"] is True
|
assert bot.send_calls[0]["disable_notification"] is True
|
||||||
assert bot.send_calls[1]["disable_notification"] is False
|
assert bot.send_calls[1]["disable_notification"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_runner_cancellation_terminates_subprocess(tmp_path, monkeypatch) -> None:
|
||||||
|
from codex_telegram_bridge.exec_bridge import CodexExecRunner
|
||||||
|
|
||||||
|
pid_file = tmp_path / "pid"
|
||||||
|
codex_path = tmp_path / "codex"
|
||||||
|
codex_path.write_text(
|
||||||
|
"#!/usr/bin/env python3\n"
|
||||||
|
"import os\n"
|
||||||
|
"import time\n"
|
||||||
|
"\n"
|
||||||
|
"pid_file = os.environ.get('CODEX_FAKE_PID_FILE')\n"
|
||||||
|
"if pid_file:\n"
|
||||||
|
" with open(pid_file, 'w', encoding='utf-8') as f:\n"
|
||||||
|
" f.write(str(os.getpid()))\n"
|
||||||
|
" f.flush()\n"
|
||||||
|
"\n"
|
||||||
|
"time.sleep(60)\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
codex_path.chmod(0o755)
|
||||||
|
monkeypatch.setenv("CODEX_FAKE_PID_FILE", str(pid_file))
|
||||||
|
|
||||||
|
runner = CodexExecRunner(codex_cmd=str(codex_path), workspace=None, extra_args=[])
|
||||||
|
|
||||||
|
async def run_and_cancel() -> None:
|
||||||
|
task = asyncio.create_task(runner.run("hello", session_id=None))
|
||||||
|
|
||||||
|
for _ in range(100):
|
||||||
|
if pid_file.exists():
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert pid_file.exists()
|
||||||
|
|
||||||
|
pid = int(pid_file.read_text(encoding="utf-8").strip())
|
||||||
|
|
||||||
|
task.cancel()
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
|
||||||
|
for _ in range(200):
|
||||||
|
try:
|
||||||
|
os.kill(pid, 0)
|
||||||
|
except ProcessLookupError:
|
||||||
|
return
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
raise AssertionError("cancelled codex subprocess is still running")
|
||||||
|
|
||||||
|
asyncio.run(run_and_cancel())
|
||||||
|
|||||||
Reference in New Issue
Block a user