fix(telegram): harden file transfer handling (#84)

This commit is contained in:
banteg
2026-01-11 06:05:38 +04:00
committed by GitHub
parent 1c6b7d7b21
commit 2380b3e5e9
5 changed files with 176 additions and 30 deletions
+84 -5
View File
@@ -54,6 +54,7 @@ from .files import (
resolve_path_within_root, resolve_path_within_root,
split_command_args, split_command_args,
write_bytes_atomic, write_bytes_atomic,
ZipTooLargeError,
zip_directory, zip_directory,
) )
from .types import ( from .types import (
@@ -297,7 +298,9 @@ def _format_ctx_status(
return "\n".join(lines) return "\n".join(lines)
def _build_bot_commands(runtime: TransportRuntime) -> list[dict[str, str]]: def _build_bot_commands(
runtime: TransportRuntime, *, include_file: bool = True
) -> list[dict[str, str]]:
commands: list[dict[str, str]] = [] commands: list[dict[str, str]] = []
seen: set[str] = set() seen: set[str] = set()
for engine_id in runtime.available_engine_ids(): for engine_id in runtime.available_engine_ids():
@@ -345,7 +348,7 @@ def _build_bot_commands(runtime: TransportRuntime) -> list[dict[str, str]]:
description = backend.description or f"command: {cmd}" description = backend.description or f"command: {cmd}"
commands.append({"command": cmd, "description": description}) commands.append({"command": cmd, "description": description})
seen.add(cmd) seen.add(cmd)
if "file" not in seen: if include_file and "file" not in seen:
commands.append({"command": "file", "description": "upload or fetch files"}) commands.append({"command": "file", "description": "upload or fetch files"})
seen.add("file") seen.add("file")
if "cancel" not in seen: if "cancel" not in seen:
@@ -400,7 +403,7 @@ def _diff_keys(old: dict[str, object], new: dict[str, object]) -> list[str]:
async def _set_command_menu(cfg: TelegramBridgeConfig) -> None: async def _set_command_menu(cfg: TelegramBridgeConfig) -> None:
commands = _build_bot_commands(cfg.runtime) commands = _build_bot_commands(cfg.runtime, include_file=cfg.files.enabled)
if not commands: if not commands:
return return
try: try:
@@ -1205,7 +1208,49 @@ async def _handle_file_put(
if plan is None: if plan is None:
return return
rel_path: Path | None = None rel_path: Path | None = None
base_dir: Path | None = None
if plan.path_value: if plan.path_value:
if plan.path_value.endswith("/"):
base_dir = normalize_relative_path(plan.path_value)
if base_dir is None:
await _send_plain(
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
text="invalid upload path.",
thread_id=msg.thread_id,
)
return
deny_rule = deny_reason(base_dir, cfg.files.deny_globs)
if deny_rule is not None:
await _send_plain(
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
text=f"path denied by rule: {deny_rule}",
thread_id=msg.thread_id,
)
return
base_target = resolve_path_within_root(plan.run_root, base_dir)
if base_target is None:
await _send_plain(
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
text="upload path escapes the repo root.",
thread_id=msg.thread_id,
)
return
if base_target.exists() and not base_target.is_dir():
await _send_plain(
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
text="upload path is a file.",
thread_id=msg.thread_id,
)
return
else:
rel_path = normalize_relative_path(plan.path_value) rel_path = normalize_relative_path(plan.path_value)
if rel_path is None: if rel_path is None:
await _send_plain( await _send_plain(
@@ -1221,7 +1266,7 @@ async def _handle_file_put(
document=document, document=document,
run_root=plan.run_root, run_root=plan.run_root,
rel_path=rel_path, rel_path=rel_path,
base_dir=None, base_dir=base_dir,
force=plan.force, force=plan.force,
) )
if result.error is not None: if result.error is not None:
@@ -1575,9 +1620,34 @@ async def _handle_file_get(
payload: bytes payload: bytes
filename: str filename: str
if target.is_dir(): if target.is_dir():
payload = zip_directory(run_root, rel_path, cfg.files.deny_globs) try:
payload = zip_directory(
run_root,
rel_path,
cfg.files.deny_globs,
max_bytes=cfg.files.max_download_bytes,
)
except ZipTooLargeError:
await _send_plain(
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
text="file is too large to send.",
thread_id=msg.thread_id,
)
return
except OSError as exc:
await _send_plain(
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
text=f"failed to read directory: {exc}",
thread_id=msg.thread_id,
)
return
filename = f"{rel_path.name or 'archive'}.zip" filename = f"{rel_path.name or 'archive'}.zip"
else: else:
try:
size = target.stat().st_size size = target.stat().st_size
if size > cfg.files.max_download_bytes: if size > cfg.files.max_download_bytes:
await _send_plain( await _send_plain(
@@ -1589,6 +1659,15 @@ async def _handle_file_get(
) )
return return
payload = target.read_bytes() payload = target.read_bytes()
except OSError as exc:
await _send_plain(
cfg.exec_cfg.transport,
chat_id=msg.chat_id,
user_msg_id=msg.message_id,
text=f"failed to read file: {exc}",
thread_id=msg.thread_id,
)
return
filename = target.name filename = target.name
if len(payload) > cfg.files.max_download_bytes: if len(payload) > cfg.files.max_download_bytes:
await _send_plain( await _send_plain(
+7 -1
View File
@@ -101,6 +101,12 @@ def _parse_incoming_message(
text = caption text = caption
if text is None: if text is None:
text = "" text = ""
file_command = False
if isinstance(text, str):
stripped = text.lstrip()
if stripped.startswith("/"):
token = stripped.split(maxsplit=1)[0]
file_command = token.startswith("/file")
voice_payload: TelegramVoice | None = None voice_payload: TelegramVoice | None = None
voice = msg.get("voice") voice = msg.get("voice")
if isinstance(voice, dict): if isinstance(voice, dict):
@@ -159,7 +165,7 @@ def _parse_incoming_message(
best = item best = item
if best is not None: if best is not None:
document_payload = _parse_document_payload(best) document_payload = _parse_document_payload(best)
if document_payload is None: if document_payload is None and file_command:
sticker = msg.get("sticker") sticker = msg.get("sticker")
if isinstance(sticker, dict): if isinstance(sticker, dict):
document_payload = _parse_document_payload(sticker) document_payload = _parse_document_payload(sticker)
+20 -3
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import io import io
import os
import shlex import shlex
import tempfile import tempfile
import zipfile import zipfile
@@ -135,19 +136,35 @@ def write_bytes_atomic(path: Path, payload: bytes) -> None:
Path(temp_name).replace(path) Path(temp_name).replace(path)
class ZipTooLargeError(Exception):
pass
def zip_directory( def zip_directory(
root: Path, root: Path,
rel_path: Path, rel_path: Path,
deny_globs: Sequence[str], deny_globs: Sequence[str],
*,
max_bytes: int | None = None,
) -> bytes: ) -> bytes:
target = root / rel_path target = root / rel_path
buffer = io.BytesIO() buffer = io.BytesIO()
with zipfile.ZipFile(buffer, "w", compression=zipfile.ZIP_DEFLATED) as archive: with zipfile.ZipFile(buffer, "w", compression=zipfile.ZIP_DEFLATED) as archive:
for item in sorted(target.rglob("*")): for dirpath, _, filenames in os.walk(target, followlinks=False):
if item.is_dir(): dir_path = Path(dirpath)
for filename in filenames:
item = dir_path / filename
if item.is_symlink():
continue
if not item.is_file():
continue continue
rel_item = rel_path / item.relative_to(target) rel_item = rel_path / item.relative_to(target)
if deny_reason(rel_item, deny_globs) is not None: if deny_reason(rel_item, deny_globs) is not None:
continue continue
archive.write(item, arcname=rel_item.as_posix()) archive.write(item, arcname=rel_item.as_posix())
return buffer.getvalue() if max_bytes is not None and buffer.tell() > max_bytes:
raise ZipTooLargeError()
payload = buffer.getvalue()
if max_bytes is not None and len(payload) > max_bytes:
raise ZipTooLargeError()
return payload
+43
View File
@@ -0,0 +1,43 @@
from __future__ import annotations
import io
import zipfile
from pathlib import Path
import pytest
from takopi.telegram.files import ZipTooLargeError, zip_directory
def test_zip_directory_skips_symlinks(tmp_path: Path) -> None:
root = tmp_path / "root"
root.mkdir()
target = root / "dir"
target.mkdir()
(target / "safe.txt").write_text("ok", encoding="utf-8")
outside = tmp_path / "secret.txt"
outside.write_text("secret", encoding="utf-8")
link_path = target / "leak.txt"
try:
link_path.symlink_to(outside)
except (OSError, NotImplementedError):
pytest.skip("symlinks not supported")
payload = zip_directory(root, Path("dir"), deny_globs=())
with zipfile.ZipFile(io.BytesIO(payload)) as archive:
names = set(archive.namelist())
assert "dir/safe.txt" in names
assert "dir/leak.txt" not in names
def test_zip_directory_limits_size(tmp_path: Path) -> None:
root = tmp_path / "root"
root.mkdir()
target = root / "dir"
target.mkdir()
(target / "data.bin").write_bytes(b"x" * 1024)
with pytest.raises(ZipTooLargeError):
zip_directory(root, Path("dir"), deny_globs=(), max_bytes=10)
+2 -1
View File
@@ -211,6 +211,7 @@ def test_parse_incoming_update_sticker_message() -> None:
"update_id": 1, "update_id": 1,
"message": { "message": {
"message_id": 10, "message_id": 10,
"caption": "/file put incoming/sticker.webp",
"chat": {"id": 123}, "chat": {"id": 123},
"sticker": { "sticker": {
"file_id": "sticker-id", "file_id": "sticker-id",
@@ -223,7 +224,7 @@ def test_parse_incoming_update_sticker_message() -> None:
msg = parse_incoming_update(update, chat_id=123) msg = parse_incoming_update(update, chat_id=123)
assert msg is not None assert msg is not None
assert isinstance(msg, TelegramIncomingMessage) assert isinstance(msg, TelegramIncomingMessage)
assert msg.text == "" assert msg.text == "/file put incoming/sticker.webp"
assert msg.document is not None assert msg.document is not None
assert msg.document.file_id == "sticker-id" assert msg.document.file_id == "sticker-id"
assert msg.document.file_name is None assert msg.document.file_name is None