Skip to content

Commit 96acba1

Browse files
committed
comments
1 parent 08ada4d commit 96acba1

File tree

5 files changed

+101
-25
lines changed

5 files changed

+101
-25
lines changed

exps/cbad/demo_processing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,27 @@ def baseline_extraction(model_dir: str,
4040
os.makedirs(drawing_dir)
4141

4242
with tf.Session(config=config):
43+
# Load the model
4344
m = LoadedModel(model_dir, predict_mode='filename_original_shape')
4445
for filename in tqdm(filenames_to_process, desc='Prediction'):
46+
# Inference
4547
prediction = m.predict(filename)
48+
# Take the first element of the 'probs' dictionary (batch size = 1)
4649
probs = prediction['probs'][0]
4750
original_shape = probs.shape
4851

52+
# The baselines probs are on the second channel
4953
baseline_probs = probs[:, :, 1]
5054
contours, _ = line_extraction_v1(baseline_probs, low_threshold=0.2, high_threshold=0.4, sigma=1.5)
5155

5256
basename = os.path.basename(filename).split('.')[0]
5357

58+
# Compute the ratio to save the coordinates in the original image coordinates reference.
5459
ratio = (original_shape[0] / probs.shape[0], original_shape[1] / probs.shape[1])
5560
xml_filename = os.path.join(output_dir, basename + '.xml')
5661
page_object = PAGE.save_baselines(xml_filename, contours, ratio, predictions_shape=probs.shape[:2])
5762

63+
# If specified, saves the images with the annotated baslines
5864
if draw_extractions:
5965
image = imread(filename)
6066
page_object.draw_baselines(image, color=(255, 0, 0), thickness=5)

exps/cbad/evaluation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ def eval_fn(input_dir: str,
2020
jar_tool_path: str=CBAD_JAR,
2121
masks_dir: str=None) -> dict:
2222
"""
23+
Evaluates a model against the selected set ('groundtruth_dir' contains XML files)
2324
2425
:param input_dir: Input directory containing probability maps (.npy)
2526
:param groudtruth_dir: directory containg XML groundtruths
2627
:param output_dir: output directory for results
2728
:param post_process_params: parameters form post processing of probability maps
29+
:param channel_baselines: the baseline class chanel
2830
:param jar_tool_path: path to cBAD evaluation tool (.jar file)
2931
:param masks_dir: optional, directory where binary masks of the page are stored (.png)
3032
:return:
@@ -82,7 +84,7 @@ def eval_fn(input_dir: str,
8284
}
8385

8486

85-
def parse_score_txt(score_txt, output_csv):
87+
def parse_score_txt(score_txt: str, output_csv: str):
8688
lines = score_txt.splitlines()
8789
header_ind = next((i for i, l in enumerate(lines)
8890
if l == '#P value, #R value, #F_1 value, #TruthFileName, #HypoFileName'))

exps/cbad/make_cbad.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,14 @@ def generate_cbad_dataset(downloading_dir: str, masks_dir: str):
3333
for dir_tuple in dirs_tuple:
3434
input_dir, output_dir = dir_tuple
3535
os.makedirs(output_dir, exist_ok=True)
36+
# For each set create the folder with the annotated data
3637
cbad_set_generator(input_dir=input_dir,
3738
output_dir=output_dir,
3839
img_size=2e6,
3940
draw_baselines=True,
4041
draw_endpoints=False)
4142

43+
# Split the 'official' train set into training and validation set
4244
if 'train' in output_dir:
4345
print('Make eval set from the given training data (0.15/0.85 eval/train)')
4446
csv_filename = os.path.join(output_dir, 'set_data.csv')

exps/cbad/process.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ def prediction_fn(model_dir: str,
2222
"""
2323
Given a model directory this function will load the model and apply it to the files (.jpg, .png) found in input_dir.
2424
The predictions will be saved in output_dir as .npy files (values ranging [0,255])
25+
2526
:param model_dir: Directory containing the saved model
2627
:param input_dir: input directory where the images to predict are
2728
:param output_dir: output directory to save the predictions (probability images)
29+
:param config: ConfigProto object to pass to the session in order to define which GPU to use
2830
:return:
2931
"""
3032
if not output_dir:
@@ -50,15 +52,17 @@ def cbad_post_processing_fn(probs: np.array,
5052
vertical_maxima: bool=False,
5153
output_basename=None) -> Tuple[List[np.ndarray], np.ndarray]:
5254
"""
55+
Given a probability map, returns the contour of lines and the corresponding mask.
56+
Saves the results in .pkl file if requested.
5357
5458
:param probs: output of the model (probabilities) in range [0, 255]
5559
:param baseline_chanel: channel where the baseline class is detected
56-
:param sigma:
57-
:param low_threshold:
58-
:param high_threshold:
59-
:param filter_width:
60-
:param output_basename:
61-
:param vertical_maxima:
60+
:param sigma: sigma value for gaussian filtering
61+
:param low_threshold: hysteresis low threshold
62+
:param high_threshold: hysteresis high threshold
63+
:param filter_width: percentage of the image width to filter out lines that are close to borders (default 0.0)
64+
:param output_basename: name of file to save the intermediaty result as .pkl file.
65+
:param vertical_maxima: set to True to use vertical local maxima as candidates for the hysteresis thresholding
6266
:return: contours, mask
6367
WARNING : contours IN OPENCV format List[np.ndarray(n_points, 1, (x,y))]
6468
"""
@@ -76,6 +80,17 @@ def line_extraction_v1(probs: np.ndarray,
7680
sigma: float=0.0,
7781
filter_width: float=0.00,
7882
vertical_maxima: bool=False) -> Tuple[List[np.ndarray], np.ndarray]:
83+
"""
84+
Given a probability map, returns the contour of lines and the corresponding mask
85+
86+
:param probs: probability map (numpy array)
87+
:param low_threshold: hysteresis low threshold
88+
:param high_threshold: hysteresis high threshold
89+
:param sigma: sigma value for gaussian filtering
90+
:param filter_width: percentage of the image width to filter out lines that are close to borders (default 0.0)
91+
:param vertical_maxima: set to True to use vertical local maxima as candidates for the hysteresis thresholding
92+
:return:
93+
"""
7994
# Smooth
8095
probs2 = cleaning_probs(probs, sigma=sigma)
8196

@@ -129,6 +144,7 @@ def extract_lines(npy_filename: str,
129144
debug: bool=False):
130145
"""
131146
From the prediction files (probs) (.npy) finds and extracts the lines into PAGE-XML format.
147+
132148
:param npy_filename: filename of saved predictions (probs) in range (0,255)
133149
:param output_dir: output direcoty to save the xml files
134150
:param original_shape: shpae of the original input image (to rescale the extracted lines if necessary)

exps/cbad/utils.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tqdm import tqdm
1313
from dh_segment.io import PAGE
1414

15+
# Constant definitions
1516
TARGET_HEIGHT = 1100
1617
DRAWING_COLOR_BASELINES = (255, 0, 0)
1718
DRAWING_COLOR_LINES = (0, 255, 0)
@@ -22,6 +23,12 @@
2223

2324

2425
def get_page_filename(image_filename: str) -> str:
26+
"""
27+
Given an path to a .jpg or .png file, get the corresponding .xml file.
28+
29+
:param image_filename: filename of the image
30+
:return: the filename of the corresponding .xml file, raises exception if .xml file does not exist
31+
"""
2532
page_filename = os.path.join(os.path.dirname(image_filename),
2633
'page',
2734
'{}.xml'.format(os.path.basename(image_filename)[:-4]))
@@ -33,13 +40,31 @@ def get_page_filename(image_filename: str) -> str:
3340

3441

3542
def get_image_label_basename(image_filename: str) -> str:
43+
"""
44+
Creates a new filename composed of the begining of the folder/collection (ex. EPFL, ABP) and the original filename
45+
46+
:param image_filename: path of the image filename
47+
:return:
48+
"""
3649
# Get acronym followed by name of file
3750
directory, basename = os.path.split(image_filename)
3851
acronym = directory.split(os.path.sep)[-1].split('_')[0]
3952
return '{}_{}'.format(acronym, basename.split('.')[0])
4053

4154

42-
def save_and_resize(img: np.array, filename: str, size=None, nearest: bool=False) -> None:
55+
def save_and_resize(img: np.array,
56+
filename: str,
57+
size=None,
58+
nearest: bool=False) -> None:
59+
"""
60+
Resizes the image if necessary and saves it. The resizing will keep the image ratio
61+
62+
:param img: the image to resize and save (numpy array)
63+
:param filename: filename of the saved image
64+
:param size: size of the image after resizing (in pixels). The ratio of the original image will be kept
65+
:param nearest: whether to use nearest interpolation method (default to False)
66+
:return:
67+
"""
4368
if size is not None:
4469
h, w = img.shape[:2]
4570
ratio = float(np.sqrt(size/(h*w)))
@@ -59,30 +84,36 @@ def annotate_one_page(image_filename: str,
5984
baseline_thickness: float=0.2,
6085
diameter_endpoint: int=20) -> Tuple[str, str]:
6186
"""
87+
Creates an annotated mask and corresponding original image and saves it in 'labels' and 'images' folders.
88+
Also copies the corresponding .xml file into 'gt' folder.
6289
63-
:param image_filename:
64-
:param output_dir:
90+
:param image_filename: filename of the image to process
91+
:param output_dir: directory to output the annotated label image
6592
:param size: Size of the resized image (# pixels)
6693
:param draw_baselines: Draws the baselines (boolean)
6794
:param draw_lines: Draws the polygon's lines (boolean)
6895
:param draw_endpoints: Predict beginning and end of baselines (True, False)
6996
:param baseline_thickness: Thickness of annotated baseline (percentage of the line's height)
7097
:param diameter_endpoint: Diameter of annotated start/end points
71-
:return:
98+
:return: (output_image_path, output_label_path)
7299
"""
73100

74101
page_filename = get_page_filename(image_filename)
102+
# Parse xml file and get TextLines
75103
page = PAGE.parse_file(page_filename)
76104
text_lines = [tl for tr in page.text_regions for tl in tr.text_lines]
77105
img = imread(image_filename, pilmode='RGB')
106+
# Create empty mask
78107
gt = np.zeros_like(img)
79108

80109
if text_lines:
81110
if draw_baselines:
82111
# Thickness : should be a percentage of the line height, for example 0.2
112+
# First, get the mean line height.
83113
mean_line_height, _, _ = _compute_statistics_line_height(page)
84114
absolute_baseline_thickness = int(max(gt.shape[0]*0.002, baseline_thickness*mean_line_height))
85115

116+
# Draw the baselines
86117
gt_baselines = np.zeros_like(img[:, :, 0])
87118
gt_baselines = cv2.polylines(gt_baselines,
88119
[PAGE.Point.list_to_cv2poly(tl.baseline) for tl in
@@ -92,6 +123,7 @@ def annotate_one_page(image_filename: str,
92123
gt[:, :, np.argmax(DRAWING_COLOR_BASELINES)] = gt_baselines
93124

94125
if draw_lines:
126+
# Draw the lines
95127
gt_lines = np.zeros_like(img[:, :, 0])
96128
for tl in text_lines:
97129
gt_lines = cv2.fillPoly(gt_lines,
@@ -100,6 +132,7 @@ def annotate_one_page(image_filename: str,
100132
gt[:, :, np.argmax(DRAWING_COLOR_LINES)] = gt_lines
101133

102134
if draw_endpoints:
135+
# Draw endpoints of baselines
103136
gt_points = np.zeros_like(img[:, :, 0])
104137
for tl in text_lines:
105138
try:
@@ -113,11 +146,14 @@ def annotate_one_page(image_filename: str,
113146
print('Length of baseline is {}'.format(len(tl.baseline)))
114147
gt[:, :, np.argmax(DRAWING_COLOR_POINTS)] = gt_points
115148

149+
# Make output filenames
116150
image_label_basename = get_image_label_basename(image_filename)
117151
output_image_path = os.path.join(output_dir, 'images', '{}.jpg'.format(image_label_basename))
118152
output_label_path = os.path.join(output_dir, 'labels', '{}.png'.format(image_label_basename))
153+
# Resize (if necessary) and save image and label
119154
save_and_resize(img, output_image_path, size=size)
120155
save_and_resize(gt, output_label_path, size=size, nearest=True)
156+
# Copy XML file to 'gt' folder
121157
shutil.copy(page_filename, os.path.join(output_dir, 'gt', '{}.xml'.format(image_label_basename)))
122158

123159
return os.path.abspath(output_image_path), os.path.abspath(output_label_path)
@@ -133,6 +169,7 @@ def cbad_set_generator(input_dir: str,
133169
draw_endpoints: bool=False,
134170
circle_thickness: int =20) -> None:
135171
"""
172+
Creates a set with 'images', 'labels', 'gt' folders, classes.txt file and .csv data
136173
137174
:param input_dir: Input directory containing images and PAGE files
138175
:param output_dir: Output directory to save images and labels
@@ -198,6 +235,12 @@ def cbad_set_generator(input_dir: str,
198235

199236

200237
def split_set_for_eval(csv_filename: str) -> None:
238+
"""
239+
Splits set into two sets (0.15 and 0.85).
240+
241+
:param csv_filename: path to csv file containing in each row image_filename,label_filename
242+
:return:
243+
"""
201244

202245
df_data = pd.read_csv(csv_filename, header=None)
203246

@@ -212,25 +255,26 @@ def split_set_for_eval(csv_filename: str) -> None:
212255
df_train.to_csv(os.path.join(saving_dir, 'train_data.csv'), header=False, index=False, encoding='utf8')
213256

214257

215-
def draw_lines_fn(xml_filename: str, output_dir: str):
216-
"""
217-
GIven an XML PAGE file, draws the corresponding lines in the original image.
218-
:param xml_filename:
219-
:param output_dir:
220-
:return:
221-
"""
222-
basename = os.path.basename(xml_filename).split('.')[0]
223-
generated_page = PAGE.parse_file(xml_filename)
224-
drawing_img = generated_page.image_filename
225-
generated_page.draw_baselines(drawing_img, color=(0, 0, 255))
226-
imsave(os.path.join(output_dir, '{}.jpg'.format(basename)), drawing_img)
258+
# def draw_lines_fn(xml_filename: str, output_dir: str):
259+
# """
260+
# Given an XML PAGE file, draws the corresponding lines in the original image.
261+
#
262+
# :param xml_filename:
263+
# :param output_dir:
264+
# :return:
265+
# """
266+
# basename = os.path.basename(xml_filename).split('.')[0]
267+
# generated_page = PAGE.parse_file(xml_filename)
268+
# drawing_img = generated_page.image_filename
269+
# generated_page.draw_baselines(drawing_img, color=(0, 0, 255))
270+
# imsave(os.path.join(output_dir, '{}.jpg'.format(basename)), drawing_img)
227271

228272

229273
def _compute_statistics_line_height(page_class: PAGE.Page, verbose: bool=False) -> Tuple[float, float, float]:
230274
"""
231-
Function to compute mean and std of line height among a page.
275+
Function to compute mean and std of line height in a page.
232276
233-
:param page_class: json Page
277+
:param page_class: PAGE.Page object
234278
:param verbose: either to print computational info or not
235279
:return: tuple (mean, standard deviation, median)
236280
"""
@@ -311,6 +355,12 @@ def update_to(b: int=1, bsize: int=1, tsize: int=None):
311355

312356

313357
def cbad_download(output_dir: str):
358+
"""
359+
Download BAD-READ dataset.
360+
361+
:param output_dir: folder where to download the data
362+
:return:
363+
"""
314364
os.makedirs(output_dir, exist_ok=True)
315365
zip_filename = os.path.join(output_dir, 'cbad-icdar17.zip')
316366

0 commit comments

Comments
 (0)