Skip to content

Commit

Permalink
Update inference_textmap.py
Browse files Browse the repository at this point in the history
  • Loading branch information
K-Hooshanfar authored Mar 4, 2024
1 parent e71e8e2 commit fc430be
Showing 1 changed file with 42 additions and 68 deletions.
110 changes: 42 additions & 68 deletions text_detector_module/inference_textmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
and then save the modified images where the detected text areas are highlighted to the specified output directory.
"""

import argparse
import math
import os
import os.path as osp

import cv2
import keras.backend as K
import os.path as osp
import glob
import argparse
import numpy as np
import pyclipper
import tensorflow as tf
from keras import layers, models
from keras_resnet.models import ResNet50
from shapely.geometry import Polygon

from keras_resnet.models import ResNet50
from keras import layers, models
import tensorflow as tf
import pyclipper
import keras.backend as K

def parse_args():
"""
Expand All @@ -31,7 +29,6 @@ def parse_args():
parser.add_argument('--output_dir', type=str, help='Output directory path')
return parser.parse_args()


def balanced_crossentropy_loss(args, negative_ratio=3., scale=5.):
pred, gt, mask = args
pred = pred[..., 0]
Expand All @@ -52,6 +49,7 @@ def balanced_crossentropy_loss(args, negative_ratio=3., scale=5.):

def dice_loss(args):
"""
Args:
pred: (b, h, w, 1)
gt: (b, h, w)
Expand Down Expand Up @@ -87,6 +85,7 @@ def db_loss(args):
return l1_loss_ + balanced_ce_loss_ + dice_loss_



def dbnet(input_size=640, k=50):
"""
Construct DBNet for text detection.
Expand All @@ -98,7 +97,6 @@ def dbnet(input_size=640, k=50):
Returns:
Tuple[models.Model, models.Model]: Training and prediction models.
"""

image_input = layers.Input(shape=(None, None, 3))
gt_input = layers.Input(shape=(input_size, input_size))
mask_input = layers.Input(shape=(input_size, input_size))
Expand Down Expand Up @@ -158,6 +156,7 @@ def dbnet(input_size=640, k=50):
return training_model, prediction_model



def resize_image(image, image_short_side=736):
"""
Resize the image while maintaining aspect ratio.
Expand All @@ -181,6 +180,7 @@ def resize_image(image, image_short_side=736):


def box_score_fast(bitmap, _box):
# 计算 box 包围的区域的平均得分
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
Expand Down Expand Up @@ -267,65 +267,39 @@ def polygons_from_bitmap(pred, bitmap, dest_width, dest_height, max_candidates=1
# Parse command line arguments for input and output directories.
args = parse_args()

# Define mean values for image normalization.
mean = np.array([103.939, 116.779, 123.68])

# Initialize the DBNet model.
_, model = dbnet()

# Load pre-trained weights for the model.
model.load_weights('./model.h5', by_name=True, skip_mismatch=True)

input_dir = args.input_dir
output_dir = args.output_dir

for image_path in glob.glob(osp.join(input_dir, '*.png')) + glob.glob(osp.join(input_dir, '*.jpg')):
image = cv2.imread(image_path)
src_image = image.copy()
h, w = image.shape[:2]
image = resize_image(image)
image = image.astype(np.float32)
image -= mean
image_input = np.expand_dims(image, axis=0)
p = model.predict(image_input)[0]
bitmap = p > 0.3
boxes, scores = polygons_from_bitmap(p, bitmap, w, h, box_thresh=0.5)

# Create an empty mask image
mask = np.zeros(src_image.shape[:2], dtype=np.uint8)

# Draw the contours of the green boxes on the mask
for box in boxes:
cv2.drawContours(mask, [np.array(box, dtype=np.int32)], -1, 255, thickness=cv2.FILLED)

# Set pixels outside the contours to zero
src_image[np.where(mask == 0)] = 0

# Save the modified image with zeros outside the contours
image_fname = osp.split(image_path)[-1]
output_path = osp.join(output_dir, image_fname)
cv2.imwrite(output_path, src_image)

# Set the input and output directory.
input_dir = args.input_dir
output_dir = args.output_dir

# Walk through the directory tree of the input directory.
for root, dirs, files in os.walk(input_dir):
for name in files:
# Process only PNG files. (feel free to change or add other format)
if name.endswith(".png"):
image_path = osp.join(root, name)
image = cv2.imread(image_path)

# Make a copy of the original image for later use.
src_image = image.copy()
h, w = image.shape[:2]

# Resize the image maintaining aspect ratio.
image = resize_image(image)
# Convert image to float32 type.
image = image.astype(np.float32)
# Normalize the image by subtracting mean values.
image -= mean

# Add batch dimension to the image.
image_input = np.expand_dims(image, axis=0)
# Predict text regions in the image.
p = model.predict(image_input)[0]
# Thresholding to create a binary map.
bitmap = p > 0.3

# Extract polygons around text regions and their scores.
boxes, scores = polygons_from_bitmap(p, bitmap, w, h, box_thresh=0.5)

# Create an empty mask image of the same size as the source image.
mask = np.zeros(src_image.shape[:2], dtype=np.uint8)

# Draw the contours on the mask based on the boxes.
for box in boxes:
cv2.drawContours(mask, [np.array(box, dtype=np.int32)], -1, 255, thickness=cv2.FILLED)

# Set pixels outside the contours to zero in the source image.
src_image[np.where(mask == 0)] = 0

# Create the output directory structure if it doesn't exist.
relative_root = osp.relpath(root, input_dir)
output_folder = osp.join(output_dir, relative_root)
if not osp.exists(output_folder):
os.makedirs(output_folder)

# Save the modified image to the output directory.
output_path = osp.join(output_folder, name)
cv2.imwrite(output_path, src_image)

0 comments on commit fc430be

Please sign in to comment.