Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP:avoid unnecessary segmentation #783

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
211 changes: 211 additions & 0 deletions experiments/visualizing_segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from loguru import logger
from PIL import Image
import numpy as np
from openadapt import vision, adapters
import cv2
from skimage.metrics import structural_similarity as ssim


def extract_difference_image(
new_image: Image.Image,
old_image: Image.Image,
tolerance: float = 0.05,
) -> Image.Image:
"""Extract the portion of the new image that is different from the old image.

Args:
new_image: The new image as a PIL Image object.
old_image: The old image as a PIL Image object.
tolerance: Tolerance level to consider a pixel as different (default is 0.05).

Returns:
A PIL Image object representing the difference image.
"""
new_image_np = np.array(new_image)
old_image_np = np.array(old_image)

# Compute the absolute difference between the two images in each color channel
diff = np.abs(new_image_np - old_image_np)

# Create a mask for the regions where the difference is above the tolerance
mask = np.any(diff > (255 * tolerance), axis=-1)

# Initialize an array for the segmented image
segmented_image_np = np.zeros_like(new_image_np)

# Set the pixels that are different in the new image
segmented_image_np[mask] = new_image_np[mask]

# Convert the numpy array back to an image
return Image.fromarray(segmented_image_np)


def combine_images_with_masks(
image_1: Image.Image,
difference_image: Image.Image,
old_masks: list[np.ndarray],
new_masks: list[np.ndarray],
) -> Image.Image:
"""Combine image_1 and difference_image using the masks.

Args:
image_1: The original image as a PIL Image object.
difference_image: The difference image as a PIL Image object.
old_masks: List of numpy arrays representing the masks from the original image.
new_masks: List of numpy arrays representing the masks from the difference image.

Returns:
A PIL Image object representing the combined image.
"""

image_1_np = np.array(image_1)
difference_image_np = np.array(difference_image)

# Create an empty canvas with the same dimensions and mode as image_1
combined_image_np = np.zeros_like(image_1_np)

def masks_overlap(mask1, mask2):
"""Check if two masks overlap."""
return np.any(np.logical_and(mask1, mask2))

# Apply old masks to the combined image where there is no overlap with new masks
for old_mask in old_masks:
if not any(masks_overlap(old_mask, new_mask) for new_mask in new_masks):
combined_image_np[old_mask] = image_1_np[old_mask]

# Apply new masks to the combined image
for new_mask in new_masks:
combined_image_np[new_mask] = difference_image_np[new_mask]

# Fill in remaining pixels from image_1 where there are no masks
combined_image_np[(combined_image_np == 0).all(axis=-1)] = image_1_np[
(combined_image_np == 0).all(axis=-1)
]

# Convert the numpy array back to an image
return Image.fromarray(combined_image_np)


def find_matching_sections_ssim(
image_1: Image.Image,
image_2: Image.Image,
block_size: int = 50,
threshold: float = 0.9,
):
"""Find and visualize matching sections between two images using SSIM.

Args:
image_1: The first image as a PIL Image object.
image_2: The second image as a PIL Image object.
block_size: The size of the blocks to compare in the SSIM calculation. Default is 50.
threshold: The SSIM score threshold to consider blocks as matching. Default is 0.9.

Returns:
A PIL Image object with matching sections highlighted.
"""

# Convert images to grayscale
image_1_gray = np.array(image_1.convert("L"))
image_2_gray = np.array(image_2.convert("L"))

# Dimensions of the images
height, width = image_1_gray.shape

# Create an empty image to visualize matches
matching_image = np.zeros_like(image_1_gray)

# Iterate over the image in blocks
for y in range(0, height, block_size):
for x in range(0, width, block_size):
# Define the block region
block_1 = image_1_gray[y : y + block_size, x : x + block_size]
block_2 = image_2_gray[y : y + block_size, x : x + block_size]

# Check if blocks have the same shape
if block_1.shape == block_2.shape:
# Compute SSIM for the current block
score, _ = ssim(block_1, block_2, full=True)

# Highlight matching sections
if score >= threshold:
matching_image[y : y + block_size, x : x + block_size] = 255

# Create an overlay to highlight matching regions on the original image
overlay = np.zeros_like(np.array(image_1), dtype=np.uint8)

