Skip to content

Commit

Permalink
[Alpha pipeline] Add segmentation component (#149)
Browse files Browse the repository at this point in the history
This PR adds the final component of the ControlNet pipeline, namely the
segmentation one.

---------

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
Co-authored-by: Philippe Moussalli <philippe.moussalli95@gmail.com>
Co-authored-by: Robbe Sneyders <robbe.sneyders@ml6.eu>
  • Loading branch information
4 people authored May 22, 2023
1 parent e063937 commit 1310e65
Show file tree
Hide file tree
Showing 15 changed files with 512 additions and 24 deletions.
17 changes: 7 additions & 10 deletions components/prompt_based_laion_retrieval/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ class LAIONRetrievalComponent(TransformComponent):
"""

def transform(
self,
dataframe: dd.DataFrame,
*,
num_images: int,
aesthetic_score: int,
aesthetic_weight: float
self,
dataframe: dd.DataFrame,
*,
num_images: int,
aesthetic_score: int,
aesthetic_weight: float
) -> dd.DataFrame:
"""
Args:
Expand All @@ -70,8 +70,6 @@ def transform(
modality=Modality.IMAGE,
)

print("Input dataframe:", dataframe.head(2))

logger.info("Retrieving URLs...")
dataframe["images_url"] = dataframe["prompts_text"].apply(
lambda example: query_clip_client(example, client),
Expand All @@ -91,12 +89,11 @@ def transform(

dataframe = dataframe.astype({'id': 'string', 'source': 'string'})

print("Final dataframe:", dataframe.head(4))
dataframe = dataframe.reset_index(drop=True)

return dataframe


if __name__ == "__main__":
component = LAIONRetrievalComponent.from_file()
component.run()
component.run()
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
FROM --platform=linux/amd64 python:3.8-slim
FROM --platform=linux/amd64 pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel

## System dependencies
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install git -y

# install requirements
COPY requirements.txt /
COPY requirements.txt ./
RUN pip3 install --no-cache-dir -r requirements.txt

# Set the working directory to the compoent folder
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
git+https://github.com/ml6team/fondant.git@main
gcsfs==2023.4.0
Pillow==9.4.0
torch==2.0.0
torch==2.0.1
transformers==4.29.2
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,6 @@ def load(self) -> dd.DataFrame:

df = dd.from_pandas(pandas_df, npartitions=1)

# TODO remove, just use a tiny df for testing purposes
data = {
"prompts_text": [
"comfortable bathroom, art deco interior design",
"comfortable bathroom, bauhaus interior design",
]
}
pandas_df = pd.DataFrame.from_dict(data)
df = dd.from_pandas(pandas_df, npartitions=1)
# end of TODO

# add id and source columns
df["id"] = df.assign(id=1).id.cumsum()
df["source"] = "seed"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
FROM --platform=linux/amd64 pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel

## System dependencies
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install git -y

# install requirements
COPY requirements.txt ./
RUN pip3 install --no-cache-dir -r requirements.txt

# Set the working directory to the component folder
WORKDIR /component/src

# Copy over src-files and spec of the component
COPY src/ .
COPY fondant_component.yaml ../

ENTRYPOINT ["python", "main.py"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Segment images
description: Component that creates segmentation masks for images using a model from the Hugging Face hub
image: ghcr.io/ml6team/segment_images:latest

input_subsets:
images:
fields:
data:
type: binary

output_subsets:
segmentations:
fields:
data:
type: binary

args:
model_id:
description: id of the model on the Hugging Face hub
type: str
batch_size:
description: batch size to use
type: int
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
git+https://github.com/ml6team/fondant.git@main
gcsfs==2023.4.0
transformers==4.29.2
Pillow==9.4.0
torch==2.0.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
This component that segments images using a model from the Hugging Face hub.
"""
import io
import itertools
import logging
import toolz

import dask
import dask.dataframe as dd
from PIL import Image
import pandas as pd
import numpy as np
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
import torch

from palette import palette

from fondant.component import TransformComponent
from fondant.logger import configure_logging

configure_logging()
logger = logging.getLogger(__name__)


def convert_to_rgb(seg: np.array):
"""
Converts a 2D segmentation to a RGB one which makes it possible to visualize it.
Args:
seg: 2D segmentation map as a NumPy array.
Returns:
color_seg: 3D segmentation map contain RGB values for each pixel.
"""
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
) # height, width, 3

for label, color in enumerate(palette):
color_seg[seg == label, :] = color

color_seg = color_seg.astype(np.uint8).tobytes()

return color_seg


@dask.delayed
def load(example):
bytes = io.BytesIO(example)
image = Image.open(bytes).convert("RGB")
return image


@dask.delayed
def transform(image, processor, device):
inputs = processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(device)

return pixel_values, image.size


@dask.delayed
def collate(examples):
encoding = {}
encoding["pixel_values"] = torch.cat([ex[0] for ex in examples])
encoding["image_sizes"] = [ex[1] for ex in examples]
return encoding


@dask.delayed
@torch.no_grad()
def segment(batch, model, processor):
outputs = model(batch["pixel_values"])
segmentations = processor.post_process_semantic_segmentation(
outputs, target_sizes=batch["image_sizes"]
)
# turn into RGB images
segmentations = [convert_to_rgb(seg.numpy()) for seg in segmentations]

return segmentations


@dask.delayed
def flatten(lst):
return pd.Series(itertools.chain(*lst))


class SegmentImagesComponent(TransformComponent):
"""
Component that segments images using a model from the Hugging Face hub.
"""

def transform(
self,
dataframe: dd.DataFrame,
model_id: str,
batch_size: int,
) -> dd.DataFrame:
"""
Args:
dataframe: Dask dataframe
model_id: id of the model on the Hugging Face hub
batch_size: batch size to use
Returns:
Dask dataframe
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Device:", device)

processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForSemanticSegmentation.from_pretrained(model_id)

# load and transform the images
images = dataframe["images_data"]
loaded_images = [load(image) for image in images]
transformed_images = [
transform(image, processor, device) for image in loaded_images
]

# batch images together
batches = [
collate(batch)
for batch in toolz.partition_all(batch_size, transformed_images)
]

# caption images
delayed_model = dask.delayed(model.to(device))
segmentations = [segment(batch, delayed_model, processor) for batch in batches]

# join lists into a single Dask delayed object
segmentations = flatten(segmentations)
delayed_series = dd.from_delayed(segmentations, meta=pd.Series(dtype="object"))
segmentations_df = delayed_series.to_frame(name="segmentations_data")

# add index columns
segmentations_df["id"] = dataframe["id"].reset_index(drop=True)
segmentations_df["source"] = dataframe["source"].reset_index(drop=True)

segmentations_df = segmentations_df.reset_index(drop=True)

return segmentations_df


if __name__ == "__main__":
component = SegmentImagesComponent.from_file()
component.run()
Loading

0 comments on commit 1310e65

Please sign in to comment.