From e54a8fe0433252808d0a60f6a08a43c9f5a42f3b Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Wed, 31 Jul 2024 14:33:46 +0200 Subject: [PATCH] fix: prevent to ingest local files (by default) (#2010) * feat: prevent to local ingestion (by default) and add white-list * docs: add local ingestion warning * docs: add missing comment * fix: update exception error * fix: black --- fern/docs/pages/manual/ingestion.mdx | 8 +++ private_gpt/settings/settings.py | 25 ++++++++ scripts/ingest_folder.py | 26 ++++++++- settings.yaml | 3 + tests/server/ingest/test_local_ingest.py | 74 ++++++++++++++++++++++++ 5 files changed, 133 insertions(+), 3 deletions(-) create mode 100644 tests/server/ingest/test_local_ingest.py diff --git a/fern/docs/pages/manual/ingestion.mdx b/fern/docs/pages/manual/ingestion.mdx index 9c7032a01..9b10b9548 100644 --- a/fern/docs/pages/manual/ingestion.mdx +++ b/fern/docs/pages/manual/ingestion.mdx @@ -8,6 +8,14 @@ The ingestion of documents can be done in different ways: ## Bulk Local Ingestion +You will need to activate `data.local_ingestion.enabled` in your setting file to use this feature. Additionally, +it is probably a good idea to set `data.local_ingestion.allow_ingest_from` to specify which folders are allowed to be ingested. + + +Be careful enabling this feature in a production environment, as it can be a security risk, as it allows users to +ingest any local file with permissions. + + When you are running PrivateGPT in a fully local setup, you can ingest a complete folder for convenience (containing pdf, text files, etc.) and optionally watch changes on it with the command: diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 7ca6e05ba..8ed7a5a8c 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -59,6 +59,27 @@ class AuthSettings(BaseModel): ) +class IngestionSettings(BaseModel): + """Ingestion configuration. + + This configuration is used to control the ingestion of data into the system + using non-server methods. This is useful for local development and testing; + or to ingest in bulk from a folder. + + Please note that this configuration is not secure and should be used in + a controlled environment only (setting right permissions, etc.). + """ + + enabled: bool = Field( + description="Flag indicating if local ingestion is enabled or not.", + default=False, + ) + allow_ingest_from: list[str] = Field( + description="A list of folders that should be permitted to make ingest requests.", + default=[], + ) + + class ServerSettings(BaseModel): env_name: str = Field( description="Name of the environment (prod, staging, local...)" @@ -74,6 +95,10 @@ class ServerSettings(BaseModel): class DataSettings(BaseModel): + local_ingestion: IngestionSettings = Field( + description="Ingestion configuration", + default_factory=lambda: IngestionSettings(allow_ingest_from=["*"]), + ) local_data_folder: str = Field( description="Path to local storage." "It will be treated as an absolute path if it starts with /" diff --git a/scripts/ingest_folder.py b/scripts/ingest_folder.py index ccda87cc5..b0b3a35b3 100755 --- a/scripts/ingest_folder.py +++ b/scripts/ingest_folder.py @@ -7,12 +7,13 @@ from private_gpt.di import global_injector from private_gpt.server.ingest.ingest_service import IngestService from private_gpt.server.ingest.ingest_watcher import IngestWatcher +from private_gpt.settings.settings import Settings logger = logging.getLogger(__name__) class LocalIngestWorker: - def __init__(self, ingest_service: IngestService) -> None: + def __init__(self, ingest_service: IngestService, setting: Settings) -> None: self.ingest_service = ingest_service self.total_documents = 0 @@ -20,6 +21,24 @@ def __init__(self, ingest_service: IngestService) -> None: self._files_under_root_folder: list[Path] = [] + self.is_local_ingestion_enabled = setting.data.local_ingestion.enabled + self.allowed_local_folders = setting.data.local_ingestion.allow_ingest_from + + def _validate_folder(self, folder_path: Path) -> None: + if not self.is_local_ingestion_enabled: + raise ValueError( + "Local ingestion is disabled." + "You can enable it in settings `ingestion.enabled`" + ) + + # Allow all folders if wildcard is present + if "*" in self.allowed_local_folders: + return + + for allowed_folder in self.allowed_local_folders: + if not folder_path.is_relative_to(allowed_folder): + raise ValueError(f"Folder {folder_path} is not allowed for ingestion") + def _find_all_files_in_folder(self, root_path: Path, ignored: list[str]) -> None: """Search all files under the root folder recursively. @@ -28,6 +47,7 @@ def _find_all_files_in_folder(self, root_path: Path, ignored: list[str]) -> None for file_path in root_path.iterdir(): if file_path.is_file() and file_path.name not in ignored: self.total_documents += 1 + self._validate_folder(file_path) self._files_under_root_folder.append(file_path) elif file_path.is_dir() and file_path.name not in ignored: self._find_all_files_in_folder(file_path, ignored) @@ -92,13 +112,13 @@ def _do_ingest_one(self, changed_path: Path) -> None: logger.addHandler(file_handler) if __name__ == "__main__": - root_path = Path(args.folder) if not root_path.exists(): raise ValueError(f"Path {args.folder} does not exist") ingest_service = global_injector.get(IngestService) - worker = LocalIngestWorker(ingest_service) + settings = global_injector.get(Settings) + worker = LocalIngestWorker(ingest_service, settings) worker.ingest_folder(root_path, args.ignored) if args.ignored: diff --git a/settings.yaml b/settings.yaml index cd977a0ae..6f936ddf3 100644 --- a/settings.yaml +++ b/settings.yaml @@ -17,6 +17,9 @@ server: secret: "Basic c2VjcmV0OmtleQ==" data: + local_ingestion: + enabled: ${LOCAL_INGESTION_ENABLED:false} + allow_ingest_from: ["*"] local_data_folder: local_data/private_gpt ui: diff --git a/tests/server/ingest/test_local_ingest.py b/tests/server/ingest/test_local_ingest.py new file mode 100644 index 000000000..860000efe --- /dev/null +++ b/tests/server/ingest/test_local_ingest.py @@ -0,0 +1,74 @@ +import os +import subprocess +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + + +@pytest.fixture() +def file_path() -> str: + return "test.txt" + + +def create_test_file(file_path: str) -> None: + with open(file_path, "w") as f: + f.write("test") + + +def clear_log_file(log_file_path: str) -> None: + if Path(log_file_path).exists(): + os.remove(log_file_path) + + +def read_log_file(log_file_path: str) -> str: + with open(log_file_path) as f: + return f.read() + + +def init_structure(folder: str, file_path: str) -> None: + clear_log_file(file_path) + os.makedirs(folder, exist_ok=True) + create_test_file(f"{folder}/${file_path}") + + +def test_ingest_one_file_in_allowed_folder( + file_path: str, test_client: TestClient +) -> None: + allowed_folder = "local_data/tests/allowed_folder" + init_structure(allowed_folder, file_path) + + test_env = os.environ.copy() + test_env["PGPT_PROFILES"] = "test" + test_env["LOCAL_INGESTION_ENABLED"] = "True" + + result = subprocess.run( + ["python", "scripts/ingest_folder.py", allowed_folder], + capture_output=True, + text=True, + env=test_env, + ) + + assert result.returncode == 0, f"Script failed with error: {result.stderr}" + response_after = test_client.get("/v1/ingest/list") + + count_ingest_after = len(response_after.json()["data"]) + assert count_ingest_after > 0, "No documents were ingested" + + +def test_ingest_disabled(file_path: str) -> None: + allowed_folder = "local_data/tests/allowed_folder" + init_structure(allowed_folder, file_path) + + test_env = os.environ.copy() + test_env["PGPT_PROFILES"] = "test" + test_env["LOCAL_INGESTION_ENABLED"] = "False" + + result = subprocess.run( + ["python", "scripts/ingest_folder.py", allowed_folder], + capture_output=True, + text=True, + env=test_env, + ) + + assert result.returncode != 0, f"Script failed with error: {result.stderr}"