Files
takopi/tests/test_telegram_voice.py

353 lines
9.8 KiB
Python

from __future__ import annotations
import pytest
from takopi.telegram.api_models import (
Chat,
ChatMember,
File,
ForumTopic,
Message,
Update,
User,
)
from takopi.telegram.client import BotClient
from takopi.telegram.types import TelegramIncomingMessage, TelegramVoice
from takopi.telegram.voice import VOICE_TRANSCRIPTION_DISABLED_HINT, transcribe_voice
class _Bot(BotClient):
def __init__(self, *, file_info: File | None, audio: bytes | None) -> None:
self._file_info = file_info
self._audio = audio
async def close(self) -> None:
return None
async def get_updates(
self,
offset: int | None,
timeout_s: int = 50,
allowed_updates: list[str] | None = None,
) -> list[Update] | None:
_ = offset, timeout_s, allowed_updates
return []
async def get_file(self, file_id: str) -> File | None:
_ = file_id
return self._file_info
async def download_file(self, file_path: str) -> bytes | None:
_ = file_path
return self._audio
async def send_message(
self,
chat_id: int,
text: str,
reply_to_message_id: int | None = None,
disable_notification: bool | None = False,
message_thread_id: int | None = None,
entities: list[dict] | None = None,
parse_mode: str | None = None,
reply_markup: dict | None = None,
*,
replace_message_id: int | None = None,
) -> Message | None:
_ = (
chat_id,
text,
reply_to_message_id,
disable_notification,
message_thread_id,
entities,
parse_mode,
reply_markup,
replace_message_id,
)
raise AssertionError("send_message should not be called")
async def send_document(
self,
chat_id: int,
filename: str,
content: bytes,
reply_to_message_id: int | None = None,
message_thread_id: int | None = None,
disable_notification: bool | None = False,
caption: str | None = None,
) -> Message | None:
_ = (
chat_id,
filename,
content,
reply_to_message_id,
message_thread_id,
disable_notification,
caption,
)
raise AssertionError("send_document should not be called")
async def edit_message_text(
self,
chat_id: int,
message_id: int,
text: str,
entities: list[dict] | None = None,
parse_mode: str | None = None,
reply_markup: dict | None = None,
*,
wait: bool = True,
) -> Message | None:
_ = (
chat_id,
message_id,
text,
entities,
parse_mode,
reply_markup,
wait,
)
raise AssertionError("edit_message_text should not be called")
async def delete_message(self, chat_id: int, message_id: int) -> bool:
_ = chat_id, message_id
raise AssertionError("delete_message should not be called")
async def set_my_commands(
self,
commands: list[dict],
*,
scope: dict | None = None,
language_code: str | None = None,
) -> bool:
_ = commands, scope, language_code
raise AssertionError("set_my_commands should not be called")
async def get_me(self) -> User | None:
raise AssertionError("get_me should not be called")
async def answer_callback_query(
self,
callback_query_id: str,
text: str | None = None,
show_alert: bool | None = None,
) -> bool:
_ = callback_query_id, text, show_alert
raise AssertionError("answer_callback_query should not be called")
async def get_chat(self, chat_id: int) -> Chat | None:
_ = chat_id
raise AssertionError("get_chat should not be called")
async def get_chat_member(self, chat_id: int, user_id: int) -> ChatMember | None:
_ = chat_id, user_id
raise AssertionError("get_chat_member should not be called")
async def create_forum_topic(self, chat_id: int, name: str) -> ForumTopic | None:
_ = chat_id, name
raise AssertionError("create_forum_topic should not be called")
async def edit_forum_topic(
self, chat_id: int, message_thread_id: int, name: str
) -> bool:
_ = chat_id, message_thread_id, name
raise AssertionError("edit_forum_topic should not be called")
def _voice_message(*, file_size: int = 123) -> TelegramIncomingMessage:
voice = TelegramVoice(
file_id="voice-id",
mime_type="audio/ogg",
file_size=file_size,
duration=1,
raw={},
)
return TelegramIncomingMessage(
transport="telegram",
chat_id=1,
message_id=1,
text="",
reply_to_message_id=None,
reply_to_text=None,
sender_id=1,
voice=voice,
raw={},
)
class _Transcriber:
def __init__(self, *, result: str | None = None, error: Exception | None = None):
self.calls: list[tuple[str, bytes]] = []
self._result = result
self._error = error
async def transcribe(self, *, model: str, audio_bytes: bytes) -> str:
self.calls.append((model, audio_bytes))
if self._error is not None:
raise self._error
assert self._result is not None
return self._result
@pytest.mark.anyio
async def test_transcribe_voice_disabled_replies_with_hint() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
transcriber = _Transcriber(result="should-not-run")
result = await transcribe_voice(
bot=_Bot(file_info=None, audio=None),
msg=_voice_message(),
enabled=False,
model="whisper-1",
reply=reply,
transcriber=transcriber,
)
assert result is None
assert replies[-1] == VOICE_TRANSCRIPTION_DISABLED_HINT
assert transcriber.calls == []
@pytest.mark.anyio
async def test_transcribe_voice_handles_missing_file() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
bot = _Bot(file_info=None, audio=None)
result = await transcribe_voice(
bot=bot,
msg=_voice_message(),
enabled=True,
model="whisper-1",
reply=reply,
)
assert result is None
assert replies[-1] == "failed to fetch voice file."
@pytest.mark.anyio
async def test_transcribe_voice_handles_missing_download() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
bot = _Bot(file_info=File(file_path="voice.ogg"), audio=None)
result = await transcribe_voice(
bot=bot,
msg=_voice_message(),
enabled=True,
model="whisper-1",
reply=reply,
)
assert result is None
assert replies[-1] == "failed to download voice file."
@pytest.mark.anyio
async def test_transcribe_voice_rejects_large_voice_without_downloading() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
class _NoFetchBot(_Bot):
async def get_file(self, file_id: str) -> File | None: # type: ignore[override]
_ = file_id
raise AssertionError("get_file should not be called")
async def download_file(self, file_path: str) -> bytes | None: # type: ignore[override]
_ = file_path
raise AssertionError("download_file should not be called")
bot = _NoFetchBot(file_info=None, audio=None)
result = await transcribe_voice(
bot=bot,
msg=_voice_message(file_size=10_000),
enabled=True,
model="whisper-1",
max_bytes=100,
reply=reply,
)
assert result is None
assert replies[-1] == "voice message is too large to transcribe."
@pytest.mark.anyio
async def test_transcribe_voice_rejects_large_download() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
transcriber = _Transcriber(result="should-not-run")
bot = _Bot(file_info=File(file_path="voice.ogg"), audio=b"x" * 200)
result = await transcribe_voice(
bot=bot,
msg=_voice_message(file_size=10),
enabled=True,
model="whisper-1",
max_bytes=100,
reply=reply,
transcriber=transcriber,
)
assert result is None
assert replies[-1] == "voice message is too large to transcribe."
assert transcriber.calls == []
@pytest.mark.anyio
async def test_transcribe_voice_handles_transcriber_error() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
transcriber = _Transcriber(error=RuntimeError("boom"))
bot = _Bot(file_info=File(file_path="voice.ogg"), audio=b"ok")
result = await transcribe_voice(
bot=bot,
msg=_voice_message(file_size=2),
enabled=True,
model="whisper-1",
reply=reply,
transcriber=transcriber,
)
assert result is None
assert replies[-1] == "boom"
assert transcriber.calls
@pytest.mark.anyio
async def test_transcribe_voice_success() -> None:
replies: list[str] = []
async def reply(**kwargs) -> None:
replies.append(kwargs["text"])
transcriber = _Transcriber(result="transcribed")
bot = _Bot(file_info=File(file_path="voice.ogg"), audio=b"ok")
result = await transcribe_voice(
bot=bot,
msg=_voice_message(file_size=2),
enabled=True,
model="whisper-1",
reply=reply,
transcriber=transcriber,
)
assert result == "transcribed"
assert replies == []
assert transcriber.calls