diff --git a/decimer_segmentation/complete_structure.py b/decimer_segmentation/complete_structure.py index 24ea12f..79aa576 100644 --- a/decimer_segmentation/complete_structure.py +++ b/decimer_segmentation/complete_structure.py @@ -366,7 +366,10 @@ def get_neighbour_pixels( return neighbour_pixels -def detect_horizontal_and_vertical_lines(image: np.ndarray) -> np.ndarray: +def detect_horizontal_and_vertical_lines( + image: np.ndarray, + average_depiction_size: Tuple[int, int] +) -> np.ndarray: """ This function takes an image and returns a binary mask that labels the pixels that are part of long horizontal or vertical lines. [Definition of long: 1/5 of the @@ -382,19 +385,19 @@ def detect_horizontal_and_vertical_lines(image: np.ndarray) -> np.ndarray: """ binarised_im = ~image * 255 binarised_im = binarised_im.astype("uint8") + + structure_height, structure_width = average_depiction_size - horizontal_kernel_size = int(binarised_im.shape[1] / 7) horizontal_kernel = cv2.getStructuringElement( - cv2.MORPH_RECT, (horizontal_kernel_size, 1) + cv2.MORPH_RECT, (structure_width, 1) ) horizontal_mask = cv2.morphologyEx( binarised_im, cv2.MORPH_OPEN, horizontal_kernel, iterations=2 ) horizontal_mask = horizontal_mask == 255 - vertical_kernel_size = int(binarised_im.shape[0] / 7) vertical_kernel = cv2.getStructuringElement( - cv2.MORPH_RECT, (1, vertical_kernel_size) + cv2.MORPH_RECT, (1, structure_height) ) vertical_mask = cv2.morphologyEx( binarised_im, cv2.MORPH_OPEN, vertical_kernel, iterations=2 @@ -472,13 +475,30 @@ def expansion_coordination( def complete_structure_mask( - image_array: np.array, mask_array: np.array, debug=False + image_array: np.array, + mask_array: np.array, + average_depiction_size: Tuple[int, int], + debug=False ) -> np.array: """ - This funtion takes an image (array) and an array containing the masks (shape: + This funtion takes an image (np.array) and an array containing the masks (shape: x,y,n where n is the amount of masks and x and y are the pixel coordinates). + Additionally, it takes the average depiction size of the structures in the image + which is used to define the kernel size for the vertical and horizontal line + detection for the exclusion masks. The exclusion mask is used to exclude pixels + from the mask expansion to avoid including whole tables. It detects objects on the contours of the mask and expands it until it frames the - complete object in the image. It returns the expanded mask array""" + complete object in the image. It returns the expanded mask array + + Args: + image_array (np.array): input image + mask_array (np.array): shape: y, x, n where n is the amount of masks + average_depiction_size (Tuple[int, int]): height, width + debug (bool, optional): More verbose if True. Defaults to False. + + Returns: + np.array: expanded mask array + """ if mask_array.size != 0: # Binarization of input image @@ -498,7 +518,8 @@ def complete_structure_mask( split_mask_arrays = np.array( [mask_array[:, :, index] for index in range(mask_array.shape[2])] ) - exclusion_mask = detect_horizontal_and_vertical_lines(blurred_image_array) + exclusion_mask = detect_horizontal_and_vertical_lines(blurred_image_array, + average_depiction_size) # Run expansion the expansion image_repeat = itertools.repeat(blurred_image_array, mask_array.shape[2]) exclusion_mask_repeat = itertools.repeat(exclusion_mask, mask_array.shape[2]) diff --git a/decimer_segmentation/decimer_segmentation.py b/decimer_segmentation/decimer_segmentation.py index 673517d..73612bf 100644 --- a/decimer_segmentation/decimer_segmentation.py +++ b/decimer_segmentation/decimer_segmentation.py @@ -98,7 +98,6 @@ def segment_chemical_structures( if not expand: masks, bboxes, _ = get_mrcnn_results(image) else: - average_height, average_width = determine_average_depiction_size(bboxes) masks = get_expanded_masks(image) segments, bboxes = apply_masks(image, masks) @@ -227,9 +226,12 @@ def get_expanded_masks(image: np.array) -> np.array: np.array: expanded masks (shape: (h, w, num_masks)) """ # Structure detection with MRCNN - masks, _, _ = get_mrcnn_results(image) + masks, bboxes, _ = get_mrcnn_results(image) + size = determine_average_depiction_size(bboxes) # Mask expansion - expanded_masks = complete_structure_mask(image_array=image, mask_array=masks) + expanded_masks = complete_structure_mask(image_array=image, + mask_array=masks, + average_depiction_size=size,) return expanded_masks