Refactor test fixtures and parametrization
This commit is contained in:
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -48,6 +49,46 @@ def tmp_config_path(tmp_path: Path) -> Path:
|
|||||||
return tmp_path / "config"
|
return tmp_path / "config"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_request() -> MagicMock:
|
||||||
|
"""Create a mock request with common attributes."""
|
||||||
|
request = MagicMock()
|
||||||
|
request.headers = {}
|
||||||
|
request.secure = False
|
||||||
|
request.query = {}
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def screen_buffer_factory():
|
||||||
|
def _make(rows: list[str], width: int = 80):
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"data": c,
|
||||||
|
"fg": "default",
|
||||||
|
"bg": "default",
|
||||||
|
"bold": False,
|
||||||
|
"italics": False,
|
||||||
|
"underscore": False,
|
||||||
|
"reverse": False,
|
||||||
|
}
|
||||||
|
for c in (row + " " * width)[:width]
|
||||||
|
]
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
return _make
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session():
|
||||||
|
session = MagicMock()
|
||||||
|
session.get_screen_has_changes = AsyncMock(return_value=False)
|
||||||
|
session.get_screen_state = AsyncMock(return_value=(80, 24, [], True))
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def poller() -> Poller:
|
def poller() -> Poller:
|
||||||
"""Create a Poller instance."""
|
"""Create a Poller instance."""
|
||||||
|
|||||||
+16
-10
@@ -66,6 +66,20 @@ class TestRenderSparklineSvg:
|
|||||||
class TestDockerStatsCollector:
|
class TestDockerStatsCollector:
|
||||||
"""Tests for Docker stats collector."""
|
"""Tests for Docker stats collector."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cpu_stats_pair(self):
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
"cpu_usage": {"total_usage": 1000000000},
|
||||||
|
"system_cpu_usage": 10000000000,
|
||||||
|
"online_cpus": 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cpu_usage": {"total_usage": 500000000},
|
||||||
|
"system_cpu_usage": 5000000000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_available_checks_socket(self, tmp_path):
|
def test_available_checks_socket(self, tmp_path):
|
||||||
"""available property checks socket existence and connectivity."""
|
"""available property checks socket existence and connectivity."""
|
||||||
socket_path = tmp_path / "docker.sock"
|
socket_path = tmp_path / "docker.sock"
|
||||||
@@ -94,19 +108,11 @@ class TestDockerStatsCollector:
|
|||||||
history = collector.get_cpu_history("test")
|
history = collector.get_cpu_history("test")
|
||||||
assert history == [10.0, 20.0, 30.0]
|
assert history == [10.0, 20.0, 30.0]
|
||||||
|
|
||||||
def test_calculate_cpu_percent(self):
|
def test_calculate_cpu_percent(self, cpu_stats_pair):
|
||||||
"""CPU percentage calculation."""
|
"""CPU percentage calculation."""
|
||||||
collector = DockerStatsCollector("/nonexistent")
|
collector = DockerStatsCollector("/nonexistent")
|
||||||
|
|
||||||
cpu_stats = {
|
cpu_stats, precpu_stats = cpu_stats_pair
|
||||||
"cpu_usage": {"total_usage": 1000000000},
|
|
||||||
"system_cpu_usage": 10000000000,
|
|
||||||
"online_cpus": 4,
|
|
||||||
}
|
|
||||||
precpu_stats = {
|
|
||||||
"cpu_usage": {"total_usage": 500000000},
|
|
||||||
"system_cpu_usage": 5000000000,
|
|
||||||
}
|
|
||||||
|
|
||||||
result = collector._calculate_cpu_percent("test", cpu_stats, precpu_stats)
|
result = collector._calculate_cpu_percent("test", cpu_stats, precpu_stats)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|||||||
+100
-132
@@ -174,20 +174,16 @@ class TestLocalServerHelpers:
|
|||||||
ws.send_json.assert_awaited_once_with(["error", "No app configured"])
|
ws.send_json.assert_awaited_once_with(["error", "No app configured"])
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_screenshot_svg_handler_returns_svg(self, server, monkeypatch, capsys):
|
async def test_screenshot_svg_handler_returns_svg(
|
||||||
request = MagicMock()
|
self, server, monkeypatch, capsys, screen_buffer_factory, mock_session, mock_request
|
||||||
|
):
|
||||||
|
request = mock_request
|
||||||
request.query = {"route_key": "rk"}
|
request.query = {"route_key": "rk"}
|
||||||
|
|
||||||
# Mock screen state: width=80, height=2, buffer with "hello" on first line
|
screen_buffer = screen_buffer_factory(["hello", ""])
|
||||||
screen_buffer = [
|
mock_session.get_screen_state = AsyncMock(return_value=(80, 2, screen_buffer, True))
|
||||||
[{"data": c, "fg": "default", "bg": "default", "bold": False, "italics": False, "underscore": False, "reverse": False} for c in "hello" + " " * 75],
|
|
||||||
[{"data": " ", "fg": "default", "bg": "default", "bold": False, "italics": False, "underscore": False, "reverse": False}] * 80,
|
|
||||||
]
|
|
||||||
session = MagicMock()
|
|
||||||
session.get_screen_has_changes = AsyncMock(return_value=False)
|
|
||||||
session.get_screen_state = AsyncMock(return_value=(80, 2, screen_buffer, True))
|
|
||||||
|
|
||||||
monkeypatch.setattr(server.session_manager, "get_session_by_route_key", lambda _rk: session)
|
monkeypatch.setattr(server.session_manager, "get_session_by_route_key", lambda _rk: mock_session)
|
||||||
|
|
||||||
response = await server._handle_screenshot(request)
|
response = await server._handle_screenshot(request)
|
||||||
assert response.content_type == "image/svg+xml"
|
assert response.content_type == "image/svg+xml"
|
||||||
@@ -198,18 +194,14 @@ class TestLocalServerHelpers:
|
|||||||
assert out.err == ""
|
assert out.err == ""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_screenshot_creates_session_for_known_slug(self, server, monkeypatch):
|
async def test_screenshot_creates_session_for_known_slug(
|
||||||
request = MagicMock()
|
self, server, monkeypatch, screen_buffer_factory, mock_session, mock_request
|
||||||
|
):
|
||||||
|
request = mock_request
|
||||||
request.query = {"route_key": "known"}
|
request.query = {"route_key": "known"}
|
||||||
|
|
||||||
# Mock screen state
|
screen_buffer = screen_buffer_factory(["world", ""])
|
||||||
screen_buffer = [
|
mock_session.get_screen_state = AsyncMock(return_value=(80, 2, screen_buffer, True))
|
||||||
[{"data": c, "fg": "default", "bg": "default", "bold": False, "italics": False, "underscore": False, "reverse": False} for c in "world" + " " * 75],
|
|
||||||
[{"data": " ", "fg": "default", "bg": "default", "bold": False, "italics": False, "underscore": False, "reverse": False}] * 80,
|
|
||||||
]
|
|
||||||
session = MagicMock()
|
|
||||||
session.get_screen_has_changes = AsyncMock(return_value=False)
|
|
||||||
session.get_screen_state = AsyncMock(return_value=(80, 2, screen_buffer, True))
|
|
||||||
|
|
||||||
# Pretend app exists for slug "known"
|
# Pretend app exists for slug "known"
|
||||||
server.session_manager.apps_by_slug["known"] = App(
|
server.session_manager.apps_by_slug["known"] = App(
|
||||||
@@ -230,7 +222,7 @@ class TestLocalServerHelpers:
|
|||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
server.session_manager,
|
server.session_manager,
|
||||||
"get_session_by_route_key",
|
"get_session_by_route_key",
|
||||||
lambda _rk: session if created else None,
|
lambda _rk: mock_session if created else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await server._handle_screenshot(request)
|
response = await server._handle_screenshot(request)
|
||||||
@@ -241,8 +233,8 @@ class TestLocalServerHelpers:
|
|||||||
assert created["called"][1:] == (132, 45)
|
assert created["called"][1:] == (132, 45)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_screenshot_returns_404_for_unknown_slug(self, server, monkeypatch):
|
async def test_screenshot_returns_404_for_unknown_slug(self, server, monkeypatch, mock_request):
|
||||||
request = MagicMock()
|
request = mock_request
|
||||||
request.query = {"route_key": "unknown"}
|
request.query = {"route_key": "unknown"}
|
||||||
|
|
||||||
monkeypatch.setattr(server.session_manager, "get_session_by_route_key", lambda _rk: None)
|
monkeypatch.setattr(server.session_manager, "get_session_by_route_key", lambda _rk: None)
|
||||||
@@ -282,89 +274,72 @@ class TestLocalServerHelpers:
|
|||||||
port=8080,
|
port=8080,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_ws_url_basic(self, server):
|
@pytest.mark.parametrize(
|
||||||
"""Test basic WebSocket URL generation."""
|
("headers", "secure", "expected_parts", "forbidden_parts"),
|
||||||
request = MagicMock()
|
[
|
||||||
request.headers = {"Host": "localhost:8080"}
|
({"Host": "localhost:8080"}, False, ("ws://", "test-route"), ()),
|
||||||
request.secure = False
|
({"Host": "localhost:8080", "X-Forwarded-Proto": "https"}, True, ("wss://",), ()),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"Host": "localhost:8080",
|
||||||
|
"X-Forwarded-Host": "example.com",
|
||||||
|
"X-Forwarded-Proto": "https",
|
||||||
|
},
|
||||||
|
False,
|
||||||
|
("example.com",),
|
||||||
|
(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"Host": "localhost:8080",
|
||||||
|
"X-Forwarded-Host": "example.com",
|
||||||
|
"X-Forwarded-Port": "9000",
|
||||||
|
},
|
||||||
|
False,
|
||||||
|
("9000",),
|
||||||
|
(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"Host": "example.com",
|
||||||
|
"X-Forwarded-Port": "443",
|
||||||
|
"X-Forwarded-Proto": "https",
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
("wss://example.com/ws/test-route",),
|
||||||
|
(":443",),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_ws_url_variants(self, server, mock_request, headers, secure, expected_parts, forbidden_parts):
|
||||||
|
"""Test WebSocket URL generation variants."""
|
||||||
|
request = mock_request
|
||||||
|
request.headers = headers
|
||||||
|
request.secure = secure
|
||||||
|
|
||||||
url = server._get_ws_url_from_request(request, "test-route")
|
url = server._get_ws_url_from_request(request, "test-route")
|
||||||
assert "ws://" in url
|
for part in expected_parts:
|
||||||
assert "test-route" in url
|
assert part in url
|
||||||
|
for part in forbidden_parts:
|
||||||
def test_get_ws_url_secure(self, server):
|
assert part not in url
|
||||||
"""Test secure WebSocket URL generation."""
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers = {"Host": "localhost:8080", "X-Forwarded-Proto": "https"}
|
|
||||||
request.secure = True
|
|
||||||
|
|
||||||
url = server._get_ws_url_from_request(request, "test-route")
|
|
||||||
assert "wss://" in url
|
|
||||||
|
|
||||||
def test_get_ws_url_forwarded_host(self, server):
|
|
||||||
"""Test WebSocket URL with forwarded host."""
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers = {
|
|
||||||
"Host": "localhost:8080",
|
|
||||||
"X-Forwarded-Host": "example.com",
|
|
||||||
"X-Forwarded-Proto": "https",
|
|
||||||
}
|
|
||||||
request.secure = False
|
|
||||||
|
|
||||||
url = server._get_ws_url_from_request(request, "test-route")
|
|
||||||
assert "example.com" in url
|
|
||||||
|
|
||||||
def test_get_ws_url_forwarded_port(self, server):
|
|
||||||
"""Test WebSocket URL with forwarded port."""
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers = {
|
|
||||||
"Host": "localhost:8080",
|
|
||||||
"X-Forwarded-Host": "example.com",
|
|
||||||
"X-Forwarded-Port": "9000",
|
|
||||||
}
|
|
||||||
request.secure = False
|
|
||||||
|
|
||||||
url = server._get_ws_url_from_request(request, "test-route")
|
|
||||||
assert "9000" in url
|
|
||||||
|
|
||||||
def test_get_ws_url_standard_port_omitted(self, server):
|
|
||||||
"""Test that standard ports are omitted from URL."""
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers = {
|
|
||||||
"Host": "example.com",
|
|
||||||
"X-Forwarded-Port": "443",
|
|
||||||
"X-Forwarded-Proto": "https",
|
|
||||||
}
|
|
||||||
request.secure = True
|
|
||||||
|
|
||||||
url = server._get_ws_url_from_request(request, "test-route")
|
|
||||||
# Port 443 should be omitted
|
|
||||||
assert ":443" not in url or url == "wss://example.com/ws/test-route"
|
|
||||||
|
|
||||||
|
|
||||||
class TestWebSocketProtocol:
|
class TestWebSocketProtocol:
|
||||||
"""Tests for WebSocket protocol message formats."""
|
"""Tests for WebSocket protocol message formats."""
|
||||||
|
|
||||||
def test_stdin_message_format(self):
|
@pytest.mark.parametrize(
|
||||||
"""Test stdin message format."""
|
("msg_type", "payload", "assertions"),
|
||||||
msg = ["stdin", "hello"]
|
[
|
||||||
assert msg[0] == "stdin"
|
("stdin", "hello", lambda msg: msg[1] == "hello"),
|
||||||
assert msg[1] == "hello"
|
("resize", {"width": 80, "height": 24}, lambda msg: msg[1]["width"] == 80),
|
||||||
|
("ping", "1234567890", lambda msg: msg[0] == "ping"),
|
||||||
def test_resize_message_format(self):
|
],
|
||||||
"""Test resize message format."""
|
)
|
||||||
msg = ["resize", {"width": 80, "height": 24}]
|
def test_message_format(self, msg_type, payload, assertions):
|
||||||
assert msg[0] == "resize"
|
"""Test message formats."""
|
||||||
assert msg[1]["width"] == 80
|
msg = [msg_type, payload]
|
||||||
assert msg[1]["height"] == 24
|
assert msg[0] == msg_type
|
||||||
|
assert assertions(msg)
|
||||||
def test_ping_pong_format(self):
|
|
||||||
"""Test ping/pong message format."""
|
|
||||||
ping = ["ping", "1234567890"]
|
|
||||||
pong = ["pong", "1234567890"]
|
|
||||||
assert ping[0] == "ping"
|
|
||||||
assert pong[0] == "pong"
|
|
||||||
assert ping[1] == pong[1]
|
|
||||||
|
|
||||||
|
|
||||||
class TestLocalServerMoreCoverage:
|
class TestLocalServerMoreCoverage:
|
||||||
@@ -390,20 +365,20 @@ class TestLocalServerMoreCoverage:
|
|||||||
await server_with_no_apps.handle_session_data("rk", b"data")
|
await server_with_no_apps.handle_session_data("rk", b"data")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_session_data_sends_bytes(self, server_with_no_apps):
|
|
||||||
ws = MagicMock()
|
|
||||||
ws.send_bytes = AsyncMock()
|
|
||||||
server_with_no_apps._websocket_connections["rk"] = ws
|
|
||||||
await server_with_no_apps.handle_session_data("rk", b"data")
|
|
||||||
ws.send_bytes.assert_awaited_once_with(b"data")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_binary_message_sends_bytes(self, server_with_no_apps):
|
@pytest.mark.parametrize(
|
||||||
|
("handler", "payload"),
|
||||||
|
[
|
||||||
|
("handle_session_data", b"data"),
|
||||||
|
("handle_binary_message", b"bin"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_handle_message_sends_bytes(self, server_with_no_apps, handler, payload):
|
||||||
ws = MagicMock()
|
ws = MagicMock()
|
||||||
ws.send_bytes = AsyncMock()
|
ws.send_bytes = AsyncMock()
|
||||||
server_with_no_apps._websocket_connections["rk"] = ws
|
server_with_no_apps._websocket_connections["rk"] = ws
|
||||||
await server_with_no_apps.handle_binary_message("rk", b"bin")
|
await getattr(server_with_no_apps, handler)("rk", payload)
|
||||||
ws.send_bytes.assert_awaited_once_with(b"bin")
|
ws.send_bytes.assert_awaited_once_with(payload)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_session_close_ends_session_and_closes_ws(self, server_with_no_apps, monkeypatch):
|
async def test_handle_session_close_ends_session_and_closes_ws(self, server_with_no_apps, monkeypatch):
|
||||||
@@ -624,16 +599,14 @@ class TestLocalServerMoreCoverage:
|
|||||||
assert created is True
|
assert created is True
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_screenshot_uses_cached_when_no_changes(self, server_with_no_apps, monkeypatch):
|
async def test_handle_screenshot_uses_cached_when_no_changes(
|
||||||
session = MagicMock()
|
self, server_with_no_apps, monkeypatch, mock_request, mock_session
|
||||||
session.get_screen_has_changes = AsyncMock(return_value=False)
|
):
|
||||||
session.get_screen_state = AsyncMock(return_value=(80, 24, [], False))
|
mock_session.get_screen_state = AsyncMock(return_value=(80, 24, [], False))
|
||||||
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session)
|
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: mock_session)
|
||||||
|
|
||||||
request = MagicMock()
|
request = mock_request
|
||||||
request.query = {"route_key": "rk"}
|
request.query = {"route_key": "rk"}
|
||||||
request.headers = {}
|
|
||||||
request.secure = False
|
|
||||||
|
|
||||||
# Seed cache
|
# Seed cache
|
||||||
server_with_no_apps._screenshot_cache["rk"] = (0.0, "<svg></svg>")
|
server_with_no_apps._screenshot_cache["rk"] = (0.0, "<svg></svg>")
|
||||||
@@ -641,31 +614,26 @@ class TestLocalServerMoreCoverage:
|
|||||||
|
|
||||||
resp = await server_with_no_apps._handle_screenshot(request)
|
resp = await server_with_no_apps._handle_screenshot(request)
|
||||||
assert resp.text == "<svg></svg>"
|
assert resp.text == "<svg></svg>"
|
||||||
session.get_screen_state.assert_not_awaited()
|
mock_session.get_screen_state.assert_not_awaited()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_screenshot_uses_screen_state(self, server_with_no_apps, monkeypatch):
|
async def test_handle_screenshot_uses_screen_state(
|
||||||
|
self, server_with_no_apps, monkeypatch, screen_buffer_factory, mock_request, mock_session
|
||||||
|
):
|
||||||
"""Test that screenshot uses get_screen_state for rendering."""
|
"""Test that screenshot uses get_screen_state for rendering."""
|
||||||
request = MagicMock()
|
request = mock_request
|
||||||
request.query = {"route_key": "rk"}
|
request.query = {"route_key": "rk"}
|
||||||
request.headers = {}
|
|
||||||
|
|
||||||
# Mock screen state
|
screen_buffer = screen_buffer_factory(["line1", "line2"])
|
||||||
screen_buffer = [
|
mock_session.get_screen_state = AsyncMock(return_value=(80, 2, screen_buffer, True))
|
||||||
[{"data": c, "fg": "default", "bg": "default", "bold": False, "italics": False, "underscore": False, "reverse": False} for c in "line1" + " " * 75],
|
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: mock_session)
|
||||||
[{"data": c, "fg": "default", "bg": "default", "bold": False, "italics": False, "underscore": False, "reverse": False} for c in "line2" + " " * 75],
|
|
||||||
]
|
|
||||||
session = MagicMock()
|
|
||||||
session.get_screen_has_changes = AsyncMock(return_value=False)
|
|
||||||
session.get_screen_state = AsyncMock(return_value=(80, 2, screen_buffer, True))
|
|
||||||
monkeypatch.setattr(server_with_no_apps.session_manager, "get_session_by_route_key", lambda _rk: session)
|
|
||||||
|
|
||||||
server_with_no_apps._route_last_activity["rk"] = 1.0
|
server_with_no_apps._route_last_activity["rk"] = 1.0
|
||||||
|
|
||||||
resp = await server_with_no_apps._handle_screenshot(request)
|
resp = await server_with_no_apps._handle_screenshot(request)
|
||||||
assert resp.content_type == "image/svg+xml"
|
assert resp.content_type == "image/svg+xml"
|
||||||
assert "<svg" in resp.text
|
assert "<svg" in resp.text
|
||||||
session.get_screen_state.assert_awaited_once()
|
mock_session.get_screen_state.assert_awaited_once()
|
||||||
|
|
||||||
def test_notify_activity_pushes_to_subscribers(self, server_with_no_apps):
|
def test_notify_activity_pushes_to_subscribers(self, server_with_no_apps):
|
||||||
"""Test that activity notifications are pushed to SSE subscribers."""
|
"""Test that activity notifications are pushed to SSE subscribers."""
|
||||||
|
|||||||
+58
-91
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from textual_webterm.svg_exporter import (
|
from textual_webterm.svg_exporter import (
|
||||||
ANSI_COLORS,
|
ANSI_COLORS,
|
||||||
DEFAULT_BG,
|
DEFAULT_BG,
|
||||||
@@ -16,102 +18,67 @@ from textual_webterm.svg_exporter import (
|
|||||||
class TestColorToHex:
|
class TestColorToHex:
|
||||||
"""Tests for _color_to_hex function."""
|
"""Tests for _color_to_hex function."""
|
||||||
|
|
||||||
def test_default_foreground(self) -> None:
|
@pytest.mark.parametrize(
|
||||||
"""Default color returns DEFAULT_FG for foreground."""
|
("color", "is_foreground", "expected"),
|
||||||
assert _color_to_hex("default", is_foreground=True) == DEFAULT_FG
|
[
|
||||||
|
("default", True, DEFAULT_FG),
|
||||||
def test_default_background(self) -> None:
|
("default", False, DEFAULT_BG),
|
||||||
"""Default color returns DEFAULT_BG for background."""
|
("#ff0000", True, "#ff0000"),
|
||||||
assert _color_to_hex("default", is_foreground=False) == DEFAULT_BG
|
("#123456", True, "#123456"),
|
||||||
|
("#AABBCC", True, "#AABBCC"),
|
||||||
def test_hex_color_passthrough(self) -> None:
|
("ff0000", True, "#ff0000"),
|
||||||
"""Hex colors pass through unchanged."""
|
("123456", True, "#123456"),
|
||||||
assert _color_to_hex("#ff0000") == "#ff0000"
|
("AABBCC", True, "#AABBCC"),
|
||||||
assert _color_to_hex("#123456") == "#123456"
|
("ff8700", True, "#ff8700"),
|
||||||
assert _color_to_hex("#AABBCC") == "#AABBCC"
|
("red", True, ANSI_COLORS["red"]),
|
||||||
|
("green", True, ANSI_COLORS["green"]),
|
||||||
def test_hex_color_without_hash(self) -> None:
|
("blue", True, ANSI_COLORS["blue"]),
|
||||||
"""Hex colors without # prefix (pyte's 256-color/truecolor) get # added."""
|
("white", True, ANSI_COLORS["white"]),
|
||||||
assert _color_to_hex("ff0000") == "#ff0000"
|
("black", True, ANSI_COLORS["black"]),
|
||||||
assert _color_to_hex("123456") == "#123456"
|
("brightred", True, ANSI_COLORS["brightred"]),
|
||||||
assert _color_to_hex("AABBCC") == "#AABBCC"
|
("brightgreen", True, ANSI_COLORS["brightgreen"]),
|
||||||
assert _color_to_hex("ff8700") == "#ff8700" # Common 256-color orange
|
("brightblue", True, ANSI_COLORS["brightblue"]),
|
||||||
|
("RED", True, ANSI_COLORS["red"]),
|
||||||
def test_named_colors(self) -> None:
|
("Green", True, ANSI_COLORS["green"]),
|
||||||
"""Named ANSI colors map correctly."""
|
("BRIGHTBLUE", True, ANSI_COLORS["brightblue"]),
|
||||||
assert _color_to_hex("red") == ANSI_COLORS["red"]
|
("unknowncolor", True, DEFAULT_FG),
|
||||||
assert _color_to_hex("green") == ANSI_COLORS["green"]
|
("unknowncolor", False, DEFAULT_BG),
|
||||||
assert _color_to_hex("blue") == ANSI_COLORS["blue"]
|
("rgb(255,0,0)", True, DEFAULT_FG),
|
||||||
assert _color_to_hex("white") == ANSI_COLORS["white"]
|
("rgb(0,255,0)", False, DEFAULT_BG),
|
||||||
assert _color_to_hex("black") == ANSI_COLORS["black"]
|
("gray", True, ANSI_COLORS["gray"]),
|
||||||
|
("grey", True, ANSI_COLORS["grey"]),
|
||||||
def test_bright_colors(self) -> None:
|
("lightgray", True, ANSI_COLORS["lightgray"]),
|
||||||
"""Bright color variants map correctly."""
|
("lightgrey", True, ANSI_COLORS["lightgrey"]),
|
||||||
assert _color_to_hex("brightred") == ANSI_COLORS["brightred"]
|
],
|
||||||
assert _color_to_hex("brightgreen") == ANSI_COLORS["brightgreen"]
|
)
|
||||||
assert _color_to_hex("brightblue") == ANSI_COLORS["brightblue"]
|
def test_color_to_hex(self, color: str, is_foreground: bool, expected: str) -> None:
|
||||||
|
"""Color conversion covers named/hex/default cases."""
|
||||||
def test_case_insensitive(self) -> None:
|
assert _color_to_hex(color, is_foreground=is_foreground) == expected
|
||||||
"""Color names are case-insensitive."""
|
|
||||||
assert _color_to_hex("RED") == ANSI_COLORS["red"]
|
|
||||||
assert _color_to_hex("Green") == ANSI_COLORS["green"]
|
|
||||||
assert _color_to_hex("BRIGHTBLUE") == ANSI_COLORS["brightblue"]
|
|
||||||
|
|
||||||
def test_unknown_color_returns_default(self) -> None:
|
|
||||||
"""Unknown color names return default."""
|
|
||||||
assert _color_to_hex("unknowncolor", is_foreground=True) == DEFAULT_FG
|
|
||||||
assert _color_to_hex("unknowncolor", is_foreground=False) == DEFAULT_BG
|
|
||||||
|
|
||||||
def test_rgb_format_returns_default(self) -> None:
|
|
||||||
"""RGB format falls back to default (not commonly used in terminals)."""
|
|
||||||
assert _color_to_hex("rgb(255,0,0)", is_foreground=True) == DEFAULT_FG
|
|
||||||
assert _color_to_hex("rgb(0,255,0)", is_foreground=False) == DEFAULT_BG
|
|
||||||
|
|
||||||
def test_gray_aliases(self) -> None:
|
|
||||||
"""Gray/grey aliases work."""
|
|
||||||
assert _color_to_hex("gray") == ANSI_COLORS["gray"]
|
|
||||||
assert _color_to_hex("grey") == ANSI_COLORS["grey"]
|
|
||||||
assert _color_to_hex("lightgray") == ANSI_COLORS["lightgray"]
|
|
||||||
assert _color_to_hex("lightgrey") == ANSI_COLORS["lightgrey"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestEscapeXml:
|
class TestEscapeXml:
|
||||||
"""Tests for XML escaping."""
|
"""Tests for XML escaping."""
|
||||||
|
|
||||||
def test_no_special_chars(self) -> None:
|
@pytest.mark.parametrize(
|
||||||
"""Plain text passes through unchanged."""
|
("input_str", "expected"),
|
||||||
assert _escape_xml("hello world") == "hello world"
|
[
|
||||||
|
("hello world", "hello world"),
|
||||||
def test_less_than(self) -> None:
|
("<", "<"),
|
||||||
"""Less than is escaped."""
|
("a < b", "a < b"),
|
||||||
assert _escape_xml("<") == "<"
|
(">", ">"),
|
||||||
assert _escape_xml("a < b") == "a < b"
|
("a > b", "a > b"),
|
||||||
|
("&", "&"),
|
||||||
def test_greater_than(self) -> None:
|
("a & b", "a & b"),
|
||||||
"""Greater than is escaped."""
|
('"', """),
|
||||||
assert _escape_xml(">") == ">"
|
("'", "'"),
|
||||||
assert _escape_xml("a > b") == "a > b"
|
('<script>"alert"</script>', "<script>"alert"</script>"),
|
||||||
|
("你好世界", "你好世界"),
|
||||||
def test_ampersand(self) -> None:
|
("🎉🚀", "🎉🚀"),
|
||||||
"""Ampersand is escaped."""
|
],
|
||||||
assert _escape_xml("&") == "&"
|
)
|
||||||
assert _escape_xml("a & b") == "a & b"
|
def test_escape_xml(self, input_str: str, expected: str) -> None:
|
||||||
|
"""Escape XML special chars and preserve unicode."""
|
||||||
def test_quotes(self) -> None:
|
assert _escape_xml(input_str) == expected
|
||||||
"""Quotes are escaped."""
|
|
||||||
assert _escape_xml('"') == """
|
|
||||||
assert _escape_xml("'") == "'"
|
|
||||||
|
|
||||||
def test_mixed_special_chars(self) -> None:
|
|
||||||
"""Multiple special chars are all escaped."""
|
|
||||||
assert _escape_xml('<script>"alert"</script>') == (
|
|
||||||
"<script>"alert"</script>"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_unicode_preserved(self) -> None:
|
|
||||||
"""Unicode characters are preserved."""
|
|
||||||
assert _escape_xml("你好世界") == "你好世界"
|
|
||||||
assert _escape_xml("🎉🚀") == "🎉🚀"
|
|
||||||
|
|
||||||
|
|
||||||
class TestRenderTerminalSvg:
|
class TestRenderTerminalSvg:
|
||||||
|
|||||||
Reference in New Issue
Block a user