Skip to content

Commit

Permalink
Merge pull request #37 from ml6team/fix_prefix
Browse files Browse the repository at this point in the history
Remove cloud specific prefix
  • Loading branch information
NielsRogge authored Apr 25, 2023
2 parents 7e146e2 + 9e7c721 commit 474e1ce
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 52 deletions.
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
git+https://github.com/ml6team/fondant.git@b1855308ca9251da5ddd8e6b88c34bc1c082a71b#egg=fondant
git+https://github.com/ml6team/fondant.git@b4f7a43eae540e6708af76dafd071c77a2254d86
pyarrow>=7.0
gcsfs==2023.4.0
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ class ImageFilterComponent(FondantComponent):
"""
Component that filters images based on height and width.
"""
def process(self, dataset: dd.DataFrame, args: Dict) -> dd.DataFrame:
def transform(self, df: dd.DataFrame, args: Dict) -> dd.DataFrame:
"""
Args:
dataset
df: Dask dataframe
args: args to pass to the function
Returns:
dataset
"""
logger.info("Filtering dataset...")
min_width, min_height = args.min_width, args.min_height
filtered_dataset = dataset.filter(lambda example: example["images_width"] > min_width and example["images_height"] > min_height)
filtered_df = df[(df["images_width"] > min_width) & (df["images_height"] > min_height)]

return filtered_dataset
return filtered_df


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
datasets==2.11.0
git+https://github.com/ml6team/fondant.git@8ecfb9fcaf0b8d457626179fe44347df829b8979#egg=fondant
git+https://github.com/ml6team/fondant.git@b04dec44ff198c34bf3862324ed5e431a9fd5366
Pillow==9.4.0
gcsfs==2023.4.0
11 changes: 5 additions & 6 deletions examples/pipelines/simple_pipeline/config/general_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class GeneralConfig:
General configs
Params:
GCP_PROJECT_ID (str): GCP project ID
DATASET_NAME (str): name of the Hugging Face dataset
ENV (str): the project run environment (sbx, dev, prd)
"""
GCP_PROJECT_ID = "soy-audio-379412"
Expand All @@ -22,12 +21,12 @@ class KubeflowConfig(GeneralConfig):
"""
Configs for the Kubeflow cluster
Params:
ARTIFACT_BUCKET (str): the GCS bucket used to store the artifacts
CLUSTER_NAME (str): the name of the k8 cluster hosting KFP
CLUSTER_ZONE (str): the zone of the k8 cluster hosting KFP
HOST (str): the kfp host url
BASE_PATH (str): base path to store the artifacts
CLUSTER_NAME (str): name of the k8 cluster hosting KFP
CLUSTER_ZONE (str): zone of the k8 cluster hosting KFP
HOST (str): kfp host url
"""
ARTIFACT_BUCKET = f"{GeneralConfig.GCP_PROJECT_ID}_kfp-artifacts"
BASE_PATH = f"gcs://soy-audio-379412_kfp-artifacts/custom_artifact"
CLUSTER_NAME = "kfp-fondant"
CLUSTER_ZONE = "europe-west4-a"
HOST = "https://52074149b1563463-dot-europe-west1.pipelines.googleusercontent.com"
15 changes: 7 additions & 8 deletions examples/pipelines/simple_pipeline/simple_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fondant.pipeline_utils import compile_and_upload_pipeline

# Load Components
artifact_bucket = KubeflowConfig.ARTIFACT_BUCKET + "/custom_artifact"
base_path = KubeflowConfig.BASE_PATH
run_id = "{{workflow.name}}"

# Component 1
Expand All @@ -22,12 +22,12 @@
image_filtering_op = comp.load_component("components/image_filtering/kubeflow_component.yaml")

load_from_hub_metadata = {
"base_path": artifact_bucket,
"base_path": base_path,
"run_id": run_id,
"component_id": load_from_hub_op.__name__,
}
image_filtering_metadata = {
"base_path": artifact_bucket,
"base_path": base_path,
"run_id": run_id,
"component_id": image_filtering_op.__name__,
}
Expand All @@ -37,12 +37,11 @@

# Pipeline
@dsl.pipeline(
name="image-generator-dataset",
description="Pipeline that takes example images as input and returns an expanded dataset of "
"similar images as outputs",
name="simple-pipeline-v2",
description="Simple pipeline that takes example images as input and embeds them using CLIP",
)
# pylint: disable=too-many-arguments, too-many-locals
def sd_dataset_creator_pipeline(
def simple_pipeline_v2(
load_from_hub_dataset_name: str = LoadFromHubConfig.DATASET_NAME,
load_from_hub_batch_size: int = LoadFromHubConfig.BATCH_SIZE,
load_from_hub_metadata: str = load_from_hub_metadata,
Expand All @@ -69,7 +68,7 @@ def sd_dataset_creator_pipeline(

if __name__ == "__main__":
compile_and_upload_pipeline(
pipeline=sd_dataset_creator_pipeline,
pipeline=simple_pipeline_v2,
host=KubeflowConfig.HOST,
env=KubeflowConfig.ENV,
)
88 changes: 65 additions & 23 deletions fondant/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
from abc import abstractmethod
import argparse
import json
import logging
from pathlib import Path
from typing import List, Mapping

import dask.dataframe as dd

from fondant.component_spec import FondantComponentSpec, kubeflow2python_type
from fondant.manifest import Manifest
from fondant.manifest import Manifest, Index
from fondant.schema import Type, Field

logger = logging.getLogger(__name__)


class FondantDataset:
"""Wrapper around the manifest to download and upload data into a specific framework.
Expand All @@ -31,36 +34,74 @@ def __init__(self, manifest: Manifest):
"__null_dask_index__": "int64",
}

