Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Alpha pipeline] Add segmentation component #149

Merged
merged 10 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
FROM --platform=linux/amd64 python:3.8-slim

## System dependencies
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install git -y

# install requirements
COPY requirements.txt /
RUN pip3 install --no-cache-dir -r requirements.txt

# Set the working directory to the component folder
WORKDIR /component/src

# Copy over src-files and spec of the component
COPY src/ .
COPY fondant_component.yaml ../

ENTRYPOINT ["python", "main.py"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Segment images
description: Component that creates segmentation masks for images using a model from the Hugging Face hub
image: ghcr.io/ml6team/segment_images:latest

input_subsets:
images:
fields:
data:
type: binary

output_subsets:
segmentations:
fields:
data:
type: binary

args:
model_id:
description: id of the model on the Hugging Face hub
type: str
batch_size:
description: batch size to use
type: int
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
git+https://github.com/ml6team/fondant.git@main
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
gcsfs==2023.4.0
transformers==4.29.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
This component that segments images using a model from the Hugging Face hub.
"""
import io
import itertools
import logging
import toolz

from PIL import Image

import dask
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
import dask.dataframe as dd
import pandas as pd
import numpy as np

from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
import torch

from palette import palette

from fondant.component import TransformComponent
from fondant.logger import configure_logging

configure_logging()
logger = logging.getLogger(__name__)


def convert_to_rgb(seg):
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
) # height, width, 3

for label, color in enumerate(palette):
color_seg[seg == label, :] = color

color_seg = color_seg.astype(np.uint8).tobytes()

return color_seg


@dask.delayed
def load(example):
bytes = io.BytesIO(example)
image = Image.open(bytes).convert("RGB")
return image


@dask.delayed
def transform(image, processor, device):
inputs = processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(device)

return pixel_values, image.size


@dask.delayed
def collate(examples):
encoding = {}
encoding["pixel_values"] = torch.cat([ex[0] for ex in examples])
encoding["image_sizes"] = [ex[1] for ex in examples]
return encoding


@dask.delayed
@torch.no_grad()
def segment(batch, model, processor):
outputs = model(batch["pixel_values"])
segmentations = processor.post_process_semantic_segmentation(
outputs, target_sizes=batch["image_sizes"]
)
# turn into RGB images
segmentations = [convert_to_rgb(seg.numpy()) for seg in segmentations]

return segmentations


@dask.delayed
def flatten(lst):
return pd.Series(itertools.chain(*lst))


class SegmentImagesComponent(TransformComponent):
"""
Component that segments images using a model from the Hugging Face hub.
"""

def transform(
self,
dataframe: dd.DataFrame,
model_id: str,
batch_size: int,
) -> dd.DataFrame:
"""
Args:
dataframe: Dask dataframe
model_id: id of the model on the Hugging Face hub
batch_size: batch size to use

Returns:
Dask dataframe
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Device:", device)

processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForSemanticSegmentation.from_pretrained(model_id)

print("Length of the input dataframe:", len(dataframe))
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
print("First rows of dataframe:", dataframe.head(2))

# load and transform the images
images = dataframe["images_data"]
loaded_images = [load(image) for image in images]
transformed_images = [
transform(image, processor, device) for image in loaded_images
]

# batch images together
batches = [
collate(batch)
for batch in toolz.partition_all(batch_size, transformed_images)
]

# caption images
delayed_model = dask.delayed(model.to(device))
segmentations = [segment(batch, delayed_model, processor) for batch in batches]

# join lists into a single Dask delayed object
segmentations = flatten(segmentations)
delayed_series = dd.from_delayed(segmentations, meta=pd.Series(dtype="object"))
segmentations_df = delayed_series.to_frame(name="segmentations_data")

# add index columns
segmentations_df["id"] = dataframe["id"].reset_index(drop=True)
segmentations_df["source"] = dataframe["source"].reset_index(drop=True)

segmentations_df = segmentations_df.reset_index(drop=True)

print("Final dataframe:", segmentations_df.head(4))
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
print("Length of the final dataframe:", len(segmentations_df))

return segmentations_df


if __name__ == "__main__":
component = SegmentImagesComponent.from_file()
component.run()
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import numpy as np
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved

palette = np.asarray(
[
[0, 0, 0],
[120, 120, 120],
[180, 120, 120],
[6, 230, 230],
[80, 50, 50],
[4, 200, 3],
[120, 120, 80],
[140, 140, 140],
[204, 5, 255],
[230, 230, 230],
[4, 250, 7],
[224, 5, 255],
[235, 255, 7],
[150, 5, 61],
[120, 120, 70],
[8, 255, 51],
[255, 6, 82],
[143, 255, 140],
[204, 255, 4],
[255, 51, 7],
[204, 70, 3],
[0, 102, 200],
[61, 230, 250],
[255, 6, 51],
[11, 102, 255],
[255, 7, 71],
[255, 9, 224],
[9, 7, 230],
[220, 220, 220],
[255, 9, 92],
[112, 9, 255],
[8, 255, 214],
[7, 255, 224],
[255, 184, 6],
[10, 255, 71],
[255, 41, 10],
[7, 255, 255],
[224, 255, 8],
[102, 8, 255],
[255, 61, 6],
[255, 194, 7],
[255, 122, 8],
[0, 255, 20],
[255, 8, 41],
[255, 5, 153],
[6, 51, 255],
[235, 12, 255],
[160, 150, 20],
[0, 163, 255],
[140, 140, 140],
[250, 10, 15],
[20, 255, 0],
[31, 255, 0],
[255, 31, 0],
[255, 224, 0],
[153, 255, 0],
[0, 0, 255],
[255, 71, 0],
[0, 235, 255],
[0, 173, 255],
[31, 0, 255],
[11, 200, 200],
[255, 82, 0],
[0, 255, 245],
[0, 61, 255],
[0, 255, 112],
[0, 255, 133],
[255, 0, 0],
[255, 163, 0],
[255, 102, 0],
[194, 255, 0],
[0, 143, 255],
[51, 255, 0],
[0, 82, 255],
[0, 255, 41],
[0, 255, 173],
[10, 0, 255],
[173, 255, 0],
[0, 255, 153],
[255, 92, 0],
[255, 0, 255],
[255, 0, 245],
[255, 0, 102],
[255, 173, 0],
[255, 0, 20],
[255, 184, 184],
[0, 31, 255],
[0, 255, 61],
[0, 71, 255],
[255, 0, 204],
[0, 255, 194],
[0, 255, 82],
[0, 10, 255],
[0, 112, 255],
[51, 0, 255],
[0, 194, 255],
[0, 122, 255],
[0, 255, 163],
[255, 153, 0],
[0, 255, 10],
[255, 112, 0],
[143, 255, 0],
[82, 0, 255],
[163, 255, 0],
[255, 235, 0],
[8, 184, 170],
[133, 0, 255],
[0, 255, 92],
[184, 0, 255],
[255, 0, 31],
[0, 184, 255],
[0, 214, 255],
[255, 0, 112],
[92, 255, 0],
[0, 224, 255],
[112, 224, 255],
[70, 184, 160],
[163, 0, 255],
[153, 0, 255],
[71, 255, 0],
[255, 0, 163],
[255, 204, 0],
[255, 0, 143],
[0, 255, 235],
[133, 255, 0],
[255, 0, 235],
[245, 0, 255],
[255, 0, 122],
[255, 245, 0],
[10, 190, 212],
[214, 255, 0],
[0, 204, 255],
[20, 0, 255],
[255, 255, 0],
[0, 153, 255],
[0, 41, 255],
[0, 255, 204],
[41, 0, 255],
[41, 255, 0],
[173, 0, 255],
[0, 245, 255],
[71, 0, 255],
[122, 0, 255],
[0, 255, 184],
[0, 92, 255],
[184, 255, 0],
[0, 133, 255],
[255, 214, 0],
[25, 194, 194],
[102, 255, 0],
[92, 0, 255],
]
)
8 changes: 8 additions & 0 deletions examples/pipelines/controlnet-interior-design/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,20 @@
number_of_gpus=1,
node_pool_name="model-inference-pool",
)
segment_images_op = ComponentOp(
component_spec_path="components/segment_images/fondant_component.yaml",
arguments={
"model_id": "openmmlab/upernet-convnext-small",
"batch_size": 2,
},
)

pipeline = Pipeline(pipeline_name=pipeline_name, base_path=PipelineConfigs.BASE_PATH)

pipeline.add_op(generate_prompts_op)
pipeline.add_op(laion_retrieval_op, dependencies=generate_prompts_op)
pipeline.add_op(download_images_op, dependencies=laion_retrieval_op)
pipeline.add_op(caption_images_op, dependencies=download_images_op)
pipeline.add_op(segment_images_op, dependencies=caption_images_op)

client.compile_and_run(pipeline=pipeline)
Loading