# Apply the overlay to the matching regions
for c in range(0, 3): # For each color channel
overlay[:, :, c] = np.where(
matching_image == 255, np.array(image_1)[:, :, c], 0
)

# For RGBA images, set the alpha channel to 255 (fully opaque) for matching sections
if image_1.mode == "RGBA":
overlay[:, :, 3] = np.where(matching_image == 255, 255, 0)

# Convert back to PIL Image
matching_image_pil = Image.fromarray(overlay)

return matching_image_pil


def visualize(image_1: Image, image_2: Image):
"""Visualize matching sections, difference sections, and combined images between two images.

Args:
image_1: The first image as a PIL Image object.
image_2: The second image as a PIL Image object.

Returns:
None
"""

try:
images = []

matching_image = find_matching_sections_ssim(image_1, image_2)

difference_image = extract_difference_image(image_2, image_1, tolerance=0.05)

old_masks = vision.get_masks_from_segmented_image(image_1)
new_masks = vision.get_masks_from_segmented_image(difference_image)

combined_image = combine_images_with_masks(
image_1, difference_image, old_masks, new_masks
)

segmentation_adapter = adapters.get_default_segmentation_adapter()
ref_segmented_image = segmentation_adapter.fetch_segmented_image(image_1)
new_segmented_image = segmentation_adapter.fetch_segmented_image(image_2)
matching_image_segment = segmentation_adapter.fetch_segmented_image(
matching_image
)
non_matching_image_Segment = segmentation_adapter.fetch_segmented_image(
difference_image
)
combined_image_segment = segmentation_adapter.fetch_segmented_image(
combined_image
)

images.append(image_1)
images.append(ref_segmented_image)
images.append(image_2)
images.append(new_segmented_image)
images.append(matching_image)
images.append(matching_image_segment)
images.append(difference_image)
images.append(non_matching_image_Segment)
images.append(combined_image)
images.append(combined_image_segment)

for image in images:
image.show()

except Exception as e:
logger.error(f"An error occurred: {e}")


