diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3682d81c7..423d94f4a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,15 @@ repos: rev: 'v0.0.254' hooks: - id: ruff - files: "^express/" + files: | + (?x)^( + express/.*| + examples/pipelines/hf_dataset_pipeline| + examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component| + examples/pipelines/finetune_stable_diffusion/components/image_filter_component| + examples/pipelines/finetune_stable_diffusion/components/embedding_component| + examples/pipelines/finetune_stable_diffusion/dataset_creation_pipeline.py| + )$ args: [--fix, --exit-non-zero-on-fix] @@ -24,4 +32,12 @@ repos: hooks: - id: black name: black - files: "^express/" + files: | + (?x)^( + express/.*| + examples/pipelines/hf_dataset_pipeline| + examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component| + examples/pipelines/finetune_stable_diffusion/components/image_filter_component| + examples/pipelines/finetune_stable_diffusion/components/embedding_component| + examples/pipelines/finetune_stable_diffusion/dataset_creation_pipeline.py| + )$ diff --git a/docs/README.md b/docs/README.md index 4fed20e4f..0305f732d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -99,4 +99,62 @@ After transforming the input data (see below), an **ExpressDatasetDraft** create ### 1.b) Transforms and Loaders The most common type of component in Express is an **ExpressTransformComponent**, which takes an `ExpressDataset` and an optional dict of arguments as input and returns an `ExpressDatasetDraft` of transformed output data. -However, at the start of a new pipeline, you won't yet have any express datasets to transform. Instead, an express pipeline can use an **ExpressLoaderComponent** as entry-point, which only takes the optional dict of arguments to construct an ExpressDatasetDraft. For example, the arguments could specify an external data location and how to interpret it, after which a loader job can create a first `ExpressDataset`. \ No newline at end of file +However, at the start of a new pipeline, you won't yet have any express datasets to transform. Instead, an express pipeline can use an **ExpressLoaderComponent** as entry-point, which only takes the optional dict of arguments to construct an ExpressDatasetDraft. For example, the arguments could specify an external data location and how to interpret it, after which a loader job can create a first `ExpressDataset`. + +## **Data Manifest: a common approach to simplify different steps throughout the pipeline** +In order to keep track of the different data sources, we opt for a manifest-centered approach where +a manifest is simply a JSON file that is passed and modified throughout the different steps of the pipeline. + +```json +{ + "dataset_id":"-", + "index":"", + "associated_data":{ + "dataset":{ + "namespace_1":"", + "...":"" + }, + "caption":{ + "namespace_1":"", + "...":"" + }, + "embedding":{ + "namespace_1":"", + "commit_hash":"", + "creation_date":"", + "run_id":"" + } +} +``` +Further deep dive on some notations: + +* **namespace:** the namespace is used to identify the different data sources. For example, you can give +your seed images a specific namespace (e.g. `seed`). Then, the images retrieved with clip-retrieval will +have different namespace (e.g. `knn`, `centroid`). + +* **index**: the index denotes a unique index for each images with the format (e.g. `seed_00010`). +It indexes all the data sources in `associated_data`. +**Note**: the index keeps track of all the namespace (e.g. [`seed_00010`,`centroid_0001`, ...]) + +* **dataset**: a set of parquet files for each namespace that contain relevant metadata +(image size, location, ...) as well as the index. + +* **caption**: a set of parquet files for each namespace that contain captions +image captions as well as the index. + +* **metadata**: Helps keep track of the step that generated that manifest, code version and pipeline run id. + +The Express pipeline consists of multiple steps defines as **Express steps** that are repeated +throughout the pipeline. The manifest pattern offers the required flexibility to promote its reuse and avoid +duplication of data sources. For example: + +* **Data filtering** (e.g. filtering on image size): add new indices to the `index` but retain associated data. + +* **Data creation** (e.g. clip retrieval): add new indicies to the new `index` and another source of data under associated data with a new namespace. + +* **Data transformation** (e.g. image formatting): retain indices but replace dataset source in `dataset`. \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/README.md b/examples/pipelines/finetune_stable_diffusion/README.md index 3496a4c6e..3dd335618 100644 --- a/examples/pipelines/finetune_stable_diffusion/README.md +++ b/examples/pipelines/finetune_stable_diffusion/README.md @@ -101,63 +101,3 @@ bash build_images.sh This will build all the components located in the `components` folder, you could also opt for building a specific component by passing the `--build-dir` and passing the folder name of the component you want to build. - - -#TODO: move those docs elsewhere -## **Data Manifest: a common approach to simplify different steps throughout the pipeline** -In order to keep track of the different data sources, we opt for a manifest-centered approach where -a manifest is simply a JSON file that is passed and modified throughout the different steps of the pipeline. - -```json -{ - "dataset_id":"-", - "index":"", - "associated_data":{ - "dataset":{ - "namespace_1":"", - "...":"" - }, - "caption":{ - "namespace_1":"", - "...":"" - }, - "embedding":{ - "namespace_1":"", - "commit_hash":"", - "creation_date":"", - "run_id":"" - } -} -``` -Further deep dive on some notations: - -* **namespace:** the namespace is used to identify the different data sources. For example, you can give -your seed images a specific namespace (e.g. `seed`). Then, the images retrieved with clip-retrieval will -have different namespace (e.g. `knn`, `centroid`). - -* **index**: the index denotes a unique index for each images with the format (e.g. `seed_00010`). -It indexes all the data sources in `associated_data`. -**Note**: the index keeps track of all the namespace (e.g. [`seed_00010`,`centroid_0001`, ...]) - -* **dataset**: a set of parquet files for each namespace that contain relevant metadata -(image size, location, ...) as well as the index. - -* **caption**: a set of parquet files for each namespace that contain captions -image captions as well as the index. - -* **metadata**: Helps keep track of the step that generated that manifest, code version and pipeline run id. - -The Express pipeline consists of multiple steps defines as **Express steps** that are repeated -throughout the pipeline. The manifest pattern offers the required flexibility to promote its reuse and avoid -duplication of data sources. For example: - -* **Data filtering** (e.g. filtering on image size): add new indices to the `index` but retain associated data. - -* **Data creation** (e.g. clip retrieval): add new indicies to the new `index` and another source of data under associated data with a new namespace. - -* **Data transformation** (e.g. image formatting): retain indices but replace dataset source in `dataset`. diff --git a/examples/pipelines/finetune_stable_diffusion/build_images.sh b/examples/pipelines/finetune_stable_diffusion/build_images.sh index 25c5a9cbc..0f922353c 100755 --- a/examples/pipelines/finetune_stable_diffusion/build_images.sh +++ b/examples/pipelines/finetune_stable_diffusion/build_images.sh @@ -49,6 +49,7 @@ for dir in $component_dir/*/; do --build-arg GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD) \ --build-arg BUILD_TIMESTAMP=$(date '+%F_%H:%M:%S') \ --label org.opencontainers.image.source=https://github.com/${namespace}/${repo} \ + --platform=linux/arm64 \ . docker push "$full_image_name" fi diff --git a/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/Dockerfile b/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/Dockerfile deleted file mode 100644 index 9ea8412ff..000000000 --- a/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/Dockerfile +++ /dev/null @@ -1,13 +0,0 @@ -FROM europe-west1-docker.pkg.dev/storied-landing-366912/storied-landing-366912-default-repository/mlpipelines/kubeflow/components/base_component:latest - -# Set the working directory to the source folder -WORKDIR /src - -# Install packages -COPY requirements.txt . -RUN pip3 install -r requirements.txt - -# Copy over src-files of the component -COPY src /src - -ENTRYPOINT ["python", "main.py"] \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/README.MD b/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/README.MD deleted file mode 100644 index 33b2d68d2..000000000 --- a/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/README.MD +++ /dev/null @@ -1,22 +0,0 @@ -# dataset_loader_component - -### Description - -This is the first component of the pipeline, it loads in an image dataset from a specific -[Google Cloud Storage (GCS)](https://cloud.google.com/storage/docs) path and creates a -parquet files with different metadata (image path, format, size, ...). - -### **Inputs/Outputs** -The component accepts a `source-dataset-bucket` as input with a reference to the blob `source-dataset-blob` -where the image dataset is located. - -The component created the first data manifest that will be updated in subsequent components when new dataset sources -are added/filtered. - -See [`component.yaml`](component.yaml) for a more detailed description on all the input/output parameters. - -### **Practical considerations** - -* The main accepted formats of the input images are either `png`, `jpg` or `svg`. Eventually, all the formats will be converted to -`jgp` in a subsequent component since it is more suitable for machine learning inference and training. -* Make sure that your images are located in one directory on GCS and not spread across different directories. \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/component.yaml b/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/component.yaml deleted file mode 100644 index 136cb9e6f..000000000 --- a/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/component.yaml +++ /dev/null @@ -1,44 +0,0 @@ -name: dataset_loader_component -description: A component that takes an images dataset as input and initializes a data manifest - and writes it to an output file -inputs: - - name: run_id - description: The run id of the pipeline - type: String - - name: artifact_bucket - description: The GCS bucket used to store the artifact - type: String - - name: component_name - description: the name of the component (used to create gcs artefact path) - type: String - - name: project_id - description: The id of the gcp-project - type: String - - name: source_dataset_bucket - description: The GCS bucket containing the dataset to load - type: String - - name: source_dataset_blob - description: The GCS blob withing the specified bucket containing the dataset to load - type: String - - name: namespace - description: The dataset namespace (abbreviation for data source) - type: String - -outputs: - - name: data_manifest_path - description: Path to the local file containing the gcs path where the output has been stored - -implementation: - container: - image: europe-west1-docker.pkg.dev/storied-landing-366912/storied-landing-366912-default-repository/mlpipelines/kubeflow/components/dataset_loader_component:latest - command: [ - python3, main.py, - --run-id, { inputValue: run_id }, - --artifact-bucket, { inputValue: artifact_bucket }, - --component-name, { inputValue: component_name }, - --project-id, { inputValue: project_id }, - --source-dataset-bucket, { inputValue: source_dataset_bucket }, - --source-dataset-blob, { inputValue: source_dataset_blob }, - --namespace, { inputValue: namespace }, - --data-manifest-path, { outputPath: data_manifest_path }, - ] \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/requirements.txt b/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/requirements.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/src/main.py b/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/src/main.py deleted file mode 100644 index 78f2d3042..000000000 --- a/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/src/main.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -This file is the entrypoint of the component. It will parse all arguments -and give them to the actual core of the component. -""" -import os -import argparse -import logging -import tempfile -from pathlib import Path -from datetime import datetime - -from google.cloud import storage - -# pylint: disable=import-error -from helpers import storage_helpers, parquet_helpers -from helpers.logger import get_logger -from helpers.manifest_helpers import DataManifest - - -def parse_args(): - """Parse component arguments""" - - parser = argparse.ArgumentParser() - parser.add_argument('--run-id', - type=str, - required=True, - help='The run id of the pipeline') - parser.add_argument('--artifact-bucket', - type=str, - required=True, - help='The GCS bucket used to store the artifacts') - parser.add_argument('--component-name', - type=str, - required=True, - help='The name of the component') - parser.add_argument('--project-id', - type=str, - required=True, - help='The id of the gcp-project') - parser.add_argument('--source-dataset-bucket', - type=str, - required=True, - help='The GCS bucket containing the dataset to load') - parser.add_argument('--source-dataset-blob', - type=str, - required=True, - help='The GCS blob withing the specified bucket containing the dataset' - ' to load') - parser.add_argument('--namespace', - type=str, - required=True, - help='The dataset namespace (abbreviation for data source)') - parser.add_argument('--data-manifest-path', - type=str, - required=True, - help='The data manifest output artifact') - return parser.parse_args() - - -# pylint: disable=too-many-locals, too-many-arguments -def dataset_loader_component(run_id: str, - artifact_bucket: str, - component_name: str, - project_id: str, - source_dataset_bucket: str, - source_dataset_blob: str, - namespace: str, - data_manifest_path: str) -> None: - """ - A component that takes an images dataset as input and initializes a data manifest - and writes it to an output file. - Args: - run_id (str): the run id of the pipeline - artifact_bucket (str): The GCS bucket used to store the artifacts - component_name (str): the name of the component (used to create gcs artefact path) - project_id (str): The id of the gcp-project - source_dataset_bucket (str): The GCS bucket containing the dataset to load - source_dataset_blob (str): The GCS blob withing the specified bucket containing the dataset - to load - namespace (str): The dataset namespace (abbreviation for data source) - data_manifest_path (str): the path to write the manifest - """ - logger = get_logger(name=__name__, level=logging.INFO) - logger.info('Started job...') - with tempfile.TemporaryDirectory() as tmp_dir: - logger.info('created temporary directory %s', tmp_dir) - - # Initialize storage client - storage_client = storage.Client(project=project_id) - - # Initialize GCS temporary storage paths - # Parse to match the directory convention from MINIO (-) - component_artifact_dir = run_id.rpartition('-')[0] - artifact_bucket_blob_path = f"custom_artifact/{component_artifact_dir}/{component_name}" - - logger.info("custom artifact will be uploaded to %s", - f'gs://{artifact_bucket}/{artifact_bucket_blob_path}') - - dataset_parquet_tmp_path = os.path.join(tmp_dir, 'dataset.parquet') - dataset_parquet_blob_path = os.path.join(artifact_bucket_blob_path, 'dataset.parquet') - index_parquet_tmp_path = os.path.join(tmp_dir, 'index.parquet') - index_parquet_blob_path = os.path.join(artifact_bucket_blob_path, 'index.parquet') - Path(dataset_parquet_tmp_path).parent.mkdir(parents=True, exist_ok=True) - Path(index_parquet_tmp_path).parent.mkdir(parents=True, exist_ok=True) - - logger.info('GCS and temporary paths initialized') - - # Write parquet index file - parquet_helpers.write_index_parquet(index_parquet_path=index_parquet_tmp_path, - data_iterable_producer=storage_helpers.get_blob_id, - storage_client=storage_client, - bucket_name=source_dataset_bucket, - prefix=source_dataset_blob, - id_prefix=namespace) - # Write parquet dataset - parquet_helpers.write_dataset_parquet \ - (dataset_parquet_path=dataset_parquet_tmp_path, - data_iterable_producer=storage_helpers.get_blob_metadata, - storage_client=storage_client, - bucket_name=source_dataset_bucket, - prefix=source_dataset_blob, - id_prefix=namespace) - - logger.info('Parquet manifest files updated') - - # Upload the parquet - storage_helpers.upload_file_to_bucket(storage_client=storage_client, - file_to_upload_path=dataset_parquet_tmp_path, - bucket_name=artifact_bucket, - blob_path=dataset_parquet_blob_path) - storage_helpers.upload_file_to_bucket(storage_client=storage_client, - file_to_upload_path=index_parquet_tmp_path, - bucket_name=artifact_bucket, - blob_path=index_parquet_blob_path) - logger.info('Parquet manifest files uploaded to GCS') - - data_manifest = DataManifest() - data_manifest.dataset_id = f"{run_id}_{component_name}" - data_manifest.index = f'gs://{artifact_bucket}/{index_parquet_blob_path}' - data_manifest.associated_data.dataset[namespace] = \ - f'gs://{artifact_bucket}/{dataset_parquet_blob_path}' - data_manifest.metadata.branch = "" # TODO: Fill from docker build env var - data_manifest.metadata.commit_hash = "" # TODO: Fill from docker build env var - data_manifest.metadata.creation_date = datetime.now().strftime("%d-%m-%Y_%H-%M-%S") - data_manifest.metadata.run_id = run_id - - logger.info('Manifest file created and updated') - - # Write manifest to outputPath - Path(data_manifest_path).parent.mkdir(parents=True, exist_ok=True) - Path(data_manifest_path).write_text(data_manifest.to_json()) - - logger.info('Manifest file written to %s', data_manifest_path) - - # Clean up temporary storage - logger.info('Files removed from temporary storage.') - logger.info('Job completed.') - - -if __name__ == '__main__': - args = parse_args() - dataset_loader_component(run_id=args.run_id, - artifact_bucket=args.artifact_bucket, - component_name=args.component_name, - project_id=args.project_id, - source_dataset_bucket=args.source_dataset_bucket, - source_dataset_blob=args.source_dataset_blob, - namespace=args.namespace, - data_manifest_path=args.data_manifest_path) diff --git a/examples/pipelines/finetune_stable_diffusion/components/embedding_component/Dockerfile b/examples/pipelines/finetune_stable_diffusion/components/embedding_component/Dockerfile new file mode 100644 index 000000000..2d626921d --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/components/embedding_component/Dockerfile @@ -0,0 +1,29 @@ +FROM --platform=linux/amd64 pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime + +## System dependencies +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install git curl -y + +# Downloading gcloud package +RUN curl https://dl.google.com/dl/cloudsdk/release/google-cloud-sdk.tar.gz > /tmp/google-cloud-sdk.tar.gz + +# Installing the package +RUN mkdir -p /usr/local/gcloud \ +&& tar -C /usr/local/gcloud -xvf /tmp/google-cloud-sdk.tar.gz \ +&& /usr/local/gcloud/google-cloud-sdk/install.sh + +# Adding the package path to local +ENV PATH $PATH:/usr/local/gcloud/google-cloud-sdk/bin + +# install requirements +COPY requirements.txt /tmp/requirements.txt +RUN python3 -m pip install -r /tmp/requirements.txt + +# Copy over src-files of the component +COPY src /src + +# Set the working directory to the source folder +WORKDIR /src + +ENTRYPOINT ["python", "main.py"] \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/embedding_component/README.MD b/examples/pipelines/finetune_stable_diffusion/components/embedding_component/README.MD new file mode 100644 index 000000000..2c32f83ed --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/components/embedding_component/README.MD @@ -0,0 +1,20 @@ +# embedding_component + +### Description + +This component extracts [embeddings](https://rom1504.medium.com/image-embeddings-ed1b194d113e) +from the converted images using a [CLIP model](https://huggingface.co/docs/transformers/model_doc/clip). + +Since image embeddings are good at capturing the features of the image in a compact and useful way, they +will be used in the next steps to retrieve images similar to our seed images. + +The CLIP model is downloaded once inside the component during build time (only the vision encoder is downloaded). The images are then embedded and the embeddings are then added to the manifest as a separate πŸ€— dataset. The data manifest is then updated with a reference to the location of the embeddings. + +### **Inputs/Outputs** + +See [`component.yaml`](component.yaml) for a more detailed description on all the input/output parameters. + +### **Practical considerations** + +* There exists many variants for the CLIP model, the current variant that is used is the [`Vit-L/14`](https://huggingface.co/openai/clip-vit-large-patch14) variant. The reason for that is that the embeddings produced from this variant are the same ones that are used for building the indices for the [LAION-5B dataset](https://laion.ai/blog/laion-5b/#:~:text=Pre%2DComputed%20Embeddings). + diff --git a/examples/pipelines/finetune_stable_diffusion/components/embedding_component/component.yaml b/examples/pipelines/finetune_stable_diffusion/components/embedding_component/component.yaml new file mode 100644 index 000000000..164476847 --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/components/embedding_component/component.yaml @@ -0,0 +1,29 @@ +name: embedding_component +description: A component that embeds images using a CLIP model from the πŸ€— hub. +inputs: + - name: extra_args + description: Additional arguments passed to the component, as a json dict string + type: String + + - name: metadata + description: Metadata arguments, passed as a json dict string + type: String + + - name: input_manifest + description: Path to the input manifest + type: String + +outputs: + - name: output_manifest + description: Path to the output manifest + +implementation: + container: + image: ghcr.io/ml6team/embedding_component:latest + command: [ + python3, main.py, + --input-manifest, {inputPath: input_manifest}, + --metadata, {inputValue: metadata}, + --extra-args, {inputValue: extra_args}, + --output-manifest, {outputPath: output_manifest}, + ] \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/embedding_component/requirements.txt b/examples/pipelines/finetune_stable_diffusion/components/embedding_component/requirements.txt new file mode 100644 index 000000000..3e9bce5e0 --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/components/embedding_component/requirements.txt @@ -0,0 +1,4 @@ +torch==2.0.0 +transformers==4.27.3 +git+https://github.com/ml6team/express.git@0f43bbd88826ffd1b620100275700f9cbe7c0b6e +Pillow==9.4.0 \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/embedding_component/src/main.py b/examples/pipelines/finetune_stable_diffusion/components/embedding_component/src/main.py new file mode 100644 index 000000000..e91d24e20 --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/components/embedding_component/src/main.py @@ -0,0 +1,104 @@ +""" +This component adds a data source to the manifest by embedding the images. +""" +import logging +from typing import Optional, Union, Dict + + +from express.components.hf_datasets_components import ( + HFDatasetsTransformComponent, + HFDatasetsDataset, + HFDatasetsDatasetDraft, +) +from express.logger import configure_logging + +import torch + +from transformers import CLIPProcessor, CLIPVisionModelWithProjection + +configure_logging() +logger = logging.getLogger(__name__) + + +cuda_available = torch.cuda.is_available() +device = "cuda" if cuda_available else "cpu" +logger.info("CUDA device availability:%s", cuda_available) + +if cuda_available: + logger.info(torch.cuda.get_device_name(0)) + logger.info("CUDA device: %s", torch.cuda.get_device_name(0)) + logger.info("Num of GPUs: %s", torch.cuda.device_count()) + + +@torch.no_grad() +def embed(examples, processor, model): + images = examples["image"] + + # prepare images for the model + inputs = processor(images=images, return_tensors="pt").to(device) + + # embed to get (batch_size, hidden_size) embeddings + outputs = model(**inputs) + image_embeds = outputs.image_embeds + + # flatten into list of embeddings + examples["embeddings"] = image_embeds.cpu().tolist() + + return examples + + +class EmbeddingComponent(HFDatasetsTransformComponent): + """ + Component that embeds the images using a CLIP model from Hugging Face. + """ + + @classmethod + def transform( + cls, + data: HFDatasetsDataset, + extra_args: Optional[Dict[str, Union[str, int, float, bool]]] = None, + ) -> HFDatasetsDatasetDraft: + """ + An example function showcasing the data transform component using Express functionalities + + Args: + data (HFDatasetsDataset[TIndex, TData]): express dataset providing access to data of a + given type + extra_args (Optional[Dict[str, Union[str, int, float, bool]]): optional args to pass to + the function + Returns: + HFDatasetsDatasetDraft: a dataset draft that creates a plan for an output manifest + """ + + # 1) Get one particular data source from the manifest + # TODO check whether we can leverage streaming + logger.info("Loading image dataset...") + image_dataset = data.load(data_source="images") + + # 2) Create embedding dataset + logger.info("Loading CLIP...") + processor = CLIPProcessor.from_pretrained(extra_args["model_id"]) + model = CLIPVisionModelWithProjection.from_pretrained(extra_args["model_id"]) + model.to(device) + + logger.info("Embedding images...") + embedded_dataset = image_dataset.map( + embed, + batched=True, + batch_size=extra_args["batch_size"], + fn_kwargs=dict(processor=processor, model=model), + remove_columns=["image", "width", "height", "byte_size"], + ) + + # 3) Create dataset draft which adds a data source to the manifest + logger.info("Creating draft...") + data_sources = {"embeddings": embedded_dataset} + dataset_draft = HFDatasetsDatasetDraft( + data_sources=data_sources, extending_dataset=data + ) + + return dataset_draft + + +if __name__ == "__main__": + EmbeddingComponent.run() diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/Dockerfile b/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/Dockerfile deleted file mode 100644 index dc88da1b2..000000000 --- a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/Dockerfile +++ /dev/null @@ -1,27 +0,0 @@ -FROM europe-west1-docker.pkg.dev/storied-landing-366912/storied-landing-366912-default-repository/mlpipelines/kubeflow/components/base_component:latest - -ARG DEBIAN_FRONTEND=noninteractive -ENV COMMIT_SHA_CLIP=d50d76daa670286dd6cacf3bcd80b5e4823fc8e1 - -## System dependencies -RUN apt-get update && \ - apt-get upgrade -y && \ - apt-get install git wget -y - -# Set the working directory to the source folder -WORKDIR /src - -## Download model -# (Download Link:https://github.com/openai/CLIP/issues/199) -RUN wget https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt - -# Install CLIP -RUN pip3 install git+https://github.com/openai/CLIP.git@${COMMIT_SHA_CLIP} - -# Install packages -COPY requirements.txt . -RUN pip3 install -r requirements.txt - -COPY src /src - -ENTRYPOINT ["python", "main.py"] \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/README.MD b/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/README.MD deleted file mode 100644 index 2dec621d7..000000000 --- a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/README.MD +++ /dev/null @@ -1,23 +0,0 @@ -# image_embedding_component - -### Description - -This component extracts [image embeddings](https://rom1504.medium.com/image-embeddings-ed1b194d113e) -from the converted images using a [CLIP model](https://www.google.com/search?q=clip+embeddings&oq=clip+embeddings&aqs=chrome..69i57j0i22i30j69i60j69i64l2j69i60j69i64j69i60.6764j0j7&sourceid=chrome&ie=UTF-8). - -Since image embeddings are good at capturing the features of the image in a compact and useful way, they -will be used in the next steps to retrieve images similar to our seed images. - -The CLIP model is downloaded once inside the component during build time. The images are then embedded -and the embeddings are then saved as numpy files `.npy` files and uploaded to GCS. The data manifest is then updated with the reference to the embeddings path. - -### **Inputs/Outputs** - -See [`component.yaml`](component.yaml) for a more detailed description on all the input/output parameters. - -### **Practical considerations** - -* There exists many variants for the CLIP model, the current variant that is used is the `Vit-L/14`. The reason -for that is that the embeddings produced from this variant are the same ones that are used for building the indices for the -[LAION-5B dataset](https://laion.ai/blog/laion-5b/#:~:text=Pre%2DComputed%20Embeddings). - diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/__init__.py b/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/component.yaml b/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/component.yaml deleted file mode 100644 index ccc4bb6d8..000000000 --- a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/component.yaml +++ /dev/null @@ -1,40 +0,0 @@ -name: image_embedding_component -description: A component that takes a data manifest as input and embeds images with a CLIP model. - The output manifest will contain references to the blob path where the embeddings are stored. -inputs: - - name: run_id - description: The run id of the pipeline - type: String - - name: artifact_bucket - description: The GCS bucket used to store the artifact - type: String - - name: component_name - description: the name of the component (used to create gcs artefact path) - type: String - - name: project_id - description: The id of the gcp-project - type: String - - name: batch_size - description: The number of images to batch before embedding - type: Integer - - name: data_manifest_path - description: The previous component manifest path - type: Stringc - -outputs: - - name: data_manifest_path_embedding_component - description: Path to the local file containing the gcs path where the output has been stored - -implementation: - container: - image: europe-west1-docker.pkg.dev/storied-landing-366912/storied-landing-366912-default-repository/mlpipelines/kubeflow/components/image_embedding_component:latest - command: [ - python3, main.py, - --run-id, { inputValue: run_id }, - --artifact-bucket, { inputValue: artifact_bucket }, - --component-name, { inputValue: component_name }, - --project-id, { inputValue: project_id }, - --batch-size, { inputValue: batch_size }, - --data-manifest-path, { inputPath: data_manifest_path }, - --data-manifest-path-embedding-component, { outputPath: data_manifest_path_embedding_component }, - ] \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/requirements.txt b/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/requirements.txt deleted file mode 100644 index a4ebd6025..000000000 --- a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -fairscale==0.4.4 -ftfy==6.1.1 -Pillow==9.3.0 -regex==2022.10.31 -torchvision==0.13.0 -transformers==4.15.0 -tqdm==4.64.1 - - diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/src/__init__.py b/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/src/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/src/main.py b/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/src/main.py deleted file mode 100644 index 2e314938e..000000000 --- a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/src/main.py +++ /dev/null @@ -1,173 +0,0 @@ -""" -This file is the entrypoint of the component. It will parse all arguments -and give them to the actual core of the component. -""" -import os -import tempfile -import argparse -import logging -import json -from pathlib import Path -from datetime import datetime - -import pyarrow.compute as pc -from google.cloud import storage - -# pylint: disable=import-error -from helpers.logger import get_logger -from helpers import storage_helpers, parquet_helpers, kfp_helpers -from helpers.manifest_helpers import DataManifest -from utils.image_embedding import KfpPipelineImageEmbedder - - -def parse_args(): - """Parse component arguments""" - - parser = argparse.ArgumentParser() - parser.add_argument('--run-id', - type=str, - required=True, - help='The run id of the pipeline') - parser.add_argument('--artifact-bucket', - type=str, - required=True, - help='The GCS bucket used to store the artifacts') - parser.add_argument('--component-name', - type=str, - required=True, - help='The name of the component') - parser.add_argument('--project-id', - type=str, - required=True, - help='The id of the gcp-project') - parser.add_argument('--batch-size', - type=int, - required=True, - help='The number of images to batch before embedding') - parser.add_argument('--data-manifest-path', - type=str, - required=True, - help='The previous component manifest path') - parser.add_argument('--data-manifest-path-embedding-component', - type=str, - required=True, - help='The path to the output manifest file') - - return parser.parse_args() - - -# pylint: disable=too-many-locals, too-many-arguments -def image_embedding_component(run_id: str, - artifact_bucket: str, - component_name: str, - project_id: str, - batch_size: int, - data_manifest_path: str, - data_manifest_path_embedding_component: str) -> None: - """ - A component that takes an images dataset as input and generated image embeddings out of them - Args: - run_id (str): the run id of the pipeline - artifact_bucket (str): The GCS bucket used to store the artifacts - component_name (str): the name of the component (used to create gcs artefact path) - project_id (str): The id of the gcp-project - batch_size (int): the number of images to batch before embedding - data_manifest_path (str): The previous component manifest path - data_manifest_path_embedding_component (str): the path to the output manifest file - """ - logger = get_logger(name=__name__, level=logging.INFO) - logger.info('Started job...') - - # Show CUDA availability - kfp_helpers.get_cuda_availability() - - # Initialize storage client - storage_client = storage.Client(project=project_id) - - with tempfile.TemporaryDirectory() as tmp_dir: - logger.info('created temporary directory %s', tmp_dir) - - # Initialize GCS custom artifact path - component_artifact_dir = run_id.rpartition('-')[0] - artifact_bucket_blob_path = f"custom_artifact/{component_artifact_dir}/{component_name}" - embedding_blob_path = f"gs://{artifact_bucket}/{artifact_bucket_blob_path}/embeddings" - logger.info("custom artifact will be uploaded to %s", - f'gs://{artifact_bucket}/{artifact_bucket_blob_path}') - - tmp_img_dir_path = os.path.join(tmp_dir, 'img_dir') - tmp_embedding_dir_path = os.path.join(tmp_dir, 'embeddings_dir') - os.makedirs(tmp_img_dir_path, exist_ok=True) - os.makedirs(tmp_embedding_dir_path, exist_ok=True) - dataset_id_parquet_tmp_path = os.path.join(tmp_dir, 'dataset.parquet') - Path(dataset_id_parquet_tmp_path).parent.mkdir(parents=True, exist_ok=True) - - # Read manifest - with open(data_manifest_path) as f: - manifest_load = json.load(f) - data_manifest = DataManifest.from_dict(manifest_load) - - # Get index and dataset parquet gcs paths - index_parquet_prev_gcs_path = data_manifest.index - # TODO: replace harcoded namespace - dataset_parquet_prev_gcs_path = data_manifest.associated_data.dataset['cf'] - - # Download parquet files locally - index_parquet_prev_tmp_path = storage_helpers.download_file_from_bucket( - storage_client, index_parquet_prev_gcs_path, tmp_dir) - dataset_parquet_prev_tmp_path = storage_helpers.download_file_from_bucket( - storage_client, dataset_parquet_prev_gcs_path, tmp_dir) - - # Get index_ids - index_ids_images_to_embed = parquet_helpers.get_column_list_from_parquet( - parquet_scanner_or_path=index_parquet_prev_tmp_path, - column_name='index') - - # Construct parquet filters and filter based on the criteria - filters = (pc.field("file_id").isin(index_ids_images_to_embed)) - - filtered_dataset_scanner = parquet_helpers.filter_parquet_file( - file_path=dataset_parquet_prev_tmp_path, - filters=filters, - batch_size=batch_size) - - # Caption images and store them in a parquet file - kfp_image_embedder = KfpPipelineImageEmbedder( - parquet_dataset=filtered_dataset_scanner, - embedding_blob_path=embedding_blob_path, - tmp_img_path=tmp_img_dir_path, - tmp_embedding_path=tmp_embedding_dir_path, - download_list_file_path=os.path.join(tmp_dir, 'download_gcs.txt') - ) - - kfp_image_embedder.start() - - # Update manifest - data_manifest.dataset_id = f"{run_id}_{component_name}" - data_manifest.metadata.branch = "" # TODO: Fill from docker build env var - data_manifest.metadata.commit_hash = "" # TODO: Fill from docker build env var - # TODO: replace harcoded namespace with string or list input - data_manifest.associated_data.embedding['cf'] = embedding_blob_path - data_manifest.metadata.creation_date = datetime.now().strftime("%d-%m-%Y_%H-%M-%S") - data_manifest.metadata.run_id = run_id - - logger.info('Manifest file created and updated') - - # Write manifest to outputPath - Path(data_manifest_path_embedding_component).parent.mkdir(parents=True, exist_ok=True) - Path(data_manifest_path_embedding_component).write_text(data_manifest.to_json()) - - logger.info('Manifest file written to %s', data_manifest_path_embedding_component) - - logger.info('Job completed.') - - -if __name__ == '__main__': - args = parse_args() - image_embedding_component \ - (run_id=args.run_id, - artifact_bucket=args.artifact_bucket, - component_name=args.component_name, - project_id=args.project_id, - batch_size=args.batch_size, - data_manifest_path=args.data_manifest_path, - data_manifest_path_embedding_component=args.data_manifest_path_embedding_component) diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/src/utils/__init__.py b/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/src/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/src/utils/image_embedding.py b/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/src/utils/image_embedding.py deleted file mode 100644 index 5f6860fd8..000000000 --- a/examples/pipelines/finetune_stable_diffusion/components/image_embedding_component/src/utils/image_embedding.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Image embedding""" -import os -import logging -from typing import List - -import torch -import numpy as np -from tqdm import tqdm -from PIL import Image -from pyarrow.dataset import Scanner - -# pylint: disable=import-error -import clip -from helpers import storage_helpers, io_helpers -from helpers.logger import get_logger - -LOGGER = get_logger(name=__name__, level=logging.INFO) - -# Set to allow pillow to process images of different sizes -Image.MAX_IMAGE_PIXELS = None - - -# pylint: disable=too-many-instance-attributes, too-few-public-methods -class KfpPipelineImageEmbedder: - """kfp image embedder """ - - # pylint: disable=too-many-arguments - def __init__(self, parquet_dataset: Scanner, embedding_blob_path: str, - tmp_img_path: str, tmp_embedding_path: str, download_list_file_path: str): - """ - Class that structures the kfp images conversion loop - Args: - embedding_blob_path (str): the blob path where the image embeddings will be stored - parquet_dataset (Scanner): the scanned parquet dataset - tmp_img_path (str) the temporary path used to store the downloaded images - tmp_embedding_path (str): the temporary path to save the image embeddings - download_list_file_path (str): path to file list containing one-per-line list of GCS - URIs to download - """ - self.device = "cuda" if torch.cuda.is_available() else "cpu" - LOGGER.info('CLIP model initialized with %s device', self.device) - self.clip_model_vit, self.clip_preprocess_vit = clip.load('ViT-L-14', - device=self.device) - self.parquet_dataset = parquet_dataset - self.embedding_blob_path = embedding_blob_path - self.tmp_img_path = tmp_img_path - self.tmp_embedding_path = tmp_embedding_path - self.download_list_file_path = download_list_file_path - self.device = "cuda" if torch.cuda.is_available() else "cpu" - - def _write_gcs_file_lists(self): - """ - Function that writes the gcs download and upload list files for bulk download/upload - """ - with open(self.download_list_file_path, "w") as download_list_file: - for batch in self.parquet_dataset.to_batches(): - for row in batch.to_pylist(): - file_uri = row['file_uri'] - download_list_file.write(file_uri + "\n") - - LOGGER.info("GCS download file list written to %s", self.download_list_file_path) - - def _download_images_to_embed(self): - """Function that download the images to caption locally""" - LOGGER.info("Downloading images to caption") - storage_helpers.copy_files_bulk(self.download_list_file_path, - self.tmp_img_path) - LOGGER.info("The images to be captioned were successfully downloaded to %s", - self.tmp_img_path) - - def _write_embeddings(self, embeddings: np.array, file_ids: List[str]): - """ - Function that saves the image embeddings as a numpy array - Args: - embeddings (np.array): the embeddings array - file_ids (List[str]): the list of file ids associated with the embeddings - """ - for embedding_array, file_id in zip(embeddings, file_ids): - np.save(os.path.join(self.tmp_embedding_path, f"{file_id}.npy"), embedding_array) - - def _embed_images(self): - """ - Function that embeds the images with CLIP - """ - LOGGER.info("Starting image embedding with CLIP") - - for batch in self.parquet_dataset.to_batches(): - # pyarrow's batch_size approximates the number of batches to return - # Batches may be smaller if there aren’t enough rows in the file and can return zero - # in some occasions - if len(batch) > 0: - img_list, file_ids = [], [] - for row in tqdm(batch.to_pylist()): - file_uri, file_id = row['file_uri'], row['file_id'] - file_name = io_helpers.get_file_name(file_uri, return_extension=True) - local_file_path = os.path.join(self.tmp_img_path, file_name) - img_list.append(self.clip_preprocess_vit(Image.open(local_file_path))) - file_ids.append(file_id) - # pylint: disable=no-member - img_stack = torch.tensor(np.stack(img_list), device=self.device) - embeddings = self.clip_model_vit.encode_image(img_stack).cpu().detach().numpy() - self._write_embeddings(embeddings, file_ids) - - LOGGER.info("Image embedding completed") - - def _upload_image_embeddings(self): - """Function that uploads the image embeddings from local disk to gcs""" - storage_helpers.copy_folder_bulk(self.tmp_embedding_path, self.embedding_blob_path) - - def start(self): - """ - Function that starts the image embedding loop - """ - self._write_gcs_file_lists() - self._download_images_to_embed() - self._embed_images() - self._upload_image_embeddings() diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/Dockerfile b/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/Dockerfile index 611ccfb16..49b7639f9 100644 --- a/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/Dockerfile +++ b/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/Dockerfile @@ -1,14 +1,29 @@ -FROM europe-west1-docker.pkg.dev/storied-landing-366912/storied-landing-366912-default-repository/mlpipelines/kubeflow/components/base_component:latest +FROM --platform=linux/amd64 python:3.8-slim -# Set the working directory to the source folder -WORKDIR /src +## System dependencies +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install git curl -y -# Copy over src-files of the component -COPY requirements.txt . +# Downloading gcloud package +RUN curl https://dl.google.com/dl/cloudsdk/release/google-cloud-sdk.tar.gz > /tmp/google-cloud-sdk.tar.gz + +# Installing the package +RUN mkdir -p /usr/local/gcloud \ +&& tar -C /usr/local/gcloud -xvf /tmp/google-cloud-sdk.tar.gz \ +&& /usr/local/gcloud/google-cloud-sdk/install.sh + +# Adding the package path to local +ENV PATH $PATH:/usr/local/gcloud/google-cloud-sdk/bin -# Install packages -RUN pip3 install -r requirements.txt +# install requirements +COPY requirements.txt / +RUN pip3 install --no-cache-dir -r requirements.txt +# Copy over src-files of the component COPY src /src +# Set the working directory to the source folder +WORKDIR /src + ENTRYPOINT ["python", "main.py"] \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/README.MD b/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/README.MD index afe998d6b..7873090c2 100644 --- a/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/README.MD +++ b/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/README.MD @@ -1,20 +1,9 @@ -# dataset_filter_component +# image_filter_component ### Description -This component is used to filter the images based on the metadata -attributes to only keep images of a certain size and format. +This component is used to filter the images based on the metadata attributes to only keep images of a certain height and width. ### **Inputs/Outputs** -The component accepts as input different filtering parameters of image size and formats to keep. The filtering -is done against the `dataset.parquet` which contains the relevant metadata used for filtering. -The **index.parquet** of the dataset source is then updated to include only the images that passed the filter. -See [`component.yaml`](component.yaml) for a more detailed description on all the input/output parameters. - -### **Practical considerations** - -* The reason for only allowing `png` and `svg` (as well as `jpg`) is that the [`image_conversion_component`](../image_conversion_component) -currently only supports converting those formats to `jpg`. If you have a different data format that you want to include, -make sure to add a method to convert to `jpg` in the `image_conversion_component` and add that format to the `image_formats` so that -images of those formats won't get filtered. \ No newline at end of file +See [`component.yaml`](component.yaml) for a more detailed description on all the input/output parameters. \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/__init__.py b/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/component.yaml b/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/component.yaml index 9d56b051c..c7f691180 100644 --- a/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/component.yaml +++ b/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/component.yaml @@ -1,49 +1,29 @@ name: image_filter_component -description: A component that takes a data manifest as input and filters images based on a set of - criteria. The output manifest will contain the same dataset parquet reference and a modified - index id reference. -inputs: - - name: run_id - description: The run id of the pipeline - type: String - - name: artifact_bucket - description: The GCS bucket used to store the artifact - type: String - - name: component_name - description: the name of the component (used to create gcs artefact path) - type: String - - name: project_id - description: The id of the gcp-project - type: String - - name: min_file_size - description: The minimum size of an image (filter) - type: Integer - - name: max_file_size - description: The maximum size of an image (filter) - type: Integer - - name: image_formats - description: The image formats to keep (filter) - type: JsonArray - - name: data_manifest_path - description: The previous component manifest path - type: String +description: A component that filters images based on certain conditions. +inputs: + - name: extra_args + description: Additional arguments passed to the component, as a json dict string + type: String + + - name: metadata + description: Metadata arguments, passed as a json dict string + type: String + - name: input_manifest + description: Path to the input manifest + type: String + outputs: - - name: data_manifest_path_filter_component - description: Path to the local file containing the gcs path where the output has been stored + - name: output_manifest + description: Path to the output manifest implementation: - container: - image: europe-west1-docker.pkg.dev/storied-landing-366912/storied-landing-366912-default-repository/mlpipelines/kubeflow/components/image_filter_component:latest - command: [ - python3, main.py, - --run-id, { inputValue: run_id }, - --artifact-bucket, { inputValue: artifact_bucket }, - --component-name, { inputValue: component_name }, - --project-id, { inputValue: project_id }, - --min-file-size, { inputValue: min_file_size }, - --max-file-size, { inputValue: max_file_size }, - --image-formats, { inputValue: image_formats }, - --data-manifest-path, { inputPath: data_manifest_path }, - --data-manifest-path-filter-component, { outputPath: data_manifest_path_filter_component }, - ] \ No newline at end of file + container: + image: ghcr.io/ml6team/image_filter_component:latest + command: [ + python3, main.py, + --input-manifest, {inputPath: input_manifest}, + --metadata, {inputValue: metadata}, + --extra-args, {inputValue: extra_args}, + --output-manifest, {outputPath: output_manifest}, + ] \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/requirements.txt b/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/requirements.txt index 8b1378917..c3f48e49c 100644 --- a/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/requirements.txt +++ b/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/requirements.txt @@ -1 +1,2 @@ - +datasets==2.10.1 +git+https://github.com/ml6team/express.git@2d630255f82c6a9d7d15e7be23ff3d9b9ba16723 \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/src/main.py b/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/src/main.py index 285978b4a..34c6e973f 100644 --- a/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/src/main.py +++ b/examples/pipelines/finetune_stable_diffusion/components/image_filter_component/src/main.py @@ -1,199 +1,78 @@ """ -This file is the entrypoint of the component. It will parse all arguments -and give them to the actual core of the component. +This component filters images of the dataset based on image size (minimum height and width). + +Technically, it updates the index of the manifest. """ -import os -import argparse import logging -import json -import tempfile -from pathlib import Path -from datetime import datetime - -import pyarrow.compute as pc -from google.cloud import storage - -# pylint: disable=import-error -from helpers.logger import get_logger -from helpers import storage_helpers, parquet_helpers, kfp_helpers -from helpers.manifest_helpers import DataManifest - - -def parse_args(): - """Parse component arguments""" - - parser = argparse.ArgumentParser() - parser.add_argument('--run-id', - type=str, - required=True, - help='The run id of the pipeline') - parser.add_argument('--artifact-bucket', - type=str, - required=True, - help='The GCS bucket used to store the artifacts') - parser.add_argument('--component-name', - type=str, - required=True, - help='The name of the component') - parser.add_argument('--project-id', - type=str, - required=True, - help='The id of the gcp-project') - parser.add_argument('--min-file-size', - type=int, - required=True, - help='The minimum size of an image (filter)') - parser.add_argument('--max-file-size', - type=int, - required=True, - help='The maximum size of an image (filter)') - parser.add_argument('--image-formats', - type=list, - required=True, - help='The image formats to keep (filter)') - parser.add_argument('--data-manifest-path', - type=str, - required=True, - help='The previous component manifest path') - parser.add_argument('--data-manifest-path-filter-component', - type=str, - required=True, - help='The path to the output manifest file') - - return parser.parse_args() - - -# pylint: disable=too-many-locals, too-many-arguments -def image_filter_component(run_id: str, - artifact_bucket: str, - component_name: str, - project_id: str, - min_file_size: int, - max_file_size: int, - image_formats: list, - data_manifest_path: str, - data_manifest_path_filter_component: str) -> None: - """ - A component that takes a data manifest as an input and filters it according to metadata related - information. - Args: - run_id (str): The run id of the pipeline - artifact_bucket (str): The GCS bucket used to store the artifacts - component_name (str): The name of the component (used to create gcs artefact path) - project_id (str): The id of the gcp-project - min_file_size (int): The minimum size of an image (filter) - max_file_size (int): The maximum size of an image (filter) - image_formats (list): The image formats to keep (filter) - data_manifest_path (str): The previous component manifest path - data_manifest_path_filter_component (str): the path to the output manifest file - """ +from typing import Optional, Union, Dict + +from datasets import Dataset + +from express.components.hf_datasets_components import ( + HFDatasetsTransformComponent, + HFDatasetsDataset, + HFDatasetsDatasetDraft, +) +from express.logger import configure_logging + +configure_logging() +logger = logging.getLogger(__name__) - logger = get_logger(name=__name__, level=logging.INFO) - logger.info('Started job...') - - with tempfile.TemporaryDirectory() as tmp_dir: - logger.info('created temporary directory %s', tmp_dir) - # Parse list variables - image_formats = kfp_helpers.parse_kfp_list(image_formats) - - # Initialize storage client - storage_client = storage.Client(project=project_id) - - # Initialize GCS custom artifact path - component_artifact_dir = run_id.rpartition('-')[0] - artifact_bucket_blob_path = f"custom_artifact/{component_artifact_dir}/{component_name}" - index_parquet_blob_path = os.path.join(artifact_bucket_blob_path, 'index.parquet') - logger.info("custom artifact will be uploaded to %s", - f'gs://{artifact_bucket}/{artifact_bucket_blob_path}') - - index_parquet_tmp_path = os.path.join(tmp_dir, 'index.parquet') - Path(index_parquet_tmp_path).parent.mkdir(parents=True, exist_ok=True) - - # Read manifest from previous component - with open(data_manifest_path) as f: - manifest_load = json.load(f) - data_manifest = DataManifest.from_dict(manifest_load) - - # Get index and dataset parquet gcs paths - index_parquet_prev_gcs_path = data_manifest.index - # TODO: replace harcoded namespace with string or list input - dataset_parquet_prev_gcs_path = data_manifest.associated_data.dataset['cf'] - - # Download parquet fies locally - index_parquet_prev_tmp_path = storage_helpers.download_file_from_bucket( - storage_client, index_parquet_prev_gcs_path, tmp_dir) - dataset_parquet_prev_tmp_path = storage_helpers.download_file_from_bucket( - storage_client, dataset_parquet_prev_gcs_path, tmp_dir) - - # Get indices - index_before_filtering = parquet_helpers.get_column_list_from_parquet( - parquet_scanner_or_path=index_parquet_prev_tmp_path, - column_name='index') - # Construct parquet filters and filter based on the criteria - filters = (pc.field("file_id").isin(index_before_filtering)) & \ - (pc.field("file_size") > pc.scalar(min_file_size)) & \ - (pc.field("file_size") < pc.scalar(max_file_size)) & \ - (pc.field("file_extension").isin(image_formats)) - - filtered_dataset_scanner = parquet_helpers.filter_parquet_file( - file_path=dataset_parquet_prev_tmp_path, - filters=filters) - - # Write new index ids parquet file and upload it to gcs - index_after_filtering = parquet_helpers.get_column_list_from_parquet( - parquet_scanner_or_path=filtered_dataset_scanner, - column_name='file_id') - - parquet_helpers.write_index_parquet( - index_parquet_path=index_parquet_tmp_path, - data_iterable_producer=lambda id_iterable: (id_element for id_element in id_iterable), - id_iterable=index_after_filtering) - - storage_helpers.upload_file_to_bucket(storage_client=storage_client, - file_to_upload_path=index_parquet_tmp_path, - bucket_name=artifact_bucket, - blob_path=index_parquet_blob_path) - - # Estimate the total number of filtered images - nb_images_before_filtering = len(index_before_filtering) - nb_images_after_filtering = parquet_helpers.get_nb_rows_from_parquet(index_parquet_tmp_path) - nb_filtered_image = nb_images_before_filtering - nb_images_after_filtering - percentage_filtered_images = round( - 100 * (nb_filtered_image / nb_images_before_filtering), 2) - - logger.info( - "The original number of images was %s. A total of %s images were filtered (%s%%)", - nb_images_before_filtering, nb_filtered_image, percentage_filtered_images) - - # Update manifest - data_manifest.dataset_id = f"{run_id}_{component_name}" - data_manifest.index = f"gs://{artifact_bucket}/{index_parquet_blob_path}" - data_manifest.metadata.branch = "" # TODO: Fill from docker build env var - data_manifest.metadata.commit_hash = "" # TODO: Fill from docker build env var - data_manifest.metadata.creation_date = datetime.now().strftime("%d-%m-%Y_%H-%M-%S") - - logger.info('Manifest file updated') - - # Write manifest to outputPath - Path(data_manifest_path_filter_component).parent.mkdir(parents=True, exist_ok=True) - Path(data_manifest_path_filter_component).write_text(data_manifest.to_json()) - - logger.info('Manifest file written to %s', data_manifest_path_filter_component) - - # Clean up temporary storage - logger.info('Files removed from temporary storage.') - logger.info('Job completed.') - - -if __name__ == '__main__': - args = parse_args() - image_filter_component( - run_id=args.run_id, - artifact_bucket=args.artifact_bucket, - component_name=args.component_name, - project_id=args.project_id, - max_file_size=args.max_file_size, - min_file_size=args.min_file_size, - image_formats=args.image_formats, - data_manifest_path=args.data_manifest_path, - data_manifest_path_filter_component=args.data_manifest_path_filter_component) + +def check_min_size(example, min_width, min_height): + width, height = example["width"], example["height"] + + return width > min_width and height > min_height + + +class ImageFilterComponent(HFDatasetsTransformComponent): + """ + Class that inherits from Hugging Face data transform. + + Goal is to leverage streaming.""" + + @classmethod + def transform( + cls, + data: HFDatasetsDataset, + extra_args: Optional[Dict[str, Union[str, int, float, bool]]] = None, + ) -> HFDatasetsDatasetDraft: + """ + An example function showcasing the data transform component using Express functionalities + + Args: + data (HFDatasetsDataset[TIndex, TData]): express dataset providing access to data of a + given type + extra_args (Optional[Dict[str, Union[str, int, float, bool]]): optional args to pass to + the function + Returns: + HFDatasetsDatasetDraft: a dataset draft that creates a plan for an output manifest + """ + + # 1) Load one particular data source from the manifest + logger.info("Loading image dataset...") + metadata_dataset = data.load( + data_source="images", columns=["index", "width", "height"] + ) + + # 2) Update index by filtering + logger.info("Filtering dataset...") + min_width, min_height = extra_args["min_width"], extra_args["min_height"] + filtered_dataset = metadata_dataset.filter( + lambda example: example["width"] > min_width + and example["height"] > min_height + ) + index_dataset = Dataset.from_dict({"index": filtered_dataset["index"]}) + + # 3) Create dataset draft which updates the index + # but maintains the same data sources + logger.info("Creating draft...") + dataset_draft = HFDatasetsDatasetDraft( + index=index_dataset, data_sources=data.manifest.data_sources + ) + + return dataset_draft + + +if __name__ == "__main__": + ImageFilterComponent.run() diff --git a/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/Dockerfile b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/Dockerfile new file mode 100644 index 000000000..49b7639f9 --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/Dockerfile @@ -0,0 +1,29 @@ +FROM --platform=linux/amd64 python:3.8-slim + +## System dependencies +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install git curl -y + +# Downloading gcloud package +RUN curl https://dl.google.com/dl/cloudsdk/release/google-cloud-sdk.tar.gz > /tmp/google-cloud-sdk.tar.gz + +# Installing the package +RUN mkdir -p /usr/local/gcloud \ +&& tar -C /usr/local/gcloud -xvf /tmp/google-cloud-sdk.tar.gz \ +&& /usr/local/gcloud/google-cloud-sdk/install.sh + +# Adding the package path to local +ENV PATH $PATH:/usr/local/gcloud/google-cloud-sdk/bin + +# install requirements +COPY requirements.txt / +RUN pip3 install --no-cache-dir -r requirements.txt + +# Copy over src-files of the component +COPY src /src + +# Set the working directory to the source folder +WORKDIR /src + +ENTRYPOINT ["python", "main.py"] \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/README.MD b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/README.MD new file mode 100644 index 000000000..96afd1bf1 --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/README.MD @@ -0,0 +1,13 @@ +# dataset_loader_component + +### Description + +This is the first component of the pipeline, it loads an image dataset from the πŸ€— [hub](https://huggingface.co/) and creates the initial manifest. This manifest includes 3 data sources: images, captions and image metadata. + +### **Inputs/Outputs** +The component accepts a `dataset_name` as input which refers to a dataset on the πŸ€— hub. + +The component creates the first manifest as output that will be updated by subsequent components when new data sources +are added/filtered. + +See [`component.yaml`](component.yaml) for a more detailed description on all the input/output parameters. \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/__init__.py b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/__init__.py similarity index 100% rename from examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/__init__.py rename to examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/__init__.py diff --git a/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/component.yaml b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/component.yaml new file mode 100644 index 000000000..6eae2dd79 --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/component.yaml @@ -0,0 +1,24 @@ +name: load_from_hub_component +description: A basic component that takes a dataset name from the πŸ€— hub as input and uploads it to a GCS bucket. +inputs: + - name: extra_args + description: Additional arguments passed to the component, as a json dict string + type: String + + - name: metadata_args + description: Metadata arguments, passed as a json dict string + type: String + +outputs: + - name: output_manifest + description: Path to the output manifest + +implementation: + container: + image: ghcr.io/ml6team/load_from_hub_component:latest + command: [ + python3, main.py, + --extra-args, {inputValue: extra_args}, + --metadata-args, {inputValue: metadata_args}, + --output-manifest, {outputPath: output_manifest}, + ] \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/requirements.txt b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/requirements.txt new file mode 100644 index 000000000..db1781885 --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/requirements.txt @@ -0,0 +1,3 @@ +datasets==2.10.1 +git+https://github.com/ml6team/express.git@main +Pillow==9.4.0 \ No newline at end of file diff --git a/examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/src/__init__.py b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/src/__init__.py similarity index 100% rename from examples/pipelines/finetune_stable_diffusion/components/dataset_loader_component/src/__init__.py rename to examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/src/__init__.py diff --git a/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/src/main.py b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/src/main.py new file mode 100644 index 000000000..08a782ee5 --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/components/load_from_hub_component/src/main.py @@ -0,0 +1,84 @@ +""" +This component loads a seed dataset from the hub and creates the initial manifest. +""" +import logging +import sys +from typing import Optional, Union, Dict + +from datasets import Dataset, load_dataset + +from express.components.hf_datasets_components import ( + HFDatasetsLoaderComponent, + HFDatasetsDatasetDraft, +) +from express.logger import configure_logging + +configure_logging() +logger = logging.getLogger(__name__) + + +def create_image_metadata(batch): + images = batch["image"] + + # add width, height and byte size columns + widths, heights = zip(*[image.size for image in images]) + batch["width"] = list(widths) + batch["height"] = list(heights) + batch["byte_size"] = [sys.getsizeof(image.tobytes()) for image in images] + + return batch + + +class LoadFromHubComponent(HFDatasetsLoaderComponent): + """Component that loads a dataset from the hub and creates the initial manifest.""" + + @classmethod + def load( + cls, extra_args: Optional[Dict[str, Union[str, int, float, bool]]] = None + ) -> HFDatasetsDatasetDraft: + """ + An example function showcasing the data loader component using Express functionalities + Args: + extra_args (Optional[Dict[str, Union[str, int, float, bool]]): optional args to pass to + the function (e.g. seed data source) + Returns: + HFDatasetsDatasetDraft: a dataset draft that creates a plan for an output manifest + """ + + # 1) Create data source + logger.info("Loading caption dataset from the hub...") + # TODO perhaps leverage streaming + dataset = load_dataset(extra_args["dataset_name"], split="train") + + # 2) Create an example index + logger.info("Creating index...") + index_list = [f"image_{idx}" for idx in range(len(dataset))] + + # 3) Create dataset draft (manifest without metadata) + # We store the index itself also as a HF Dataset + logger.info("Creating draft...") + index_dataset = Dataset.from_dict({"index": index_list}) + image_dataset = dataset.remove_columns(["text"]).add_column( + name="index", column=index_list + ) + text_dataset = dataset.remove_columns(["image"]).add_column( + name="index", column=index_list + ) + image_dataset = image_dataset.map( + create_image_metadata, + batched=True, + batch_size=extra_args["batch_size"], + ) + data_sources = { + "images": image_dataset, + "captions": text_dataset, + } + dataset_draft = HFDatasetsDatasetDraft( + index=index_dataset, data_sources=data_sources + ) + + return dataset_draft + + +if __name__ == "__main__": + LoadFromHubComponent.run() diff --git a/examples/pipelines/finetune_stable_diffusion/config/components_config.py b/examples/pipelines/finetune_stable_diffusion/config/components_config.py new file mode 100644 index 000000000..df19ccc14 --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/config/components_config.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass + + +@dataclass +class LoadFromHubConfig: + """ + Configs for the dataset loader component + Params: + DATASET_NAME (str): Name of the dataset on the hub. + BATCH_SIZE (int): Batch size to use when creating image metadata. + """ + + DATASET_NAME = "lambdalabs/pokemon-blip-captions" + BATCH_SIZE = 1000 + + +@dataclass +class ImageFilterConfig: + """ + Configs for the image filter component + Params: + MIN_HEIGHT (int): Minimum height for each image. + MIN_WIDTH (int): Minimum width for each image. + """ + + MIN_HEIGHT = 200 + MIN_WIDTH = 200 + + +@dataclass +class EmbeddingConfig: + """ + Configs for the embedding component + Params: + MODEL_ID (int): HF model id to use for embedding. + BATCH_SIZE (int): Batch size to use when embedding. + """ + + MODEL_ID = "openai/clip-vit-large-patch14" + BATCH_SIZE = 10 + + +@dataclass +class ClipRetrievalConfig: + """ + Configs for CLIP image retrieval component + Params: + LAION_INDEX_URL(str): url of the indices of the metadata. Those indices need to be + transformed in case you decide to use only a subset of the dataset + LAION_METADATA_URL (str): url to the metadata of laion dataset metadata (arrow format). It + can either contain a subset of the laion 5b metadata (e.g. laion-en) or all of the metadata + NB_IMAGES_KNN (int): The ratio of number of image to retrieve via the knn strategy + (per image) + NB_IMAGES_CENTROID (int): The ratio of number of image to retrieve via the centroid strategy + """ + + LAION_INDEX_URL = "gs://express-sd-datasets/laion-5b/2b-en/image.index/*" + LAION_METADATA_URL = ( + "gs://express-sd-datasets/laion-5b/metadata/metadata/2B-en.arrow" + ) + NUM_IMAGES_KNN = 500 + NUM_IMAGES_CENTROID = 1_000_000 diff --git a/examples/pipelines/finetune_stable_diffusion/config/general_config.py b/examples/pipelines/finetune_stable_diffusion/config/general_config.py new file mode 100644 index 000000000..f54635e3e --- /dev/null +++ b/examples/pipelines/finetune_stable_diffusion/config/general_config.py @@ -0,0 +1,33 @@ +"""General config""" + +import os + +from dataclasses import dataclass + +@dataclass +class GeneralConfig: + """ + General configs + Params: + GCP_PROJECT_ID (str): GCP project ID + DATASET_NAME (str): name of the Hugging Face dataset + ENV (str): the project run environment (sbx, dev, prd) + """ + GCP_PROJECT_ID = "soy-audio-379412" + ENV = os.environ.get('ENV', 'dev') + + +@dataclass +class KubeflowConfig(GeneralConfig): + """ + Configs for the Kubeflow cluster + Params: + ARTIFACT_BUCKET (str): the GCS bucket used to store the artifacts + CLUSTER_NAME (str): the name of the k8 cluster hosting KFP + CLUSTER_ZONE (str): the zone of the k8 cluster hosting KFP + HOST (str): the kfp host url + """ + ARTIFACT_BUCKET = f"{GeneralConfig.GCP_PROJECT_ID}_kfp-artifacts" + CLUSTER_NAME = "kfp-express" + CLUSTER_ZONE = "europe-west4-a" + HOST = "https://472c61c751ab9be9-dot-europe-west1.pipelines.googleusercontent.com" diff --git a/examples/pipelines/finetune_stable_diffusion/dataset_creation_pipeline.py b/examples/pipelines/finetune_stable_diffusion/dataset_creation_pipeline.py index 4b6b290a8..815ee8c91 100644 --- a/examples/pipelines/finetune_stable_diffusion/dataset_creation_pipeline.py +++ b/examples/pipelines/finetune_stable_diffusion/dataset_creation_pipeline.py @@ -1,234 +1,163 @@ -"""Pipeline used to create a stable diffusion dataset from a set of given images. This is done by -using clip retrieval on the LAION dataset""" +"""Pipeline used to create a stable diffusion dataset from a set of given images.""" # pylint: disable=import-error -import sys -import os -import logging - -sys.path.insert(0, os.path.abspath('..')) - -from config import GeneralConfig, KubeflowConfig -from pipelines_config.dataset_creation_config import DatasetLoaderConfig, ImageFilterConfig, \ - ImageConversionConfig, ImageEmbeddingConfig, ImageCaptionConfig, ClipRetrievalConfig, \ - ClipDownloaderConfig -from express.pipeline_utils import compile_and_upload_pipeline -from express.logger import configure_logging +import json from kfp import components as comp from kfp import dsl + from kubernetes import client as k8s_client -configure_logging() +from config.general_config import KubeflowConfig +from config.components_config import ( + LoadFromHubConfig, + ImageFilterConfig, + EmbeddingConfig, + ClipRetrievalConfig, +) +from express.pipeline_utils import compile_and_upload_pipeline -LOGGER = logging.getLogger(__name__) +# Load Components +run_id = "{{workflow.name}}" +artifact_bucket = KubeflowConfig.ARTIFACT_BUCKET -# Load Component -dataset_loader_component = comp.load_component( - 'components/dataset_loader_component/component.yaml') -image_filter_component = comp.load_component( - 'components/image_filter_component/component.yaml') -image_conversion_component = comp.load_component( - 'components/image_conversion_component/component.yaml') -image_embedding_component = comp.load_component( - 'components/image_embedding_component/component.yaml') -clip_retrieval_component = comp.load_component( - 'components/clip_retrieval_component/component.yaml') -clip_downloader_component = comp.load_component( - 'components/clip_downloader_component/component.yaml') -image_caption_component = comp.load_component( - 'components/image_caption_component/component.yaml') +# Component 1 +load_from_hub_op = comp.load_component( + "components/load_from_hub_component/component.yaml" +) +load_from_hub_extra_args = { + "dataset_name": LoadFromHubConfig.DATASET_NAME, + "batch_size": LoadFromHubConfig.BATCH_SIZE, +} +load_from_hub_metadata_args = { + "run_id": run_id, + "component_name": load_from_hub_op.__name__, + "artifact_bucket": artifact_bucket, +} +load_from_hub_extra_args = json.dumps(load_from_hub_extra_args) +load_from_hub_metadata_args = json.dumps(load_from_hub_metadata_args) + +# Component 2 +image_filter_op = comp.load_component( + "components/image_filter_component/component.yaml" +) +image_filter_extra_args = { + "min_height": ImageFilterConfig.MIN_HEIGHT, + "min_width": ImageFilterConfig.MIN_WIDTH, +} +image_filter_metadata_args = { + "run_id": run_id, + "component_name": image_filter_op.__name__, + "artifact_bucket": artifact_bucket, +} +image_filter_extra_args = json.dumps(image_filter_extra_args) +image_filter_metadata_args = json.dumps(image_filter_metadata_args) + + +# Component 3 +embedding_op = comp.load_component("components/embedding_component/component.yaml") +embedding_extra_args = { + "model_id": EmbeddingConfig.MODEL_ID, + "batch_size": EmbeddingConfig.BATCH_SIZE, +} +embedding_metadata_args = { + "run_id": run_id, + "component_name": embedding_op.__name__, + "artifact_bucket": artifact_bucket, +} +embedding_extra_args = json.dumps(embedding_extra_args) +embedding_metadata_args = json.dumps(embedding_metadata_args) + + +# Component 4 +clip_retrieval_op = comp.load_component( + "components/clip_retrieval_component/component.yaml" +) +clip_retrieval_extra_args = { + "model_id": ClipRetrievalConfig.LAION_INDEX_URL, + "batch_size": ClipRetrievalConfig.LAION_METADATA_URL, + "num_images_knn": ClipRetrievalConfig.NUM_IMAGES_KNN, + "num_images_centroid": ClipRetrievalConfig.NUM_IMAGES_CENTROID, +} +clip_retrieval_metadata_args = { + "run_id": run_id, + "component_name": clip_retrieval_op.__name__, + "artifact_bucket": artifact_bucket, +} +clip_retrieval_extra_args = json.dumps(clip_retrieval_extra_args) +clip_retrieval_metadata_args = json.dumps(clip_retrieval_metadata_args) # Pipeline @dsl.pipeline( - name='image-generator-dataset', - description='Pipeline that takes example images as input and returns an expanded dataset of ' - 'similar images as outputs' + name="image-generator-dataset", + description="Pipeline that takes example images as input and returns an expanded dataset of " + "similar images as outputs", ) # pylint: disable=too-many-arguments, too-many-locals def sd_dataset_creator_pipeline( - source_dataset_bucket: str = DatasetLoaderConfig.SOURCE_DATASET_BUCKET, - project_id: str = GeneralConfig.GCP_PROJECT_ID, - source_dataset_blob: str = DatasetLoaderConfig.SOURCE_DATASET_BLOB, - namespace: str = DatasetLoaderConfig.NAMESPACE, - max_file_size: int = ImageFilterConfig.MAX_FILE_SIZE, - min_file_size: int = ImageFilterConfig.MIN_FILE_SIZE, - image_formats: list = ImageFilterConfig.IMAGE_FORMATS, - file_extensions: list = ImageConversionConfig.FILE_EXTENSIONS, - svg_image_width: int = ImageConversionConfig.SVG_IMAGE_WIDTH, - svg_image_height: int = ImageConversionConfig.SVG_IMAGE_HEIGHT, - clip_batch_size: int = ImageEmbeddingConfig.BATCH_SIZE, - laion_index_url: str = ClipRetrievalConfig.LAION_INDEX_URL, - laion_metadata_url: str = ClipRetrievalConfig.LAION_METADATA_URL, - nb_images_knn: int = ClipRetrievalConfig.NB_IMAGES_KNN, - nb_images_centroid: int = ClipRetrievalConfig.NB_IMAGES_CENTROID, - image_resize: int = ClipDownloaderConfig.IMAGE_RESIZE, - timeout: int = ClipDownloaderConfig.TIMEOUT, - min_image_size: int = ClipDownloaderConfig.MIN_IMAGE_SIZE, - max_image_area: int = ClipDownloaderConfig.MAX_IMAGE_AREA, - min_length: int = ImageCaptionConfig.MIN_LENGTH, - max_length: int = ImageCaptionConfig.MAX_LENGTH, - blip_batch_size: int = ImageCaptionConfig.BATCH_SIZE, - beams: int = ImageCaptionConfig.BEAMS + load_from_hub_extra_args: str = load_from_hub_extra_args, + load_from_hub_metadata_args: str = load_from_hub_metadata_args, + image_filter_extra_args: str = image_filter_extra_args, + image_filter_metadata_args: str = image_filter_metadata_args, + embedding_extra_args: str = embedding_extra_args, + embedding_metadata_args: str = embedding_metadata_args, + clip_retrieval_extra_args: str = clip_retrieval_extra_args, + clip_retrieval_metadata_args: str = clip_retrieval_metadata_args, ): - """ - Pipeline that takes example images as input and returns an expanded dataset of - similar images as outputs - Args: - # General - project_id (str): project ID string - # Dataset loader component - source_dataset_bucket (str): The GCS bucket containing the dataset to load - source_dataset_blob (str): The GCS blob withing the specified bucket containing the - dataset to load - namespace (str): The dataset namespace (abbreviation for data source) - # Dataset filter component - max_file_size (int): The maximum size of an image (filter) - min_file_size (int): The minimum size of an image (filter) - image_formats (list): The image formats to keep (filter) - # Dataset conversion component - file_extensions (list): The list of image file extensions to convert - svg_image_width (int): the desired width to scale the converted SVG image to - svg_image_height (int): the desired width to scale the converted SVG image to - # Dataset embedding component - clip_batch_size (int): the bath size used to batch the images before embedding - # Clip retrieval component - laion_index_url (str): contains the indices of the metadata. Those indices need to be - transformed in case you decide to use only a subset of the dataset - laion_metadata_url (str): url to the metadata of laion dataset metadata (arrow format). - It can either contain a subset of the laion 5b metadata (e.g. laion-en) or all of the - metadata - nb_images_knn (int): The number of images to return with knn method (per image) - nb_images_centroid (int): The number of images to return with centroid method - # Clip downloader component - image_resize (int): the size to resize the image - timeout (int): maximum time (in seconds to wait) when trying to download an image - min_image_size (int): minimum size of the image to download - (considers the min of width and height) - max_image_area (int): The maximum area (nr of pixels) of the images to download - # Dataset caption component - min_length (str): The minimum caption length to generate - max_length (str): the maximum caption length to generate - blip_batch_size (int): the batch size of the images to caption - beams (int): The blip beam parameters - """ - # pylint: disable=not-callable,unused-variable - run_id = '{{pod.name}}' - artifact_bucket = KubeflowConfig.ARTIFACT_BUCKET - # Define necessary volume mounts (local ssd) - local_ssd_volume = dsl.PipelineVolume(volume=k8s_client.V1Volume( - name="scratch-volume", - empty_dir=k8s_client.V1EmptyDirVolumeSource())) - - # Define components - dataset_loader_task = dataset_loader_component( - run_id=run_id, - artifact_bucket=artifact_bucket, - component_name=dataset_loader_component.__name__, - project_id=project_id, - source_dataset_bucket=source_dataset_bucket, - source_dataset_blob=source_dataset_blob, - namespace=namespace).set_display_name('Load Images') - - image_filter_task = image_filter_component( - run_id=run_id, - artifact_bucket=artifact_bucket, - component_name=image_filter_component.__name__, - project_id=project_id, - max_file_size=max_file_size, - min_file_size=min_file_size, - image_formats=image_formats, - data_manifest_path=dataset_loader_task.outputs['data_manifest_path']) \ - .set_display_name('Filter Images') - - image_conversion_task = image_conversion_component( - run_id=run_id, - artifact_bucket=artifact_bucket, - component_name=image_conversion_component.__name__, - project_id=project_id, - file_extensions=file_extensions, - svg_image_width=svg_image_width, - svg_image_height=svg_image_height, - data_manifest_path=image_filter_task.outputs['data_manifest_path_filter_component']) \ - .set_display_name('Convert Image Format') \ - .add_node_selector_constraint('node_pool', 'burst-zone') \ - .add_toleration( - k8s_client.V1Toleration(effect='NoSchedule', key='reserved-pool', operator='Equal', - value='true')) - - image_embedding_task = image_embedding_component( - run_id=run_id, - artifact_bucket=artifact_bucket, - component_name=image_embedding_component.__name__, - project_id=project_id, - batch_size=clip_batch_size, - data_manifest_path=image_conversion_task.outputs[ - 'data_manifest_path_image_conversion_component']) \ - .set_display_name('Embed Images') \ - .set_gpu_limit(1) \ - .add_node_selector_constraint('node_pool', 'model-inference-pool') \ + local_ssd_volume = dsl.PipelineVolume( + volume=k8s_client.V1Volume( + name="scratch-volume", empty_dir=k8s_client.V1EmptyDirVolumeSource() + ) + ) + + # Component 1 + load_from_hub_task = load_from_hub_op( + extra_args=load_from_hub_extra_args, + metadata_args=load_from_hub_metadata_args, + ).set_display_name("Load initial images") + + # Component 2 + image_filter_task = image_filter_op( + extra_args=image_filter_extra_args, + metadata=image_filter_metadata_args, + input_manifest=load_from_hub_task.outputs["output_manifest"], + ).set_display_name("Filter images") + + # Component 3 + embedding_task = ( + embedding_op( + extra_args=embedding_extra_args, + metadata=embedding_metadata_args, + input_manifest=image_filter_task.outputs["output_manifest"], + ) + .set_display_name("Embed images") + .set_gpu_limit(1) + .add_node_selector_constraint("node_pool", "model-inference-pool") .add_toleration( - k8s_client.V1Toleration(effect='NoSchedule', key='reserved-pool', operator='Equal', - value='true')) - - clip_retrieval_task = clip_retrieval_component( - run_id=run_id, - artifact_bucket=artifact_bucket, - component_name=clip_retrieval_component.__name__, - project_id=project_id, - laion_index_url=laion_index_url, - laion_metadata_url=laion_metadata_url, - nb_images_knn=nb_images_knn, - nb_images_centroid=nb_images_centroid, - data_manifest_path=image_embedding_task.outputs[ - 'data_manifest_path_embedding_component']) \ - .set_display_name('Clip retrieval') \ - .set_ephemeral_storage_request('2T') \ - .add_pvolumes({'/cache': local_ssd_volume}) \ - .add_node_selector_constraint('node_pool', 'nvme-pool') - - clip_downloader_task = clip_downloader_component( - run_id=run_id, - artifact_bucket=artifact_bucket, - component_name=clip_downloader_component.__name__, - project_id=project_id, - image_resize=image_resize, - timeout=timeout, - min_image_size=min_image_size, - max_image_area=max_image_area, - data_manifest_path=clip_retrieval_task.outputs[ - 'data_manifest_path_clip_retrieval_component'], - parquet_path_clip_knn_retrieval=clip_retrieval_task.outputs[ - 'parquet_path_clip_knn_retrieval'], - parquet_path_clip_centroid_retrieval=clip_retrieval_task.outputs[ - 'parquet_path_clip_centroid_retrieval']) \ - .set_display_name('Clip Image downloader') \ - .add_node_selector_constraint('node_pool', 'burst-zone') \ - .add_toleration(k8s_client.V1Toleration - (effect='NoSchedule', key='reserved-pool', operator='Equal', - value='true')) - - image_caption_task = image_caption_component( - run_id=run_id, - artifact_bucket=artifact_bucket, - component_name=image_caption_component.__name__, - project_id=project_id, - min_length=min_length, - max_length=max_length, - batch_size=blip_batch_size, - beams=beams, - data_manifest_path=clip_downloader_task.outputs[ - 'data_manifest_path_clip_downloader_component']) \ - .set_display_name('Caption Images') \ - .set_gpu_limit(1) \ - .add_node_selector_constraint('node_pool', 'model-inference-pool') \ - .add_toleration( - k8s_client.V1Toleration(effect='NoSchedule', key='reserved-pool', operator='Equal', - value='true')) - - -if __name__ == '__main__': - compile_and_upload_pipeline(pipeline=sd_dataset_creator_pipeline, - host=KubeflowConfig.HOST, - env=KubeflowConfig.ENV) + k8s_client.V1Toleration( + effect="NoSchedule", key="reserved-pool", operator="Equal", value="true" + ) + ) + ) + + # Component 4 + clip_retrieval_op( + extra_args=clip_retrieval_extra_args, + metadata=clip_retrieval_metadata_args, + input_manifest=embedding_task.outputs["output_manifest"], + ).set_display_name("Retrieve images").set_ephemeral_storage_request( + "2T" + ).add_pvolumes( + {"/cache": local_ssd_volume} + ).add_node_selector_constraint( + "node_pool", "nvme-pool" + ) + + +if __name__ == "__main__": + compile_and_upload_pipeline( + pipeline=sd_dataset_creator_pipeline, + host=KubeflowConfig.HOST, + env=KubeflowConfig.ENV, + ) diff --git a/examples/pipelines/finetune_stable_diffusion/pipelines_config/dataset_creation_config.py b/examples/pipelines/finetune_stable_diffusion/pipelines_config/dataset_creation_config.py deleted file mode 100644 index 454df0619..000000000 --- a/examples/pipelines/finetune_stable_diffusion/pipelines_config/dataset_creation_config.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Dataset creation pipeline config""" - -from dataclasses import dataclass - - -@dataclass -class DatasetLoaderConfig: - """ - Configs for the dataset loader component - Params: - SOURCE_DATASET_BUCKET (str): The GCS bucket containing the dataset to load - SOURCE_DATASET_BLOB (str): the zone of the k8 cluster hosting KFP - NAMESPACE (str): The dataset namespace (abbreviation for data source) - """ - SOURCE_DATASET_BUCKET = "express-datasets" - SOURCE_DATASET_BLOB = "initial-clean-cut-dataset" - NAMESPACE = "cf" - - -@dataclass -class ImageFilterConfig: - """ - Configs for the dataset filter component - Params: - MIN_FILE_SIZE (str): The minimum size of an image in bytes (filter) - MAX_FILE_SIZE (str): The maximum size of an image in bytes (filter) - IMAGE_FORMATS (list): The image formats to keep, any other formats that are not included - will be filtered from the dataset - """ - MIN_FILE_SIZE = 100_000 # 100kb - MAX_FILE_SIZE = 5_000_000 # 5mb - # Image formats notation from 'type' in field in GCS - IMAGE_FORMATS = ['jpeg', 'jpg', 'png', 'svg'] - - -@dataclass -class ImageConversionConfig: - """ - Configs for dataset image converter component - Params: - FILE_EXTENSIONS (list): The list of image file extensions to convert from - SVG_IMAGE_WIDTH (int): the desired width to scale the converted SVG image to - SVG_IMAGE_HEIGHT (int): the desired width to scale the converted SVG image to - """ - FILE_EXTENSIONS = ['png', 'svg'] - SVG_IMAGE_WIDTH = 1024 - SVG_IMAGE_HEIGHT = 1024 - - -@dataclass -class ImageEmbeddingConfig: - """ - Configs for dataset image embedding component - Params: - BATCH_SIZE (int): the batch size used to batch the images before embedding - """ - BATCH_SIZE = 8 - - -@dataclass -class ClipRetrievalConfig: - """ - Configs for dataset image converter component - Params: - LAION_INDEX_URL(str): contains the indices of the metadata. Those indices need to be - transformed in case you decide to use only a subset of the dataset - LAION_METADATA_URL (str): url to the metadata of laion dataset metadata (arrow format). It - can either contain a subset of the laion 5b metadata (e.g. laion-en) or all of the metadata - NB_IMAGES_KNN (int): The ratio of number of image to retrieve via the knn strategy - (per image) - NB_IMAGES_CENTROID (int): The ratio of number of image to retrieve via the centroid strategy - """ - LAION_INDEX_URL = "gs://express-sd-datasets/laion-5b/2b-en/image.index/*" - LAION_METADATA_URL = "gs://express-sd-datasets/laion-5b/metadata/metadata/2B-en.arrow" - NB_IMAGES_KNN = 500 - NB_IMAGES_CENTROID = 1_000_000 - - -@dataclass -class ClipDownloaderConfig: - """ - Configs for dataset image converter component - Params: - IMAGE_RESIZE (int): the size to resize the image - TIMEOUT (int): maximum time (in seconds to wait) when trying to download an image - MIN_IMAGE_SIZE (int): minimum size of the image to download - (considers the min of width and height) - MAX_IMAGE_AREA (int): The maximum area (nr of pixels) of the images to download - """ - IMAGE_RESIZE = 512 - TIMEOUT = 5 - MIN_IMAGE_SIZE = 100 - MAX_IMAGE_AREA = 178956870 - - -@dataclass -class ImageCaptionConfig: - """ - Configs for dataset image converter component - Params: - MIN_LENGTH (str): The minimum caption length - MAX_LENGTH (str): the maximum caption length - BEAMS (int): The blip beam parameters - BATCH_SIZE (int): The batch size of images to pass to the blip model - """ - MIN_LENGTH = 10 - MAX_LENGTH = 20 - BATCH_SIZE = 100 - BEAMS = 1 diff --git a/examples/pipelines/finetune_stable_diffusion/pipelines_config/sd_finetuning_config.py b/examples/pipelines/finetune_stable_diffusion/pipelines_config/sd_finetuning_config.py deleted file mode 100644 index 2ef71469e..000000000 --- a/examples/pipelines/finetune_stable_diffusion/pipelines_config/sd_finetuning_config.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Stable diffusion pipeline config""" - -from dataclasses import dataclass - - -# pylint: disable=too-many-instance-attributes -@dataclass -class StableDiffusionFinetuningConfig: - """ - Configs for dataset image converter component - Params: - DATA_MANIFEST_GCS_PATH (str): the path to the data manifest that contains information on the - training set - pretrained_model_gcs_path (str): Model identifier from huggingface.co/models or gcs path to a model - SEED (int): A seed for reproducible training. - RESOLUTION (int): The resolution for input images, all the images in the train/validation - dataset will be resized to this resolution - TRAIN_BATCH_SIZE (int): Batch size (per device) for the training dataloader - NUM_TRAIN_EPOCHS (int): Total number of epochs to perform - MAX_TRAIN_STEPS (int): Total number of training steps to perform. If provided overrides - `num_train_epochs` - CHECKPOINTING_STEPS (int): Save a checkpoint of the training state every X updates. These checkpoints are only - suitable for resuming training using `--resume_from_checkpoint`. - GRADIENT_ACCUMULATION_STEPS (int): The number of updates steps to accumulate before - performing a backward/update pass - GRADIENT_CHECKPOINTING (Union[str,None]): Whether to use gradient checkpointing to save memory - at the expense of slower backward pass - LEARNING_RATE (float): Initial learning rate (after the potential warmup period) to use. - SCALE_LR (Union[str,None]): Scale the learning rate by the number of GPUs, gradient accumulation steps, - and batch size - LR_WARMUP_STEPS (int): Scale the learning rate by the number of GPUs, gradient accumulation - steps, and batch size - LR_SCHEDULER (str): The scheduler type to use. Choose between ["linear", "cosine", - cosine_with_restarts", "polynomial", "constant","constant_with_warmup"] - USE_EMA (Union[str,None]): Whether to use EMA model - MIXED_PRECISION (str): Whether to use mixed precision. Choose between fp16 and bf16 - (bfloat16). Bf16 requires PyTorch >=1.10.and an Nvidia Ampere GPU. - Default to the value of accelerate config of the current system or the flag passed with the - `accelerate.launch` command. Use this argument to override the `accelerate` config - CENTER_CROP (Union[str,None]): whether to center crop images before resizing to resolution (if not set, - random crop will be used) - RANDOM_FLIP (Union[str,None]): whether to randomly flip images horizontally - - """ - DATA_MANIFEST_GCS_PATH = "gs://storied-landing-366912-kfp-output/artifacts/image-generator-dataset-mlfdr/2022/12/21/image-generator-dataset-mlfdr-1067088695/image-caption-component-data_manifest_path_caption_component.tgz" - PRETRAINED_MODEL_GCS_PATH = "gs://express-models/stable-diffusion-v1-5-fp32" - # TODO: Only the most relevant params from ":https://github.com/huggingface/diffusers/blob/ - # main/examples/ text_to_image/train_text_to_image.py" were specified here. Check later whether - # to include additional relevant arguments (right now it uses the default arguments) - SEED = 1024 - RESOLUTION = 512 - # Batch size of 4 is the maximum batch size that can be set to train when training on two A100s - # GPUs without running in 'Out of memory' issues - TRAIN_BATCH_SIZE = 4 - NUM_TRAIN_EPOCHS = 100 - MAX_TRAIN_STEPS = 25 # overwrites training epochs if defined - CHECKPOINTING_STEPS = 10 - GRADIENT_ACCUMULATION_STEPS = 4 - GRADIENT_CHECKPOINTING = "True" # "True" or None* - LEARNING_RATE = 1e-5 - SCALE_LR = None # "True" or None* - LR_WARMUP_STEPS = 0 - LR_SCHEDULER = "constant" - USE_EMA = "True" # "True" or None* - MIXED_PRECISION = "fp16" - CENTER_CROP = "True" # "True" or None* - RANDOM_FLIP = "True" # "True" or None* - RESUME_FROM_CHECKPOINT = None -# *Kubeflow does not support passing in boolean variables (bug where True is always returned). -# Workaround for optional arguments implemented where "True" is passed as a string to include the -# argument and None is passed to omit the argument diff --git a/examples/pipelines/finetune_stable_diffusion/sd_finetuning_pipeline.py b/examples/pipelines/finetune_stable_diffusion/sd_finetuning_pipeline.py index 53085d292..a6a02ce91 100644 --- a/examples/pipelines/finetune_stable_diffusion/sd_finetuning_pipeline.py +++ b/examples/pipelines/finetune_stable_diffusion/sd_finetuning_pipeline.py @@ -100,7 +100,7 @@ def sd_finetuning_pipeline( run_id = '{{pod.name}}' artifact_bucket = KubeflowConfig.ARTIFACT_BUCKET - sd_finetuning_task = sd_finetuning_component( + sd_finetuning_component( project_id=GeneralConfig.GCP_PROJECT_ID, run_id=run_id, artifact_bucket=artifact_bucket, diff --git a/examples/pipelines/hf_dataset_pipeline/components/load_from_hub/src/main.py b/examples/pipelines/hf_dataset_pipeline/components/load_from_hub/src/main.py index e1313abea..c1491916a 100644 --- a/examples/pipelines/hf_dataset_pipeline/components/load_from_hub/src/main.py +++ b/examples/pipelines/hf_dataset_pipeline/components/load_from_hub/src/main.py @@ -5,7 +5,7 @@ from typing import Optional, Union, Dict import pandas as pd -from datasets import Dataset, load_dataset, concatenate_datasets +from datasets import Dataset, load_dataset from express.components.hf_datasets_components import HFDatasetsLoaderComponent, HFDatasetsDatasetDraft from express.logger import configure_logging diff --git a/examples/pipelines/hf_dataset_pipeline/hf_dataset_pipeline.py b/examples/pipelines/hf_dataset_pipeline/hf_dataset_pipeline.py index 7e6599bd6..b6110fdcd 100644 --- a/examples/pipelines/hf_dataset_pipeline/hf_dataset_pipeline.py +++ b/examples/pipelines/hf_dataset_pipeline/hf_dataset_pipeline.py @@ -45,7 +45,7 @@ def hf_dataset_pipeline(load_from_hub_extra_args: str = load_from_hub_extra_args ).set_display_name('Load from hub component') # Component 2 - add_captions_task = add_captions_op(extra_args=add_captions_extra_args, + add_captions_op(extra_args=add_captions_extra_args, metadata=add_captions_metadata_args, input_manifest=load_from_hub_task.outputs["output_manifest"], ).set_display_name('Add captions component') diff --git a/express/components/common.py b/express/components/common.py index 10d19e0e4..680c3e2de 100644 --- a/express/components/common.py +++ b/express/components/common.py @@ -48,7 +48,9 @@ def load_index(self) -> IndexT: Loads the index data. """ - def load(self, data_source: str, for_index: Optional[IndexT] = None) -> DataT: + def load( + self, data_source: str, for_index: Optional[IndexT] = None, **kwargs + ) -> DataT: """ Load data from a named data source. @@ -57,6 +59,7 @@ def load(self, data_source: str, for_index: Optional[IndexT] = None) -> DataT: for_index (Optional[TIndex]): Pass in an index to filter the data on. By default, the original Dataset index is used. This argument can be used to use a different index instead. + kwargs (dict): Additional keyword arguments forwarded to the _load_data_source method. Returns: TData: Data of type TData @@ -69,13 +72,15 @@ def load(self, data_source: str, for_index: Optional[IndexT] = None) -> DataT: if for_index is None: for_index = self._index_data return self._load_data_source( - self.manifest.data_sources[data_source], for_index + self.manifest.data_sources[data_source], for_index, **kwargs ) @staticmethod @abstractmethod def _load_data_source( - data_source: DataSource, index_filter: Optional[IndexT] + data_source: DataSource, + index_filter: Optional[IndexT], + **kwargs, ) -> DataT: """ Load data from a (possibly remote) path. @@ -130,7 +135,7 @@ def __init__( "added to an extending dataset draft after it's been constructed." ) self.index = extending_dataset.manifest.index - for name, dataset in extending_dataset.manifest.associated_data.items(): + for name, dataset in extending_dataset.manifest.data_sources.items(): self.with_data_source(name, dataset, replace_ok=False) @classmethod @@ -339,9 +344,10 @@ def run(cls) -> DataManifest: output_dataset_draft = cls.transform( data=input_dataset, extra_args=json.loads(args.extra_args) ) + metadata = Metadata.from_dict(json.loads(args.metadata)) output_manifest = cls._create_output_dataset( draft=output_dataset_draft, - metadata=json.loads(args.metadata), + metadata=metadata, save_path=args.output_manifest, ) return output_manifest diff --git a/express/components/hf_datasets_components.py b/express/components/hf_datasets_components.py index 6b283bc50..d2502aaf7 100644 --- a/express/components/hf_datasets_components.py +++ b/express/components/hf_datasets_components.py @@ -33,9 +33,7 @@ # pylint: disable=too-few-public-methods -class HFDatasetsDataset( - ExpressDataset[List[str], Union[datasets.Dataset, datasets.DatasetDict]] -): +class HFDatasetsDataset(ExpressDataset[List[str], datasets.Dataset]): """Hugging Face Datasets dataset""" def load_index(self) -> datasets.Dataset: @@ -54,8 +52,10 @@ def load_index(self) -> datasets.Dataset: @staticmethod def _load_data_source( - data_source: DataSource, index_filter: datasets.Dataset - ) -> Union[datasets.Dataset, datasets.DatasetDict]: + data_source: DataSource, + index_filter: datasets.Dataset, + **kwargs, + ) -> datasets.Dataset: """Function that loads in a data source""" if data_source.type != DataType.PARQUET: raise TypeError("Only reading from parquet is currently supported.") @@ -67,19 +67,24 @@ def _load_data_source( data_source_location, tmp_dir ) - data_source_hf_datasets = load_dataset( + if "columns" in kwargs: + if "index" not in kwargs["columns"]: + raise ValueError( + "Please also include the index when specifying columns" + ) + + dataset = load_dataset( "parquet", data_files=local_parquet_path, split="train", + **kwargs, ) if index_filter: index = index_filter["index"] - return data_source_hf_datasets.filter( - lambda example: example["index"] in index - ) + return dataset.filter(lambda example: example["index"] in index) - return data_source_hf_datasets + return dataset class HFDatasetsDatasetHandler(ExpressDatasetHandler[List[str], datasets.Dataset]): @@ -87,7 +92,7 @@ class HFDatasetsDatasetHandler(ExpressDatasetHandler[List[str], datasets.Dataset @staticmethod def _upload_parquet( - data: Union[datasets.Dataset, datasets.DatasetDict], name: str, remote_path: str + data: datasets.Dataset, name: str, remote_path: str ) -> DataSource: with tempfile.TemporaryDirectory() as temp_folder: # TODO: uploading without writing to temp file @@ -120,7 +125,7 @@ def _upload_index(cls, index: datasets.Dataset, remote_path: str) -> DataSource: def _upload_data_source( cls, name: str, - data: Union[datasets.Dataset, datasets.DatasetDict], + data: datasets.Dataset, remote_path: str, ) -> DataSource: data_source = cls._upload_parquet(data=data, name=name, remote_path=remote_path)