Skip to content

Commit

Permalink
Fixes zenml-io#2257 : Unifies materializer temporary directory/file c…
Browse files Browse the repository at this point in the history
…reation onto the tempfile module
  • Loading branch information
akesterson committed Mar 24, 2024
1 parent 8a4dd58 commit 972fd23
Show file tree
Hide file tree
Showing 17 changed files with 208 additions and 244 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,22 @@ The `load()` and `save()` methods define the serialization and deserialization o

You will need to override these methods according to how you plan to serialize your objects. E.g., if you have custom PyTorch classes as `ASSOCIATED_TYPES`, then you might want to use `torch.save()` and `torch.load()` here.

It is a very common practice to use temporary files and directories as an intermediate step in a materializer's `load()` or `save()` method. Materializers using this pattern must take care to clean up after themselves even in the case of unexpected exceptions. The established pattern for this is to use the [`tempfile`](https://docs.python.org/3/library/tempfile.html) module's context handlers. These are a simple and efficient way to create and clean up temporary files and directories. `tempfile` is part of Python's standard cross-platform library. For example:

```python
def save(self, model: TFPreTrainedModel) -> None:
"""Writes a Model to the specified dir.
Args:
model: The TF Model to write.
"""
with tempfile.TemporaryDirectory() as temp_dir:
# Do something with the model in the temporary directory
# Save it to the artifact store
# When your code reaches this point, whether through normal flow or through
# an unhandled exception, your entire temporary directory is cleaned up.
```

#### (Optional) How to Visualize the Artifact

Optionally, you can override the `save_visualizations()` method to automatically save visualizations for all artifacts saved by your materializer. These visualizations are then shown next to your artifacts in the dashboard:
Expand Down
3 changes: 3 additions & 0 deletions src/zenml/artifact_stores/local_artifact_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def ensure_path_local(cls, path: str) -> str:
Raises:
ArtifactStoreInterfaceError: If the given path is not a local path.
"""
# TODO : This would be unnecessary if we prefixed local files with file://
# and this is not going to catch all possibilities anyway so local files should
# refactor to file://
remote_prefixes = ["gs://", "hdfs://", "s3://", "az://", "abfs://"]
if any(path.startswith(prefix) for prefix in remote_prefixes):
raise ArtifactStoreInterfaceError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from zenml.enums import ArtifactType
from zenml.integrations.bentoml.constants import DEFAULT_BENTO_FILENAME
from zenml.io import fileio
from zenml.logger import get_logger
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.utils import io_utils
Expand All @@ -50,22 +49,21 @@ def load(self, data_type: Type[bento.Bento]) -> bento.Bento:
An bento.Bento object.
"""
# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory()

# Copy from artifact store to temporary directory
io_utils.copy_dir(self.uri, temp_dir.name)

# Load the Bento from the temporary directory
imported_bento = Bento.import_from(
os.path.join(temp_dir.name, DEFAULT_BENTO_FILENAME)
)

# Try save the Bento to the local BentoML store
try:
_ = bentoml.get(imported_bento.tag)
except BentoMLException:
imported_bento.save()
return imported_bento
with tempfile.TemporaryDirectory() as temp_dir:
# Copy from artifact store to temporary directory
io_utils.copy_dir(self.uri, temp_dir)

# Load the Bento from the temporary directory
imported_bento = Bento.import_from(
os.path.join(temp_dir, DEFAULT_BENTO_FILENAME)
)

# Try save the Bento to the local BentoML store
try:
_ = bentoml.get(imported_bento.tag)
except BentoMLException:
imported_bento.save()
return imported_bento

def save(self, bento: bento.Bento) -> None:
"""Write to artifact store.
Expand All @@ -74,17 +72,16 @@ def save(self, bento: bento.Bento) -> None:
bento: An bento.Bento object.
"""
# Create a temporary directory to store the model
temp_dir = tempfile.TemporaryDirectory(prefix="zenml-temp-")
temp_bento_path = os.path.join(temp_dir.name, DEFAULT_BENTO_FILENAME)

# save the image in a temporary directory
bentoml.export_bento(bento.tag, temp_bento_path)
with tempfile.TemporaryDirectory(prefix="zenml-temp-") as temp_dir:
temp_bento_path = os.path.join(
temp_dir, DEFAULT_BENTO_FILENAME
)

# copy the saved image to the artifact store
io_utils.copy_dir(temp_dir.name, self.uri)
# save the image in a temporary directory
bentoml.export_bento(bento.tag, temp_bento_path)

