refactor: migrate exec bridge to anyio and harden cancellation (#6)

This commit is contained in:
banteg
2025-12-31 01:51:46 +04:00
committed by GitHub
parent 6687a435c9
commit 8eda3f5e84
9 changed files with 492 additions and 310 deletions
+152 -111
View File
@@ -1,5 +1,4 @@
import asyncio
import anyio
import pytest
from takopi.exec_bridge import (
@@ -187,7 +186,7 @@ class _FakeClock:
def __init__(self, start: float = 0.0) -> None:
self._now = start
self._sleep_until: float | None = None
self._sleep_event: asyncio.Event | None = None
self._sleep_event: anyio.Event | None = None
self.sleep_calls = 0
def __call__(self) -> float:
@@ -205,10 +204,10 @@ class _FakeClock:
async def sleep(self, delay: float) -> None:
self.sleep_calls += 1
if delay <= 0:
await asyncio.sleep(0)
await anyio.sleep(0)
return
self._sleep_until = self._now + delay
self._sleep_event = asyncio.Event()
self._sleep_event = anyio.Event()
await self._sleep_event.wait()
@@ -222,7 +221,7 @@ class _FakeRunnerWithEvents:
answer: str = "ok",
session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2",
advance_after: float | None = None,
hold: asyncio.Event | None = None,
hold: anyio.Event | None = None,
) -> None:
self._events = events
self._times = times
@@ -238,16 +237,17 @@ class _FakeRunnerWithEvents:
for when, event in zip(self._times, self._events, strict=False):
self._clock.set(when)
await on_event(event)
await asyncio.sleep(0)
await anyio.sleep(0)
if self._advance_after is not None:
self._clock.set(self._advance_after)
await asyncio.sleep(0)
await anyio.sleep(0)
if self._hold is not None:
await self._hold.wait()
return (self._session_id, self._answer, True)
def test_final_notify_sends_loud_final_message() -> None:
@pytest.mark.anyio
async def test_final_notify_sends_loud_final_message() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
@@ -261,14 +261,12 @@ def test_final_notify_sends_loud_final_message() -> None:
max_concurrency=1,
)
asyncio.run(
handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="hi",
resume_session=None,
)
await handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="hi",
resume_session=None,
)
assert len(bot.send_calls) == 2
@@ -276,7 +274,8 @@ def test_final_notify_sends_loud_final_message() -> None:
assert bot.send_calls[1]["disable_notification"] is False
def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
@pytest.mark.anyio
async def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
@@ -290,14 +289,12 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
max_concurrency=1,
)
asyncio.run(
handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="hi",
resume_session=None,
)
await handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="hi",
resume_session=None,
)
assert len(bot.send_calls) == 2
@@ -305,7 +302,8 @@ def test_new_final_message_forces_notification_when_too_long_to_edit() -> None:
assert bot.send_calls[1]["disable_notification"] is False
def test_progress_edits_are_rate_limited() -> None:
@pytest.mark.anyio
async def test_progress_edits_are_rate_limited() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
@@ -345,29 +343,28 @@ def test_progress_edits_are_rate_limited() -> None:
max_concurrency=1,
)
asyncio.run(
handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="hi",
resume_session=None,
clock=clock,
sleep=clock.sleep,
progress_edit_every=1.0,
)
await handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="hi",
resume_session=None,
clock=clock,
sleep=clock.sleep,
progress_edit_every=1.0,
)
assert len(bot.edit_calls) == 1
assert "echo 2" in bot.edit_calls[0]["text"]
def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
@pytest.mark.anyio
async def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
clock = _FakeClock()
hold = asyncio.Event()
hold = anyio.Event()
events = [
{
"type": "item.started",
@@ -404,24 +401,25 @@ def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
max_concurrency=1,
)
async def run_test() -> None:
task = asyncio.create_task(
handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="hi",
resume_session=None,
clock=clock,
sleep=clock.sleep,
progress_edit_every=1.0,
)
async def run_handle_message() -> None:
await handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="hi",
resume_session=None,
clock=clock,
sleep=clock.sleep,
progress_edit_every=1.0,
)
async with anyio.create_task_group() as tg:
tg.start_soon(run_handle_message)
for _ in range(100):
if clock._sleep_until is not None:
break
await asyncio.sleep(0)
await anyio.sleep(0)
assert clock._sleep_until == pytest.approx(1.0)
@@ -430,23 +428,21 @@ def test_progress_edits_do_not_sleep_again_without_new_events() -> None:
for _ in range(100):
if bot.edit_calls:
break
await asyncio.sleep(0)
await anyio.sleep(0)
assert len(bot.edit_calls) == 1
for _ in range(5):
await asyncio.sleep(0)
await anyio.sleep(0)
assert clock.sleep_calls == 1
assert clock._sleep_until is None
hold.set()
await task
asyncio.run(run_test())
def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
@pytest.mark.anyio
async def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
@@ -489,17 +485,15 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
max_concurrency=1,
)
asyncio.run(
handle_message(
cfg,
chat_id=123,
user_msg_id=42,
text="do it",
resume_session=None,
clock=clock,
sleep=clock.sleep,
progress_edit_every=1.0,
)
await handle_message(
cfg,
chat_id=123,
user_msg_id=42,
text="do it",
resume_session=None,
clock=clock,
sleep=clock.sleep,
progress_edit_every=1.0,
)
assert bot.send_calls[0]["reply_to_message_id"] == 42
@@ -510,7 +504,8 @@ def test_bridge_flow_sends_progress_edits_and_final_resume() -> None:
assert len(bot.delete_calls) == 1
def test_handle_cancel_without_reply_prompts_user() -> None:
@pytest.mark.anyio
async def test_handle_cancel_without_reply_prompts_user() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
@@ -526,13 +521,14 @@ def test_handle_cancel_without_reply_prompts_user() -> None:
msg = {"chat": {"id": 123}, "message_id": 10}
running_tasks: dict = {}
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
await _handle_cancel(cfg, msg, running_tasks)
assert len(bot.send_calls) == 1
assert "reply to the progress message" in bot.send_calls[0]["text"]
def test_handle_cancel_with_no_session_id_says_nothing_running() -> None:
@pytest.mark.anyio
async def test_handle_cancel_with_no_progress_message_says_nothing_running() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
@@ -548,17 +544,18 @@ def test_handle_cancel_with_no_session_id_says_nothing_running() -> None:
msg = {
"chat": {"id": 123},
"message_id": 10,
"reply_to_message": {"text": "no uuid here"},
"reply_to_message": {"text": "no message id"},
}
running_tasks: dict = {}
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
await _handle_cancel(cfg, msg, running_tasks)
assert len(bot.send_calls) == 1
assert "nothing is currently running" in bot.send_calls[0]["text"]
def test_handle_cancel_with_finished_task_says_nothing_running() -> None:
@pytest.mark.anyio
async def test_handle_cancel_with_finished_task_says_nothing_running() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
@@ -571,21 +568,22 @@ def test_handle_cancel_with_finished_task_says_nothing_running() -> None:
startup_msg="",
max_concurrency=1,
)
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
progress_id = 99
msg = {
"chat": {"id": 123},
"message_id": 10,
"reply_to_message": {"text": f"resume: `{session_id}`"},
"reply_to_message": {"message_id": progress_id},
}
running_tasks: dict = {} # Session not in running_tasks
running_tasks: dict = {} # Progress message not in running_tasks
asyncio.run(_handle_cancel(cfg, msg, running_tasks))
await _handle_cancel(cfg, msg, running_tasks)
assert len(bot.send_calls) == 1
assert "nothing is currently running" in bot.send_calls[0]["text"]
def test_handle_cancel_cancels_running_task() -> None:
@pytest.mark.anyio
async def test_handle_cancel_cancels_running_task() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
@@ -598,29 +596,70 @@ def test_handle_cancel_cancels_running_task() -> None:
startup_msg="",
max_concurrency=1,
)
session_id = "019b66fc-64c2-7a71-81cd-081c504cfeb2"
progress_id = 42
msg = {
"chat": {"id": 123},
"message_id": 10,
"reply_to_message": {"text": f"resume: `{session_id}`"},
"reply_to_message": {"message_id": progress_id},
}
async def run_test():
task = asyncio.create_task(asyncio.sleep(10))
running_tasks = {session_id: task}
from takopi.exec_bridge import RunningTask
cancelled_event = anyio.Event()
cancel_scope = anyio.CancelScope()
running_task = RunningTask(scope=cancel_scope)
async def sleeper() -> None:
with cancel_scope:
try:
await anyio.sleep(10)
except anyio.get_cancelled_exc_class():
cancelled_event.set()
return
async with anyio.create_task_group() as tg:
tg.start_soon(sleeper)
running_tasks = {progress_id: running_task}
await _handle_cancel(cfg, msg, running_tasks)
try:
await task
except asyncio.CancelledError:
return True
return False
await cancelled_event.wait()
cancelled = asyncio.run(run_test())
assert cancelled is True
assert len(bot.send_calls) == 0 # No error message sent
@pytest.mark.anyio
async def test_handle_cancel_only_cancels_matching_progress_message() -> None:
from takopi.exec_bridge import BridgeConfig, _handle_cancel
bot = _FakeBot()
runner = _FakeRunner(answer="ok")
cfg = BridgeConfig(
bot=bot, # type: ignore[arg-type]
runner=runner, # type: ignore[arg-type]
chat_id=123,
final_notify=True,
startup_msg="",
max_concurrency=1,
)
from takopi.exec_bridge import RunningTask
scope_first = anyio.CancelScope()
scope_second = anyio.CancelScope()
task_first = RunningTask(scope=scope_first)
task_second = RunningTask(scope=scope_second)
msg = {
"chat": {"id": 123},
"message_id": 10,
"reply_to_message": {"message_id": 1},
}
running_tasks = {1: task_first, 2: task_second}
await _handle_cancel(cfg, msg, running_tasks)
assert scope_first.cancel_called is True
assert scope_second.cancel_called is False
assert len(bot.send_calls) == 0
class _FakeRunnerCancellable:
def __init__(self, session_id: str = "019b66fc-64c2-7a71-81cd-081c504cfeb2"):
self._session_id = session_id
@@ -629,11 +668,12 @@ class _FakeRunnerCancellable:
on_event = kwargs.get("on_event")
if on_event:
await on_event({"type": "thread.started", "thread_id": self._session_id})
await asyncio.sleep(10) # Will be cancelled
await anyio.sleep(10) # Will be cancelled
return (self._session_id, "ok", True)
def test_handle_message_cancelled_renders_cancelled_state() -> None:
@pytest.mark.anyio
async def test_handle_message_cancelled_renders_cancelled_state() -> None:
from takopi.exec_bridge import BridgeConfig, handle_message
bot = _FakeBot()
@@ -649,23 +689,24 @@ def test_handle_message_cancelled_renders_cancelled_state() -> None:
)
running_tasks: dict = {}
async def run_test():
task = asyncio.create_task(
handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="do something",
resume_session=None,
running_tasks=running_tasks,
)
async def run_handle_message() -> None:
await handle_message(
cfg,
chat_id=123,
user_msg_id=10,
text="do something",
resume_session=None,
running_tasks=running_tasks,
)
await asyncio.sleep(0.01) # Let task start and register
assert session_id in running_tasks
running_tasks[session_id].cancel()
await task
asyncio.run(run_test())
async with anyio.create_task_group() as tg:
tg.start_soon(run_handle_message)
for _ in range(100):
if running_tasks:
break
await anyio.sleep(0)
assert running_tasks
running_tasks[next(iter(running_tasks))].scope.cancel()
assert len(bot.send_calls) == 1 # Progress message
assert len(bot.edit_calls) >= 1