Fix websocket replay tests
This commit is contained in:
@@ -366,11 +366,13 @@ class LocalServer:
|
|||||||
self._websocket_connections[route_key] = ws
|
self._websocket_connections[route_key] = ws
|
||||||
|
|
||||||
session_id = self.session_manager.routes.get(RouteKey(route_key))
|
session_id = self.session_manager.routes.get(RouteKey(route_key))
|
||||||
|
session = None
|
||||||
if session_id is not None:
|
if session_id is not None:
|
||||||
session = self.session_manager.get_session(session_id)
|
session = self.session_manager.get_session(session_id)
|
||||||
if session is None or not session.is_running():
|
if session is None or not session.is_running():
|
||||||
self.session_manager.on_session_end(session_id)
|
self.session_manager.on_session_end(session_id)
|
||||||
session_id = None
|
session_id = None
|
||||||
|
session = None
|
||||||
else:
|
else:
|
||||||
# Force terminal redraw on reconnect to avoid blank screen
|
# Force terminal redraw on reconnect to avoid blank screen
|
||||||
if hasattr(session, 'force_redraw'):
|
if hasattr(session, 'force_redraw'):
|
||||||
@@ -378,8 +380,14 @@ class LocalServer:
|
|||||||
if hasattr(session, 'send_bytes'):
|
if hasattr(session, 'send_bytes'):
|
||||||
await session.send_bytes(CLEAR_AND_REDRAW_SEQ.encode('utf-8'))
|
await session.send_bytes(CLEAR_AND_REDRAW_SEQ.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
session_created = session_id is not None
|
session_created = session_id is not None
|
||||||
|
|
||||||
|
if session_created and session is not None and hasattr(session, 'get_replay_buffer'):
|
||||||
|
replay = await session.get_replay_buffer()
|
||||||
|
if replay:
|
||||||
|
await ws.send_bytes(replay)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
if msg.type == WSMsgType.TEXT:
|
if msg.type == WSMsgType.TEXT:
|
||||||
@@ -853,12 +861,14 @@ class LocalServer:
|
|||||||
return web.Response(text=html_content, content_type="text/html")
|
return web.Response(text=html_content, content_type="text/html")
|
||||||
|
|
||||||
async def handle_session_data(self, route_key: RouteKey, data: bytes) -> None:
|
async def handle_session_data(self, route_key: RouteKey, data: bytes) -> None:
|
||||||
|
self.mark_route_activity(str(route_key))
|
||||||
ws = self._websocket_connections.get(route_key)
|
ws = self._websocket_connections.get(route_key)
|
||||||
if ws is None:
|
if ws is None:
|
||||||
return
|
return
|
||||||
await ws.send_bytes(data)
|
await ws.send_bytes(data)
|
||||||
|
|
||||||
async def handle_binary_message(self, route_key: RouteKey, payload: bytes) -> None:
|
async def handle_binary_message(self, route_key: RouteKey, payload: bytes) -> None:
|
||||||
|
self.mark_route_activity(str(route_key))
|
||||||
ws = self._websocket_connections.get(route_key)
|
ws = self._websocket_connections.get(route_key)
|
||||||
if ws is None:
|
if ws is None:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -694,7 +694,17 @@ class TestLocalServerMoreCoverage:
|
|||||||
assert queue.get_nowait() == "existing"
|
assert queue.get_nowait() == "existing"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mark_route_activity_triggers_notification(self, server_with_no_apps):
|
async def test_handle_session_data_marks_activity(self, server_with_no_apps, monkeypatch):
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_bytes = AsyncMock()
|
||||||
|
server_with_no_apps._websocket_connections["rk"] = ws
|
||||||
|
server_with_no_apps._route_last_activity["rk"] = 0.0
|
||||||
|
|
||||||
|
await server_with_no_apps.handle_session_data("rk", b"data")
|
||||||
|
assert server_with_no_apps._route_last_activity["rk"] > 0.0
|
||||||
|
ws.send_bytes.assert_awaited_once_with(b"data")
|
||||||
|
|
||||||
|
def test_mark_route_activity_triggers_notification(self, server_with_no_apps):
|
||||||
"""Test that mark_route_activity triggers SSE notification."""
|
"""Test that mark_route_activity triggers SSE notification."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from aiohttp import WSMsgType, web
|
from aiohttp import WSMsgType, web
|
||||||
@@ -59,9 +60,16 @@ async def test_websocket_creates_session_on_resize(tmp_path):
|
|||||||
server.session_manager.routes["test"] = "sid"
|
server.session_manager.routes["test"] = "sid"
|
||||||
server.session_manager.sessions["sid"] = DummySession()
|
server.session_manager.sessions["sid"] = DummySession()
|
||||||
|
|
||||||
|
# Replay buffer should be sent on reconnect
|
||||||
|
replay_session = server.session_manager.sessions["sid"]
|
||||||
|
replay_session.get_replay_buffer = AsyncMock(return_value=b"replay")
|
||||||
|
|
||||||
client = await _make_client(server)
|
client = await _make_client(server)
|
||||||
try:
|
try:
|
||||||
ws = await client.ws_connect("/ws/test")
|
ws = await client.ws_connect("/ws/test")
|
||||||
|
msg = await ws.receive(timeout=1)
|
||||||
|
assert msg.type == WSMsgType.BINARY
|
||||||
|
assert msg.data == b"replay"
|
||||||
await ws.close()
|
await ws.close()
|
||||||
finally:
|
finally:
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user