# Remove the temporary directory
fileio.rmtree(temp_dir.name)
# copy the saved image to the artifact store
io_utils.copy_dir(temp_dir.name, self.uri)

def extract_metadata(
self, bento: bento.Bento
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from datasets.dataset_dict import DatasetDict

from zenml.enums import ArtifactType
from zenml.io import fileio
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.materializers.pandas_materializer import PandasMaterializer
from zenml.utils import io_utils
Expand Down Expand Up @@ -65,16 +64,13 @@ def save(self, ds: Union[Dataset, DatasetDict]) -> None:
Args:
ds: The Dataset to write.
"""
temp_dir = TemporaryDirectory()
path = os.path.join(temp_dir.name, DEFAULT_DATASET_DIR)
try:
with TemporaryDirectory() as temp_dir:
path = os.path.join(temp_dir, DEFAULT_DATASET_DIR)
ds.save_to_disk(path)
io_utils.copy_dir(
path,
os.path.join(self.uri, DEFAULT_DATASET_DIR),
)
finally:
fileio.rmtree(temp_dir.name)

def extract_metadata(
self, ds: Union[Dataset, DatasetDict]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,30 @@ def load(self, data_type: Type[PreTrainedModel]) -> PreTrainedModel:
Returns:
The model read from the specified dir.
"""
temp_dir = TemporaryDirectory()
io_utils.copy_dir(
os.path.join(self.uri, DEFAULT_PT_MODEL_DIR), temp_dir.name
)

config = AutoConfig.from_pretrained(temp_dir.name)
architecture = config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(temp_dir.name)
with TemporaryDirectory() as temp_dir:
io_utils.copy_dir(
os.path.join(self.uri, DEFAULT_PT_MODEL_DIR), temp_dir
)

config = AutoConfig.from_pretrained(temp_dir.name)
architecture = config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(temp_dir.name)

def save(self, model: PreTrainedModel) -> None:
"""Writes a Model to the specified dir.
Args:
model: The Torch Model to write.
"""
temp_dir = TemporaryDirectory()
model.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.uri, DEFAULT_PT_MODEL_DIR),
)
with TemporaryDirectory() as temp_dir:
model.save_pretrained(temp_dir)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.uri, DEFAULT_PT_MODEL_DIR),
)

def extract_metadata(
self, model: PreTrainedModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,30 @@ def load(self, data_type: Type[TFPreTrainedModel]) -> TFPreTrainedModel:
Returns:
The model read from the specified dir.
"""
temp_dir = TemporaryDirectory()
io_utils.copy_dir(
os.path.join(self.uri, DEFAULT_TF_MODEL_DIR), temp_dir.name
)

config = AutoConfig.from_pretrained(temp_dir.name)
architecture = "TF" + config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(temp_dir.name)
with TemporaryDirectory() as temp_dir:
io_utils.copy_dir(
os.path.join(self.uri, DEFAULT_TF_MODEL_DIR), temp_dir
)

config = AutoConfig.from_pretrained(temp_dir.name)
architecture = "TF" + config.architectures[0]
model_cls = getattr(
importlib.import_module("transformers"), architecture
)
return model_cls.from_pretrained(temp_dir.name)

def save(self, model: TFPreTrainedModel) -> None:
"""Writes a Model to the specified dir.
Args:
model: The TF Model to write.
"""
temp_dir = TemporaryDirectory()
model.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.uri, DEFAULT_TF_MODEL_DIR),
)
with TemporaryDirectory() as temp_dir:
model.save_pretrained(temp_dir)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.uri, DEFAULT_TF_MODEL_DIR),
)

def extract_metadata(
self, model: TFPreTrainedModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,22 @@ def load(self, data_type: Type[Any]) -> PreTrainedTokenizerBase:
Returns:
The tokenizer read from the specified dir.
"""
temp_dir = TemporaryDirectory()
io_utils.copy_dir(
os.path.join(self.uri, DEFAULT_TOKENIZER_DIR), temp_dir.name
)
with TemporaryDirectory() as temp_dir:
io_utils.copy_dir(
os.path.join(self.uri, DEFAULT_TOKENIZER_DIR), temp_dir
)

return AutoTokenizer.from_pretrained(temp_dir.name)
return AutoTokenizer.from_pretrained(temp_dir.name)

def save(self, tokenizer: Type[Any]) -> None:
"""Writes a Tokenizer to the specified dir.
Args:
tokenizer: The HFTokenizer to write.
"""
temp_dir = TemporaryDirectory()
tokenizer.save_pretrained(temp_dir.name)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.uri, DEFAULT_TOKENIZER_DIR),
)
with TemporaryDirectory() as temp_dir:
tokenizer.save_pretrained(temp_dir)
io_utils.copy_dir(
temp_dir.name,
os.path.join(self.uri, DEFAULT_TOKENIZER_DIR),
)
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,14 @@ def load(self, data_type: Type[Any]) -> lgb.Booster:
filepath = os.path.join(self.uri, DEFAULT_FILENAME)

# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
with tempfile.TemporaryDirectory(prefix="zenml-temp-") as temp_dir:
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
booster = lgb.Booster(model_file=temp_file)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
booster = lgb.Booster(model_file=temp_file)

# Cleanup and return
fileio.rmtree(temp_dir)
return booster
return booster

def save(self, booster: lgb.Booster) -> None:
"""Creates a JSON serialization for a lightgbm Booster model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ def load(self, data_type: Type[Any]) -> lgb.Dataset:
filepath = os.path.join(self.uri, DEFAULT_FILENAME)

# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
with tempfile.TemporaryDirectory(prefix="zenml-temp-") as temp_dir:
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
matrix = lgb.Dataset(temp_file, free_raw_data=False)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
matrix = lgb.Dataset(temp_file, free_raw_data=False)

# No clean up this time because matrix is lazy loaded
return matrix
# No clean up this time because matrix is lazy loaded
return matrix

def save(self, matrix: lgb.Dataset) -> None:
"""Creates a binary serialization for a lightgbm.Dataset object.
Expand All @@ -66,13 +66,12 @@ def save(self, matrix: lgb.Dataset) -> None:
filepath = os.path.join(self.uri, DEFAULT_FILENAME)

# Make a temporary phantom artifact
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
matrix.save_binary(temp_file)
with tempfile.TemporaryDirectory(prefix="zenml-temp-") as temp_dir:
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
matrix.save_binary(temp_file)

# Copy it into artifact store
fileio.copy(temp_file, filepath)
fileio.rmtree(temp_dir)
# Copy it into artifact store
fileio.copy(temp_file, filepath)

def extract_metadata(
self, matrix: lgb.Dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,34 +58,33 @@ def load(self, data_type: Type[Image.Image]) -> Image.Image:
filepath = [file for file in files if not fileio.isdir(file)][0]

# create a temporary folder
temp_dir = tempfile.TemporaryDirectory(prefix="zenml-temp-")
temp_file = os.path.join(
temp_dir.name,
f"{DEFAULT_IMAGE_FILENAME}{os.path.splitext(filepath)[1]}",
)
with tempfile.TemporaryDirectory(prefix="zenml-temp-") as temp_dir:
temp_file = os.path.join(
temp_dir,
f"{DEFAULT_IMAGE_FILENAME}{os.path.splitext(filepath)[1]}",
)

# copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
return Image.open(temp_file)
# copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
return Image.open(temp_file)

def save(self, image: Image.Image) -> None:
"""Write to artifact store.
Args:
image: An Image.Image object.
"""
temp_dir = tempfile.TemporaryDirectory(prefix="zenml-temp-")
file_extension = image.format or DEFAULT_IMAGE_EXTENSION
full_filename = f"{DEFAULT_IMAGE_FILENAME}.{file_extension}"
temp_image_path = os.path.join(temp_dir.name, full_filename)
with tempfile.TemporaryDirectory(prefix="zenml-temp-") as temp_dir:
file_extension = image.format or DEFAULT_IMAGE_EXTENSION
full_filename = f"{DEFAULT_IMAGE_FILENAME}.{file_extension}"
temp_image_path = os.path.join(temp_dir, full_filename)

# save the image in a temporary directory
image.save(temp_image_path)
# save the image in a temporary directory
image.save(temp_image_path)

# copy the saved image to the artifact store
artifact_store_path = os.path.join(self.uri, full_filename)
io_utils.copy(temp_image_path, artifact_store_path, overwrite=True) # type: ignore[attr-defined]
temp_dir.cleanup()
# copy the saved image to the artifact store
artifact_store_path = os.path.join(self.uri, full_filename)
io_utils.copy(temp_image_path, artifact_store_path, overwrite=True) # type: ignore[attr-defined]

def save_visualizations(
self, image: Image.Image
Expand Down
Loading

0 comments on commit 972fd23

Please sign in to comment.