fix(telegram): harden file transfer handling (#84)
This commit is contained in:
@@ -54,6 +54,7 @@ from .files import (
|
||||
resolve_path_within_root,
|
||||
split_command_args,
|
||||
write_bytes_atomic,
|
||||
ZipTooLargeError,
|
||||
zip_directory,
|
||||
)
|
||||
from .types import (
|
||||
@@ -297,7 +298,9 @@ def _format_ctx_status(
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _build_bot_commands(runtime: TransportRuntime) -> list[dict[str, str]]:
|
||||
def _build_bot_commands(
|
||||
runtime: TransportRuntime, *, include_file: bool = True
|
||||
) -> list[dict[str, str]]:
|
||||
commands: list[dict[str, str]] = []
|
||||
seen: set[str] = set()
|
||||
for engine_id in runtime.available_engine_ids():
|
||||
@@ -345,7 +348,7 @@ def _build_bot_commands(runtime: TransportRuntime) -> list[dict[str, str]]:
|
||||
description = backend.description or f"command: {cmd}"
|
||||
commands.append({"command": cmd, "description": description})
|
||||
seen.add(cmd)
|
||||
if "file" not in seen:
|
||||
if include_file and "file" not in seen:
|
||||
commands.append({"command": "file", "description": "upload or fetch files"})
|
||||
seen.add("file")
|
||||
if "cancel" not in seen:
|
||||
@@ -400,7 +403,7 @@ def _diff_keys(old: dict[str, object], new: dict[str, object]) -> list[str]:
|
||||
|
||||
|
||||
async def _set_command_menu(cfg: TelegramBridgeConfig) -> None:
|
||||
commands = _build_bot_commands(cfg.runtime)
|
||||
commands = _build_bot_commands(cfg.runtime, include_file=cfg.files.enabled)
|
||||
if not commands:
|
||||
return
|
||||
try:
|
||||
@@ -1205,23 +1208,65 @@ async def _handle_file_put(
|
||||
if plan is None:
|
||||
return
|
||||
rel_path: Path | None = None
|
||||
base_dir: Path | None = None
|
||||
if plan.path_value:
|
||||
rel_path = normalize_relative_path(plan.path_value)
|
||||
if rel_path is None:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
text="invalid upload path.",
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
return
|
||||
if plan.path_value.endswith("/"):
|
||||
base_dir = normalize_relative_path(plan.path_value)
|
||||
if base_dir is None:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
text="invalid upload path.",
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
return
|
||||
deny_rule = deny_reason(base_dir, cfg.files.deny_globs)
|
||||
if deny_rule is not None:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
text=f"path denied by rule: {deny_rule}",
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
return
|
||||
base_target = resolve_path_within_root(plan.run_root, base_dir)
|
||||
if base_target is None:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
text="upload path escapes the repo root.",
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
return
|
||||
if base_target.exists() and not base_target.is_dir():
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
text="upload path is a file.",
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
return
|
||||
else:
|
||||
rel_path = normalize_relative_path(plan.path_value)
|
||||
if rel_path is None:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
text="invalid upload path.",
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
return
|
||||
result = await _save_document_payload(
|
||||
cfg,
|
||||
document=document,
|
||||
run_root=plan.run_root,
|
||||
rel_path=rel_path,
|
||||
base_dir=None,
|
||||
base_dir=base_dir,
|
||||
force=plan.force,
|
||||
)
|
||||
if result.error is not None:
|
||||
@@ -1575,11 +1620,14 @@ async def _handle_file_get(
|
||||
payload: bytes
|
||||
filename: str
|
||||
if target.is_dir():
|
||||
payload = zip_directory(run_root, rel_path, cfg.files.deny_globs)
|
||||
filename = f"{rel_path.name or 'archive'}.zip"
|
||||
else:
|
||||
size = target.stat().st_size
|
||||
if size > cfg.files.max_download_bytes:
|
||||
try:
|
||||
payload = zip_directory(
|
||||
run_root,
|
||||
rel_path,
|
||||
cfg.files.deny_globs,
|
||||
max_bytes=cfg.files.max_download_bytes,
|
||||
)
|
||||
except ZipTooLargeError:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
@@ -1588,7 +1636,38 @@ async def _handle_file_get(
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
return
|
||||
payload = target.read_bytes()
|
||||
except OSError as exc:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
text=f"failed to read directory: {exc}",
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
return
|
||||
filename = f"{rel_path.name or 'archive'}.zip"
|
||||
else:
|
||||
try:
|
||||
size = target.stat().st_size
|
||||
if size > cfg.files.max_download_bytes:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
text="file is too large to send.",
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
return
|
||||
payload = target.read_bytes()
|
||||
except OSError as exc:
|
||||
await _send_plain(
|
||||
cfg.exec_cfg.transport,
|
||||
chat_id=msg.chat_id,
|
||||
user_msg_id=msg.message_id,
|
||||
text=f"failed to read file: {exc}",
|
||||
thread_id=msg.thread_id,
|
||||
)
|
||||
return
|
||||
filename = target.name
|
||||
if len(payload) > cfg.files.max_download_bytes:
|
||||
await _send_plain(
|
||||
|
||||
@@ -101,6 +101,12 @@ def _parse_incoming_message(
|
||||
text = caption
|
||||
if text is None:
|
||||
text = ""
|
||||
file_command = False
|
||||
if isinstance(text, str):
|
||||
stripped = text.lstrip()
|
||||
if stripped.startswith("/"):
|
||||
token = stripped.split(maxsplit=1)[0]
|
||||
file_command = token.startswith("/file")
|
||||
voice_payload: TelegramVoice | None = None
|
||||
voice = msg.get("voice")
|
||||
if isinstance(voice, dict):
|
||||
@@ -159,7 +165,7 @@ def _parse_incoming_message(
|
||||
best = item
|
||||
if best is not None:
|
||||
document_payload = _parse_document_payload(best)
|
||||
if document_payload is None:
|
||||
if document_payload is None and file_command:
|
||||
sticker = msg.get("sticker")
|
||||
if isinstance(sticker, dict):
|
||||
document_payload = _parse_document_payload(sticker)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import shlex
|
||||
import tempfile
|
||||
import zipfile
|
||||
@@ -135,19 +136,35 @@ def write_bytes_atomic(path: Path, payload: bytes) -> None:
|
||||
Path(temp_name).replace(path)
|
||||
|
||||
|
||||
class ZipTooLargeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def zip_directory(
|
||||
root: Path,
|
||||
rel_path: Path,
|
||||
deny_globs: Sequence[str],
|
||||
*,
|
||||
max_bytes: int | None = None,
|
||||
) -> bytes:
|
||||
target = root / rel_path
|
||||
buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(buffer, "w", compression=zipfile.ZIP_DEFLATED) as archive:
|
||||
for item in sorted(target.rglob("*")):
|
||||
if item.is_dir():
|
||||
continue
|
||||
rel_item = rel_path / item.relative_to(target)
|
||||
if deny_reason(rel_item, deny_globs) is not None:
|
||||
continue
|
||||
archive.write(item, arcname=rel_item.as_posix())
|
||||
return buffer.getvalue()
|
||||
for dirpath, _, filenames in os.walk(target, followlinks=False):
|
||||
dir_path = Path(dirpath)
|
||||
for filename in filenames:
|
||||
item = dir_path / filename
|
||||
if item.is_symlink():
|
||||
continue
|
||||
if not item.is_file():
|
||||
continue
|
||||
rel_item = rel_path / item.relative_to(target)
|
||||
if deny_reason(rel_item, deny_globs) is not None:
|
||||
continue
|
||||
archive.write(item, arcname=rel_item.as_posix())
|
||||
if max_bytes is not None and buffer.tell() > max_bytes:
|
||||
raise ZipTooLargeError()
|
||||
payload = buffer.getvalue()
|
||||
if max_bytes is not None and len(payload) > max_bytes:
|
||||
raise ZipTooLargeError()
|
||||
return payload
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from takopi.telegram.files import ZipTooLargeError, zip_directory
|
||||
|
||||
|
||||
def test_zip_directory_skips_symlinks(tmp_path: Path) -> None:
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
target = root / "dir"
|
||||
target.mkdir()
|
||||
(target / "safe.txt").write_text("ok", encoding="utf-8")
|
||||
outside = tmp_path / "secret.txt"
|
||||
outside.write_text("secret", encoding="utf-8")
|
||||
link_path = target / "leak.txt"
|
||||
try:
|
||||
link_path.symlink_to(outside)
|
||||
except (OSError, NotImplementedError):
|
||||
pytest.skip("symlinks not supported")
|
||||
|
||||
payload = zip_directory(root, Path("dir"), deny_globs=())
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(payload)) as archive:
|
||||
names = set(archive.namelist())
|
||||
|
||||
assert "dir/safe.txt" in names
|
||||
assert "dir/leak.txt" not in names
|
||||
|
||||
|
||||
def test_zip_directory_limits_size(tmp_path: Path) -> None:
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
target = root / "dir"
|
||||
target.mkdir()
|
||||
(target / "data.bin").write_bytes(b"x" * 1024)
|
||||
|
||||
with pytest.raises(ZipTooLargeError):
|
||||
zip_directory(root, Path("dir"), deny_globs=(), max_bytes=10)
|
||||
@@ -211,6 +211,7 @@ def test_parse_incoming_update_sticker_message() -> None:
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"message_id": 10,
|
||||
"caption": "/file put incoming/sticker.webp",
|
||||
"chat": {"id": 123},
|
||||
"sticker": {
|
||||
"file_id": "sticker-id",
|
||||
@@ -223,7 +224,7 @@ def test_parse_incoming_update_sticker_message() -> None:
|
||||
msg = parse_incoming_update(update, chat_id=123)
|
||||
assert msg is not None
|
||||
assert isinstance(msg, TelegramIncomingMessage)
|
||||
assert msg.text == ""
|
||||
assert msg.text == "/file put incoming/sticker.webp"
|
||||
assert msg.document is not None
|
||||
assert msg.document.file_id == "sticker-id"
|
||||
assert msg.document.file_name is None
|
||||
|
||||
Reference in New Issue
Block a user