def _load_subset(self, name: str, fields: List[str]) -> dd.DataFrame:
def _load_subset(
self, name: str, fields: List[str], index: Index = None
) -> dd.DataFrame:
# get subset from the manifest
subset = self.manifest.subsets[name]
# TODO remove prefix
location = "gcs://" + subset.location
# get remote path
remote_path = subset.location

# add index fields
index_fields = list(self.manifest.index.fields.keys())
fields = index_fields + fields

logger.info(f"Loading subset {name} with fields {fields}...")

df = dd.read_parquet(
location,
remote_path,
columns=fields,
)

# filter on default index of manifest if no index is provided
if index is None:
index_df = self._load_index()
ids = index_df["id"].compute()
sources = index_df["source"].compute()
df = df[df["id"].isin(ids) & df["source"].isin(sources)]

# add subset prefix to columns
df = df.rename(
columns={
col: name + "_" + col for col in df.columns if col not in index_fields
}
)

return df

def _load_index(self):
# get index subset from the manifest
index = self.manifest.index
# get remote path
remote_path = index.location

df = dd.read_parquet(remote_path)

if list(df.columns) != ["id", "source"]:
raise ValueError(
f"Index columns should be 'id' and 'source', found {df.columns}"
)

return df

def load_data(self, spec: FondantComponentSpec) -> dd.DataFrame:
subsets = []
subset_dfs = []
for name, subset in spec.input_subsets.items():
fields = list(subset.fields.keys())
subset_df = self._load_subset(name, fields)
subsets.append(subset_df)
subset_dfs.append(subset_df)

# return a single dataframe with column_names called subset_field
# TODO perhaps leverage dd.merge here instead
df = dd.concat(subset_dfs)

# TODO this method should return a single dataframe with column_names called subset_field
# TODO add index
# df = concatenate_datasets(subsets)
logging.info("Columns of dataframe:", list(df.columns))

# return df
return df

def _upload_index(self, df: dd.DataFrame):
# get location
# TODO remove prefix and suffix
remote_path = "gcs://" + self.manifest.index.location
# get remote path
remote_path = self.manifest.index.location

# upload to the cloud
dd.to_parquet(
Expand All @@ -79,8 +120,8 @@ def _upload_subset(self, name: str, fields: Mapping[str, Field], df: dd.DataFram
expected_schema = {field.name: field.type for field in fields.values()}
expected_schema.update(self.index_schema)

# TODO remove prefix
remote_path = "gcs://" + self.manifest.subsets[name].location
# get remote path
remote_path = self.manifest.subsets[name].location

# upload to the cloud
dd.to_parquet(df, remote_path, schema=expected_schema, overwrite=True)
Expand Down Expand Up @@ -161,13 +202,14 @@ def run(self) -> dd.DataFrame:
dataset.add_index(df)
dataset.add_subsets(df, self.spec)
else:
# create HF dataset, based on component spec
input_dataset = dataset.load_data(self.spec)
# provide this dataset to the user
# create dataframe, based on component spec
df = dataset.load_data(self.spec)
# provide this dataframe to the user
df = self.transform(
dataset=input_dataset,
df=df,
args=args,
)
# TODO update index, potentially add new subsets

# step 4: create output manifest
output_manifest = dataset.upload(save_path=args.output_manifest_path)
Expand Down Expand Up @@ -212,8 +254,8 @@ def _add_and_parse_args(self):

@abstractmethod
def load(self, args) -> dd.DataFrame:
"""Load initial dataset"""
"""Load initial dataframe"""

@abstractmethod
def transform(self, dataset, args) -> dd.DataFrame:
"""Transform existing dataset"""
def transform(self, df, args) -> dd.DataFrame:
"""Transform existing dataframe"""
18 changes: 9 additions & 9 deletions fondant/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def compile_and_upload_pipeline(
pipeline_func=pipeline, package_path=pipeline_filename
)

existing_pipelines = client.list_pipelines(page_size=100).pipelines
for existing_pipeline in existing_pipelines:
if existing_pipeline.name == pipeline_name:
# Delete existing pipeline before uploading
logger.warning(
f"Pipeline {pipeline_name} already exists. Deleting old pipeline..."
)
client.delete_pipeline_version(existing_pipeline.default_version.id)
client.delete_pipeline(existing_pipeline.id)
# existing_pipelines = client.list_pipelines(page_size=100).pipelines
# for existing_pipeline in existing_pipelines:
# if existing_pipeline.name == pipeline_name:
# # Delete existing pipeline before uploading
# logger.warning(
# f"Pipeline {pipeline_name} already exists. Deleting old pipeline..."
# )
# client.delete_pipeline_version(existing_pipeline.default_version.id)
# client.delete_pipeline(existing_pipeline.id)

logger.info(f"Uploading pipeline: {pipeline_name}")
client.upload_pipeline(pipeline_filename, pipeline_name=pipeline_name)
Expand Down

0 comments on commit 474e1ce

Please sign in to comment.