Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check blob hash #942

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/data_ingestion.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ If needed, you can modify the chunking algorithm in `scripts/prepdocslib/textspl

To upload more PDFs, put them in the data/ folder and run `./scripts/prepdocs.sh` or `./scripts/prepdocs.ps1`.

A [recent change](https://github.com/Azure-Samples/azure-search-openai-demo/pull/835) added checks to see what's been uploaded before. The prepdocs script now writes an .md5 file with an MD5 hash of each file that gets uploaded. Whenever the prepdocs script is re-run, that hash is checked against the current hash and the file is skipped if it hasn't changed.
The script checks existing docs by comparing the hash of the local file to the hash of the file in the blob storage. If the hash is different, it will upload the new file to blob storage and update the index. If the hash is the same, it will skip the file.

## Removing documents

Expand Down
2 changes: 1 addition & 1 deletion scripts/prepdocs.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ $argumentList = "./scripts/prepdocs.py `"$cwd/data/*`" $adlsGen2StorageAccountAr
"--openaiservice `"$env:AZURE_OPENAI_SERVICE`" --openaikey `"$env:OPENAI_API_KEY`" " + `
"--openaiorg `"$env:OPENAI_ORGANIZATION`" --openaideployment `"$env:AZURE_OPENAI_EMB_DEPLOYMENT`" " + `
"--openaimodelname `"$env:AZURE_OPENAI_EMB_MODEL_NAME`" --index $env:AZURE_SEARCH_INDEX " + `
"--formrecognizerservice $env:AZURE_FORMRECOGNIZER_SERVICE --tenantid $env:AZURE_TENANT_ID -v"
"--formrecognizerservice $env:AZURE_FORMRECOGNIZER_SERVICE --tenantid $env:AZURE_TENANT_ID --blobstoragehashcheck -v"
Start-Process -FilePath $venvPythonPath -ArgumentList $argumentList -Wait -NoNewWindow
12 changes: 12 additions & 0 deletions scripts/prepdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from prepdocslib.filestrategy import DocumentAction, FileStrategy
from prepdocslib.listfilestrategy import (
ADLSGen2ListFileStrategy,
BlobListFileStrategy,
ListFileStrategy,
LocalListFileStrategy,
)
Expand Down Expand Up @@ -90,6 +91,11 @@ def setup_file_strategy(credential: AsyncTokenCredential, args: Any) -> FileStra
credential=adls_gen2_creds,
verbose=args.verbose,
)
elif args.blobstoragehashcheck:
print("Using Blob Storage Account files to get hashes of existing files")
list_file_strategy = BlobListFileStrategy(
path_pattern=args.files, blob_manager=blob_manager, verbose=args.verbose
)
else:
print(f"Using local files in {args.files}")
list_file_strategy = LocalListFileStrategy(path_pattern=args.files, verbose=args.verbose)
Expand Down Expand Up @@ -140,6 +146,12 @@ async def main(strategy: Strategy, credential: AsyncTokenCredential, args: Any):
parser.add_argument(
"--datalakestorageaccount", required=False, help="Optional. Azure Data Lake Storage Gen2 Account name"
)
parser.add_argument(
"--blobstoragehashcheck",
action="store_true",
required=False,
help="Optional. Use files from this Azure Blob Storage account for hash comparisons, rather than using local files.",
)
parser.add_argument(
"--datalakefilesystem",
required=False,
Expand Down
2 changes: 1 addition & 1 deletion scripts/prepdocs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ $aclArg --storageaccount "$AZURE_STORAGE_ACCOUNT" \
--openaimodelname "$AZURE_OPENAI_EMB_MODEL_NAME" --index "$AZURE_SEARCH_INDEX" \
--formrecognizerservice "$AZURE_FORMRECOGNIZER_SERVICE" --openaimodelname "$AZURE_OPENAI_EMB_MODEL_NAME" \
--tenantid "$AZURE_TENANT_ID" --openaihost "$OPENAI_HOST" \
--openaikey "$OPENAI_API_KEY" -v
--openaikey "$OPENAI_API_KEY" --blobstoragehashcheck -v
15 changes: 14 additions & 1 deletion scripts/prepdocslib/blobmanager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import binascii
import os
import re
from typing import Optional, Union

from azure.core.credentials_async import AsyncTokenCredential
from azure.storage.blob.aio import BlobServiceClient

from .listfilestrategy import File
from .file import File


class BlobManager:
Expand Down Expand Up @@ -60,6 +61,18 @@ async def remove_blob(self, path: Optional[str] = None):
print(f"\tRemoving blob {blob_path}")
await container_client.delete_blob(blob_path)

async def get_blob_hash(self, blob_name: str):
async with BlobServiceClient(
account_url=self.endpoint, credential=self.credential
) as service_client, service_client.get_blob_client(self.container, blob_name) as blob_client:
if not await blob_client.exists():
return None

blob_properties = await blob_client.get_blob_properties()
blob_hash_raw_bytes = blob_properties.content_settings.content_md5
hex_hash = binascii.hexlify(blob_hash_raw_bytes)
return hex_hash.decode("utf-8")

@classmethod
def sourcepage_from_file_page(cls, filename, page=0) -> str:
if os.path.splitext(filename)[1].lower() == ".pdf":
Expand Down
33 changes: 33 additions & 0 deletions scripts/prepdocslib/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import base64
import os
import re
from typing import IO, Optional


class File:
"""
Represents a file stored either locally or in a data lake storage account
This file might contain access control information about which users or groups can access it
"""

def __init__(self, content: IO, acls: Optional[dict[str, list]] = None):
self.content = content
self.acls = acls or {}

def __enter__(self):
return self

def __exit__(self, *args):
self.close()

def filename(self):
return os.path.basename(self.content.name)

def filename_to_id(self):
filename_ascii = re.sub("[^0-9a-zA-Z_-]", "_", self.filename())
filename_hash = base64.b16encode(self.filename().encode("utf-8")).decode("ascii")
return f"file-{filename_ascii}-{filename_hash}"

def close(self):
if self.content:
self.content.close()
77 changes: 52 additions & 25 deletions scripts/prepdocslib/listfilestrategy.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,17 @@
import base64
import hashlib
import os
import re
import tempfile
from abc import ABC
from glob import glob
from typing import IO, AsyncGenerator, Dict, List, Optional, Union
from typing import AsyncGenerator, Dict, List, Union

from azure.core.credentials_async import AsyncTokenCredential
from azure.storage.filedatalake.aio import (
DataLakeServiceClient,
)


class File:
"""
Represents a file stored either locally or in a data lake storage account
This file might contain access control information about which users or groups can access it
"""

def __init__(self, content: IO, acls: Optional[dict[str, list]] = None):
self.content = content
self.acls = acls or {}

def filename(self):
return os.path.basename(self.content.name)

def filename_to_id(self):
filename_ascii = re.sub("[^0-9a-zA-Z_-]", "_", self.filename())
filename_hash = base64.b16encode(self.filename().encode("utf-8")).decode("ascii")
return f"file-{filename_ascii}-{filename_hash}"

def close(self):
if self.content:
self.content.close()
from .blobmanager import BlobManager
from .file import File


class ListFileStrategy(ABC):
Expand Down Expand Up @@ -103,6 +81,55 @@ def check_md5(self, path: str) -> bool:
return False


class BlobListFileStrategy(ListFileStrategy):
"""
Concrete strategy for listing remote files that are located in a blob storage account
"""

def __init__(self, path_pattern: str, blob_manager: BlobManager, verbose: bool = False):
self.path_pattern = path_pattern
self.blob_manager = blob_manager
self.verbose = verbose

async def list_paths(self) -> AsyncGenerator[str, None]:
async for p in self._list_paths(self.path_pattern):
yield p

async def _list_paths(self, path_pattern: str) -> AsyncGenerator[str, None]:
for path in glob(path_pattern):
if os.path.isdir(path):
async for p in self._list_paths(f"{path}/*"):
yield p
else:
# Only list files, not directories
yield path

async def list(self) -> AsyncGenerator[File, None]:
async for path in self.list_paths():
if not await self.check_md5(path):
yield File(content=open(path, mode="rb"))

async def check_md5(self, path: str) -> bool:
# if filename ends in .md5 skip
if path.endswith(".md5"):
return True

# get hash from local file
with open(path, "rb") as file:
existing_hash = hashlib.md5(file.read()).hexdigest()

# get hash from blob storage
blob_hash = await self.blob_manager.get_blob_hash(os.path.basename(path))

# compare hashes from local and blob storage
if blob_hash and blob_hash.strip() == existing_hash.strip():
if self.verbose:
print(f"\tSkipping {path}, no changes detected.")
return True

return False


class ADLSGen2ListFileStrategy(ListFileStrategy):
"""
Concrete strategy for listing files that are located in a data lake storage account
Expand Down
5 changes: 4 additions & 1 deletion scripts/prepdocslib/searchmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,14 @@ async def update_content(self, sections: List[Section]):
for i, document in enumerate(documents):
document["embedding"] = embeddings[i]

# Remove any existing documents with the same sourcefile before uploading new ones
# that ensures we don't have outdated documents in the index
await self.remove_content(path=batch[0].content.filename())
await search_client.upload_documents(documents)

async def remove_content(self, path: Optional[str] = None):
if self.search_info.verbose:
print(f"Removing sections from '{path or '<all>'}' from search index '{self.search_info.index_name}'")
print(f"\tRemoving sections from '{path or '<all>'}' from search index '{self.search_info.index_name}'")
async with self.search_info.create_search_client() as search_client:
while True:
filter = None if path is None else f"sourcefile eq '{os.path.basename(path)}'"
Expand Down
29 changes: 29 additions & 0 deletions tests/test_blob_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,35 @@ async def mock_delete_blob(self, name, *args, **kwargs):
await blob_manager.remove_blob()


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info.minor < 10, reason="requires Python 3.10 or higher")
async def test_get_blob_hash(monkeypatch, mock_env, blob_manager):
blob_name = "test_blob"

# Set up mocks used by get_blob_hash
async def mock_exists(*args, **kwargs):
return True

monkeypatch.setattr("azure.storage.blob.aio.BlobClient.exists", mock_exists)

async def mock_get_blob_properties(*args, **kwargs):
class MockBlobProperties:
class MockContentSettings:
content_md5 = b"\x14\x0c\xdd\x8f\xd2\x74\x3d\x3b\xf1\xd1\xe2\x43\x01\xe4\xa0\x11"

content_settings = MockContentSettings()

return MockBlobProperties()

monkeypatch.setattr("azure.storage.blob.aio.BlobClient.get_blob_properties", mock_get_blob_properties)

blob_hash = await blob_manager.get_blob_hash(blob_name)

# The expected hash is the hex encoding of the mock content MD5
expected_hash = "140cdd8fd2743d3bf1d1e24301e4a011"
assert blob_hash == expected_hash


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info.minor < 10, reason="requires Python 3.10 or higher")
async def test_create_container_upon_upload(monkeypatch, mock_env, blob_manager):
Expand Down