# Example usage
img_2 = Image.open("../experiments/winCalNew.png")
img_1 = Image.open("../experiments/winCalOld.png")
visualize(img_1, img_2)
Binary file added experiments/winCalNew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added experiments/winCalOld.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
121 changes: 120 additions & 1 deletion openadapt/strategies/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,102 @@ def find_similar_image_segmentation(
return similar_segmentation, similar_segmentation_diff


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove extra newline

def combine_segmentations(
difference_image: Image.Image,
previous_segmentation: Segmentation,
new_descriptions: list[str],
new_masked_images: list[Image.Image],
new_masks: list[np.ndarray],
) -> Segmentation:
"""Combine the previous segmentation with the new segmentation of the differences.
Args:
difference_image: The difference image found in similar segmentation.
previous_segmentation: The previous segmentation containing unchanged segments.
new_descriptions: Descriptions of the new segments from the difference image.
new_masked_images: Masked images of the new segments from the difference image.
new_masks: masks of the new segments.
Returns:
Segmentation: A new segmentation combining both previous and new segments.
"""

image_1_np = np.array(previous_segmentation.image)
difference_image_np = np.array(difference_image)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add function docstring including args and return values

# Create an empty canvas with the same dimensions and mode as image_1
combined_image_np = np.zeros_like(image_1_np)

# Ensure difference_image_np is 3 channels
if difference_image_np.ndim == 2: # Grayscale image
difference_image_np = np.stack((difference_image_np,) * 3, axis=-1)

def masks_overlap(mask1, mask2):
"""Check if two masks overlap."""
return np.any(np.logical_and(mask1, mask2))

# Calculate the bounding boxes and centroids for the new segments
new_bounding_boxes, new_centroids = vision.calculate_bounding_boxes(new_masks)

segmentation_adapter = adapters.get_default_segmentation_adapter()
segmented_prev_image = segmentation_adapter.fetch_segmented_image(
previous_segmentation.image
)
previous_masks = vision.get_masks_from_segmented_image(segmented_prev_image)

# Filter out overlapping previous segments
filtered_previous_masked_images = []
# filtered_previous_descriptions = []
filtered_previous_bounding_boxes = []
filtered_previous_centroids = []
for idx, prev_mask in enumerate(previous_masks):
if not any(masks_overlap(prev_mask, new_mask) for new_mask in new_masks):
combined_image_np[prev_mask] = image_1_np[
prev_mask
] # Apply previous masks to the combined image where there is no overlap with new masks
filtered_previous_masked_images.append(
previous_segmentation.masked_images[idx]
)
# filtered_previous_descriptions.append(
# previous_segmentation.descriptions[idx]
# )
filtered_previous_bounding_boxes.append(
previous_segmentation.bounding_boxes[idx]
)
filtered_previous_centroids.append(previous_segmentation.centroids[idx])

# Apply new masks to the combined image
for new_mask in new_masks:
combined_image_np[new_mask] = difference_image_np[new_mask]

# Fill in remaining pixels from image_1 where there are no masks
combined_image_np[(combined_image_np == 0).all(axis=-1)] = image_1_np[
(combined_image_np == 0).all(axis=-1)
]

# Combine filtered previous segments with new segments
combined_masked_images = filtered_previous_masked_images + new_masked_images
# combined_descriptions = filtered_previous_descriptions + new_descriptions
combined_bounding_boxes = filtered_previous_bounding_boxes + new_bounding_boxes
combined_centroids = filtered_previous_centroids + new_centroids

# Convert the numpy array back to an image
new_image = Image.fromarray(combined_image_np)

marked_image = plotting.get_marked_image(
new_image,
new_masks, # masks,
)
# new_image.show()

return Segmentation(
new_image,
marked_image,
combined_masked_images,
new_descriptions,
combined_bounding_boxes,
combined_centroids,
)


def get_window_segmentation(
action_event: models.ActionEvent,
exceptions: list[Exception] | None = None,
Expand Down Expand Up @@ -402,7 +498,30 @@ def get_window_segmentation(
# TODO XXX: create copy of similar_segmentation, but overwrite with segments of
# regions of new image where segments of similar_segmentation overlap non-zero
# regions of similar_segmentation_diff
return similar_segmentation
logger.info(f"Found similar_segmentation")
similar_segmentation_diff_image = Image.fromarray(similar_segmentation_diff)
segmentation_adapter = adapters.get_default_segmentation_adapter()
segmented_diff_image = segmentation_adapter.fetch_segmented_image(
similar_segmentation_diff_image
)
new_masks = vision.get_masks_from_segmented_image(segmented_diff_image)
new_masked_images = vision.extract_masked_images(
similar_segmentation_diff_image, new_masks
)
new_descriptions = prompt_for_descriptions(
similar_segmentation_diff_image,
new_masked_images,
action_event.active_segment_description,
exceptions,
)
updated_segmentation = combine_segmentations(
similar_segmentation_diff_image,
similar_segmentation,
new_descriptions,
new_masked_images,
new_masks,
)
return updated_segmentation

segmentation_adapter = adapters.get_default_segmentation_adapter()
segmented_image = segmentation_adapter.fetch_segmented_image(original_image)
Expand Down
15 changes: 11 additions & 4 deletions openadapt/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,18 @@ def extract_masked_images(
cropped_mask = mask[rmin : rmax + 1, cmin : cmax + 1]
cropped_image = original_image_np[rmin : rmax + 1, cmin : cmax + 1]

# Ensure the cropped image has the correct shape
if cropped_image.ndim == 2: # Grayscale image
cropped_image = cropped_image[:, :, None]
elif cropped_image.shape[2] != 1: # Color image
cropped_image = cropped_image[:, :, :3] # Keep RGB channels only

# Ensure the mask has the correct shape
reshaped_mask = cropped_mask[:, :, None]

# Apply the mask
masked_image = np.where(cropped_mask[:, :, None], cropped_image, 0).astype(
np.uint8
)
masked_images.append(Image.fromarray(masked_image))
masked_image = np.where(reshaped_mask, cropped_image, 0).astype(np.uint8)
masked_images.append(Image.fromarray(masked_image.squeeze()))

logger.info(f"{len(masked_images)=}")
return masked_images
Expand Down
2 changes: 1 addition & 1 deletion openadapt/window/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_active_window_data(
"""
state = get_active_window_state(include_window_data)
if not state:
return None
return {}
title = state["title"]
left = state["left"]
top = state["top"]
Expand Down
Loading