-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Alpha pipeline] Add segmentation component (#149)
This PR adds the final component of the ControlNet pipeline, namely the segmentation one. --------- Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local> Co-authored-by: Philippe Moussalli <philippe.moussalli95@gmail.com> Co-authored-by: Robbe Sneyders <robbe.sneyders@ml6.eu>
- Loading branch information
1 parent
e063937
commit 1310e65
Showing
15 changed files
with
512 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 2 additions & 2 deletions
4
examples/pipelines/controlnet-interior-design/components/caption_images/Dockerfile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
examples/pipelines/controlnet-interior-design/components/caption_images/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
git+https://github.com/ml6team/fondant.git@main | ||
gcsfs==2023.4.0 | ||
Pillow==9.4.0 | ||
torch==2.0.0 | ||
torch==2.0.1 | ||
transformers==4.29.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
19 changes: 19 additions & 0 deletions
19
examples/pipelines/controlnet-interior-design/components/segment_images/Dockerfile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
FROM --platform=linux/amd64 pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel | ||
|
||
## System dependencies | ||
RUN apt-get update && \ | ||
apt-get upgrade -y && \ | ||
apt-get install git -y | ||
|
||
# install requirements | ||
COPY requirements.txt ./ | ||
RUN pip3 install --no-cache-dir -r requirements.txt | ||
|
||
# Set the working directory to the component folder | ||
WORKDIR /component/src | ||
|
||
# Copy over src-files and spec of the component | ||
COPY src/ . | ||
COPY fondant_component.yaml ../ | ||
|
||
ENTRYPOINT ["python", "main.py"] |
23 changes: 23 additions & 0 deletions
23
...les/pipelines/controlnet-interior-design/components/segment_images/fondant_component.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
name: Segment images | ||
description: Component that creates segmentation masks for images using a model from the Hugging Face hub | ||
image: ghcr.io/ml6team/segment_images:latest | ||
|
||
input_subsets: | ||
images: | ||
fields: | ||
data: | ||
type: binary | ||
|
||
output_subsets: | ||
segmentations: | ||
fields: | ||
data: | ||
type: binary | ||
|
||
args: | ||
model_id: | ||
description: id of the model on the Hugging Face hub | ||
type: str | ||
batch_size: | ||
description: batch size to use | ||
type: int |
5 changes: 5 additions & 0 deletions
5
examples/pipelines/controlnet-interior-design/components/segment_images/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
git+https://github.com/ml6team/fondant.git@main | ||
gcsfs==2023.4.0 | ||
transformers==4.29.2 | ||
Pillow==9.4.0 | ||
torch==2.0.1 |
148 changes: 148 additions & 0 deletions
148
examples/pipelines/controlnet-interior-design/components/segment_images/src/main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
""" | ||
This component that segments images using a model from the Hugging Face hub. | ||
""" | ||
import io | ||
import itertools | ||
import logging | ||
import toolz | ||
|
||
import dask | ||
import dask.dataframe as dd | ||
from PIL import Image | ||
import pandas as pd | ||
import numpy as np | ||
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation | ||
import torch | ||
|
||
from palette import palette | ||
|
||
from fondant.component import TransformComponent | ||
from fondant.logger import configure_logging | ||
|
||
configure_logging() | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def convert_to_rgb(seg: np.array): | ||
""" | ||
Converts a 2D segmentation to a RGB one which makes it possible to visualize it. | ||
Args: | ||
seg: 2D segmentation map as a NumPy array. | ||
Returns: | ||
color_seg: 3D segmentation map contain RGB values for each pixel. | ||
""" | ||
color_seg = np.zeros( | ||
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8 | ||
) # height, width, 3 | ||
|
||
for label, color in enumerate(palette): | ||
color_seg[seg == label, :] = color | ||
|
||
color_seg = color_seg.astype(np.uint8).tobytes() | ||
|
||
return color_seg | ||
|
||
|
||
@dask.delayed | ||
def load(example): | ||
bytes = io.BytesIO(example) | ||
image = Image.open(bytes).convert("RGB") | ||
return image | ||
|
||
|
||
@dask.delayed | ||
def transform(image, processor, device): | ||
inputs = processor(images=image, return_tensors="pt") | ||
pixel_values = inputs.pixel_values.to(device) | ||
|
||
return pixel_values, image.size | ||
|
||
|
||
@dask.delayed | ||
def collate(examples): | ||
encoding = {} | ||
encoding["pixel_values"] = torch.cat([ex[0] for ex in examples]) | ||
encoding["image_sizes"] = [ex[1] for ex in examples] | ||
return encoding | ||
|
||
|
||
@dask.delayed | ||
@torch.no_grad() | ||
def segment(batch, model, processor): | ||
outputs = model(batch["pixel_values"]) | ||
segmentations = processor.post_process_semantic_segmentation( | ||
outputs, target_sizes=batch["image_sizes"] | ||
) | ||
# turn into RGB images | ||
segmentations = [convert_to_rgb(seg.numpy()) for seg in segmentations] | ||
|
||
return segmentations | ||
|
||
|
||
@dask.delayed | ||
def flatten(lst): | ||
return pd.Series(itertools.chain(*lst)) | ||
|
||
|
||
class SegmentImagesComponent(TransformComponent): | ||
""" | ||
Component that segments images using a model from the Hugging Face hub. | ||
""" | ||
|
||
def transform( | ||
self, | ||
dataframe: dd.DataFrame, | ||
model_id: str, | ||
batch_size: int, | ||
) -> dd.DataFrame: | ||
""" | ||
Args: | ||
dataframe: Dask dataframe | ||
model_id: id of the model on the Hugging Face hub | ||
batch_size: batch size to use | ||
Returns: | ||
Dask dataframe | ||
""" | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
logger.info("Device:", device) | ||
|
||
processor = AutoImageProcessor.from_pretrained(model_id) | ||
model = AutoModelForSemanticSegmentation.from_pretrained(model_id) | ||
|
||
# load and transform the images | ||
images = dataframe["images_data"] | ||
loaded_images = [load(image) for image in images] | ||
transformed_images = [ | ||
transform(image, processor, device) for image in loaded_images | ||
] | ||
|
||
# batch images together | ||
batches = [ | ||
collate(batch) | ||
for batch in toolz.partition_all(batch_size, transformed_images) | ||
] | ||
|
||
# caption images | ||
delayed_model = dask.delayed(model.to(device)) | ||
segmentations = [segment(batch, delayed_model, processor) for batch in batches] | ||
|
||
# join lists into a single Dask delayed object | ||
segmentations = flatten(segmentations) | ||
delayed_series = dd.from_delayed(segmentations, meta=pd.Series(dtype="object")) | ||
segmentations_df = delayed_series.to_frame(name="segmentations_data") | ||
|
||
# add index columns | ||
segmentations_df["id"] = dataframe["id"].reset_index(drop=True) | ||
segmentations_df["source"] = dataframe["source"].reset_index(drop=True) | ||
|
||
segmentations_df = segmentations_df.reset_index(drop=True) | ||
|
||
return segmentations_df | ||
|
||
|
||
if __name__ == "__main__": | ||
component = SegmentImagesComponent.from_file() | ||
component.run() |
Oops, something went wrong.