Skip to content

Commit 30444a0

Browse files
committed
Update to new text detection model
1 parent 294f711 commit 30444a0

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

surya/common/polygon.py

+15
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,21 @@ def merge(self, other):
8686
y2 = max(self.bbox[3], other.bbox[3])
8787
self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
8888

89+
def expand(self, x_margin: float, y_margin: float):
90+
new_polygon = []
91+
x_margin = x_margin * self.width
92+
y_margin = y_margin * self.height
93+
for idx, poly in enumerate(self.polygon):
94+
if idx == 0:
95+
new_polygon.append([poly[0] - x_margin, poly[1] - y_margin])
96+
elif idx == 1:
97+
new_polygon.append([poly[0] + x_margin, poly[1] - y_margin])
98+
elif idx == 2:
99+
new_polygon.append([poly[0] + x_margin, poly[1] + y_margin])
100+
elif idx == 3:
101+
new_polygon.append([poly[0] - x_margin, poly[1] + y_margin])
102+
self.polygon = new_polygon
103+
89104
def intersection_polygon(self, other) -> List[List[float]]:
90105
new_poly = []
91106
for i in range(4):

surya/detection/heatmap.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def get_and_clean_boxes(textmap, processor_size, image_size, text_threshold=None
122122
bboxes = clean_boxes(bboxes)
123123
return bboxes
124124

125+
125126
def parallel_get_lines(preds, orig_sizes, include_maps=False):
126127
heatmap, affinity_map = preds
127128
heat_img, aff_img = None, None
@@ -143,18 +144,24 @@ def parallel_get_lines(preds, orig_sizes, include_maps=False):
143144
return result
144145

145146
def parallel_get_boxes(preds, orig_sizes, include_maps=False):
146-
heatmap, _ = preds
147+
heatmap, affinity_map = preds
147148
heat_img, aff_img = None, None
149+
148150
if include_maps:
149151
heat_img = Image.fromarray((heatmap * 255).astype(np.uint8))
152+
aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8))
150153
heatmap_size = list(reversed(heatmap.shape))
151154
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes)
155+
for box in bboxes:
156+
#Skip for vertical boxes
157+
if box.height<3*box.width:
158+
box.expand(x_margin=0, y_margin=settings.DETECTOR_BOX_Y_EXPAND_MARGIN)
152159

153160
result = TextDetectionResult(
154161
bboxes=bboxes,
155162
vertical_lines=[],
156163
heatmap=heat_img,
157-
affinity_map=None,
164+
affinity_map=aff_img,
158165
image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]]
159166
)
160167
return result

surya/settings.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@ def TORCH_DEVICE_MODEL(self) -> str:
4848

4949
# Text detection
5050
DETECTOR_BATCH_SIZE: Optional[int] = None # Defaults to 2 for CPU/MPS, 32 otherwise
51-
DETECTOR_MODEL_CHECKPOINT: str = "s3://text_detection/2025_02_18"
51+
DETECTOR_MODEL_CHECKPOINT: str = "s3://text_detection/2025_02_28"
5252
DETECTOR_BENCH_DATASET_NAME: str = "vikp/doclaynet_bench"
5353
DETECTOR_IMAGE_CHUNK_HEIGHT: int = 1400 # Height at which to slice images vertically
5454
DETECTOR_TEXT_THRESHOLD: float = 0.6 # Threshold for text detection (above this is considered text)
5555
DETECTOR_BLANK_THRESHOLD: float = 0.35 # Threshold for blank space (below this is considered blank)
5656
DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(8, os.cpu_count()) # Number of workers for postprocessing
5757
DETECTOR_MIN_PARALLEL_THRESH: int = 3 # Minimum number of images before we parallelize
58+
DETECTOR_BOX_Y_EXPAND_MARGIN: float = 0.025 #Margin by which to expand detected boxes vertically
5859
COMPILE_DETECTOR: bool = False
5960

6061
# Inline math detection

0 commit comments

Comments
 (0)