Skip to content

Commit d41dd7c

Browse files
committed
Patch issues with table rec
1 parent 2375929 commit d41dd7c

File tree

8 files changed

+42
-30
lines changed

8 files changed

+42
-30
lines changed

ocr_app.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33

44
import pypdfium2
55
import streamlit as st
6-
from pypdfium2 import PdfiumError
76

87
from surya.detection import batch_text_detection
9-
from surya.input.pdflines import get_page_text_lines, get_table_blocks
108
from surya.layout import batch_layout_detection
119
from surya.model.detection.model import load_model, load_processor
1210
from surya.model.layout.model import load_model as load_layout_model
@@ -24,7 +22,7 @@
2422
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
2523
from surya.settings import settings
2624
from surya.tables import batch_table_recognition
27-
from surya.postprocessing.util import rescale_bboxes, rescale_bbox
25+
from surya.postprocessing.util import rescale_bbox
2826

2927

3028
@st.cache_resource()

surya/model/common/adetr/decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def forward(
374374
# Do cross-attention on encoder outputs
375375
cross_attn_inputs = self.cross_pre_norm(hidden_states)
376376
cross_attn_path = self.cross_attn_block(
377-
cross_attn_inputs, position_ids, encoder_hidden_states, attention_mask, encoder_attention_mask
377+
cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache
378378
)
379379
hidden_states = cross_attn_path + hidden_states
380380

surya/model/table_rec/decoder.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ def __init__(self, config):
4343

4444
self.config = config
4545

46-
def forward(self, boxes: torch.LongTensor):
46+
def forward(self, boxes: torch.LongTensor, *args):
47+
# Need to keep *args for compatibility with common decoder
48+
boxes = boxes.to(torch.long).clamp(0, self.config.vocab_size)
49+
4750
boxes_unbound = boxes.to(torch.long).unbind(dim=-1)
4851
cx, cy, w, h, xskew, yskew = boxes_unbound[self.component_idxs["bbox"][0]:self.component_idxs["bbox"][1]]
4952
category = boxes_unbound[self.component_idxs["category"][0]:self.component_idxs["category"][1]][0]

surya/model/table_rec/encoderdecoder.py

-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def __init__(
2525
self,
2626
config: Optional[PretrainedConfig] = None,
2727
encoder: Optional[PreTrainedModel] = None,
28-
text_encoder: Optional[PreTrainedModel] = None,
2928
decoder: Optional[PreTrainedModel] = None,
3029
):
3130
# initialize with config

surya/model/table_rec/model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from surya.model.table_rec.encoder import DonutSwinModel
2-
from surya.model.table_rec.config import SuryaTableRecConfig, SuryaTableRecDecoderConfig, DonutSwinTableRecConfig, \
3-
SuryaTableRecTextEncoderConfig
4-
from surya.model.table_rec.decoder import SuryaTableRecDecoder, SuryaTableRecTextEncoder
2+
from surya.model.table_rec.config import SuryaTableRecConfig, SuryaTableRecDecoderConfig, DonutSwinTableRecConfig
3+
from surya.model.table_rec.decoder import SuryaTableRecDecoder
54
from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
65
from surya.settings import settings
76

surya/model/table_rec/shaper.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from typing import List, Dict
23
import numpy as np
34
import torch
@@ -120,21 +121,24 @@ def convert_bbox_to_polygon(self, box, skew_scaler=BOX_DIM // 2, skew_min=.001):
120121
y1 = cy - height / 2
121122
x2 = cx + width / 2
122123
y2 = cy + height / 2
123-
skew_x = torch.floor((box[4] - skew_scaler) / 2)
124-
skew_y = torch.floor((box[5] - skew_scaler) / 2)
124+
skew_x = math.floor((box[4] - skew_scaler) / 2)
125+
skew_y = math.floor((box[5] - skew_scaler) / 2)
125126

126127
# Ensures we don't get slightly warped boxes
127128
# Note that the values are later scaled, so this is in 1/1024 space
128-
skew_x[torch.abs(skew_x) < skew_min] = 0
129-
skew_y[torch.abs(skew_y) < skew_min] = 0
129+
if abs(skew_x) < skew_min:
130+
skew_x = 0
131+
132+
if abs(skew_y) < skew_min:
133+
skew_y = 0
130134

131135
polygon = [x1 - skew_x, y1 - skew_y, x2 - skew_x, y1 + skew_y, x2 + skew_x, y2 + skew_y, x1 + skew_x,
132136
y2 - skew_y]
133137
poly = []
134138
for i in range(4):
135139
poly.append([
136-
polygon[2 * i].item(),
137-
polygon[2 * i + 1].item()
140+
polygon[2 * i],
141+
polygon[2 * i + 1]
138142
])
139143
return poly
140144

surya/tables.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,23 @@ def pad_to_batch_size(tensor: torch.Tensor, batch_size: int) -> torch.Tensor:
5757

5858
def inference_loop(
5959
model: nn.Module,
60-
processor: SuryaProcessor,
6160
encoder_hidden_states: torch.Tensor,
6261
batch_input_ids: torch.Tensor,
6362
current_batch_size: int,
63+
batch_size: int
6464
):
6565
shaper = LabelShaper()
66-
max_batch_size = batch_input_ids.shape[0]
6766
batch_predictions = [[] for _ in range(current_batch_size)]
6867
max_tokens = settings.TABLE_REC_MAX_BOXES
6968
decoder_position_ids = torch.ones_like(batch_input_ids[0, :, 0], dtype=torch.int64, device=model.device).cumsum(
7069
0) - 1
70+
inference_token_count = batch_input_ids.shape[1]
7171

72+
if settings.TABLE_REC_STATIC_CACHE:
73+
encoder_hidden_states = pad_to_batch_size(encoder_hidden_states, batch_size)
74+
batch_input_ids = pad_to_batch_size(batch_input_ids, batch_size)
75+
76+
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
7277

7378
with torch.inference_mode():
7479
token_count = 0
@@ -94,15 +99,15 @@ def inference_loop(
9499
for (k, kcount, mode) in BOX_PROPERTIES:
95100
k_logits = return_dict["box_property_logits"][k][j, -1, :]
96101
if mode == "classification":
97-
item = torch.argmax(k_logits, dim=-1).item()
102+
item = int(torch.argmax(k_logits, dim=-1).item())
98103
if k == "category":
99-
done.append(item == processor.tokenizer.eos_id or item == processor.tokenizer.pad_id)
104+
done.append(item == model.decoder.config.eos_token_id or item == model.decoder.config.pad_token_id)
100105
item -= SPECIAL_TOKENS
101106
box_property[k] = item
102107
elif mode == "regression":
103108
if k == "bbox":
104109
k_logits *= BOX_DIM
105-
box_property[k] = k_logits
110+
box_property[k] = k_logits.tolist()
106111
box_properties.append(box_property)
107112

108113
all_done = all_done | torch.tensor(done, dtype=torch.bool)
@@ -111,6 +116,7 @@ def inference_loop(
111116
break
112117

113118
batch_input_ids = torch.tensor(shaper.dict_to_labels(box_properties), dtype=torch.long).to(model.device)
119+
batch_input_ids = batch_input_ids.unsqueeze(1) # Add sequence length dimension
114120

115121
for j, (box_property, status) in enumerate(zip(box_properties, all_done)):
116122
if not status:
@@ -120,7 +126,7 @@ def inference_loop(
120126
inference_token_count = batch_input_ids.shape[1]
121127

122128
if settings.TABLE_REC_STATIC_CACHE:
123-
batch_input_ids = pad_to_batch_size(batch_input_ids, max_batch_size)
129+
batch_input_ids = pad_to_batch_size(batch_input_ids, batch_size)
124130
return batch_predictions
125131

126132

@@ -156,19 +162,14 @@ def batch_table_recognition(images: List, model: TableRecEncoderDecoderModel, pr
156162
batch_input_ids = model_inputs["input_ids"].to(model.device)
157163
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
158164

159-
if settings.TABLE_REC_STATIC_CACHE:
160-
batch_pixel_values = pad_to_batch_size(batch_pixel_values, batch_size)
161-
batch_input_ids = pad_to_batch_size(batch_input_ids, batch_size)
162-
163-
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
164-
model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
165165
shaper = LabelShaper()
166166

167167
# We only need to process each image once
168168
with torch.inference_mode():
169169
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state
170170

171-
row_predictions = inference_loop(model, processor, encoder_hidden_states, batch_input_ids, current_batch_size)
171+
row_predictions = inference_loop(model, encoder_hidden_states, batch_input_ids, current_batch_size, batch_size)
172+
172173
row_query_items = []
173174
row_encoder_hidden_states = []
174175
idx_map = []
@@ -186,7 +187,16 @@ def batch_table_recognition(images: List, model: TableRecEncoderDecoderModel, pr
186187

187188
row_encoder_hidden_states = torch.stack(row_encoder_hidden_states)
188189
row_inputs = processor(images=None, query_items=row_query_items, convert_images=False)
189-
cell_predictions = inference_loop(model, processor, row_encoder_hidden_states, row_inputs["input_ids"], len(row_query_items))
190+
row_input_ids = row_inputs["input_ids"].to(model.device)
191+
cell_predictions = []
192+
for j in tqdm(range(0, len(images), batch_size), desc="Recognizing tables"):
193+
cell_batch_hidden_states = row_encoder_hidden_states[j:j+batch_size]
194+
cell_batch_input_ids = row_input_ids[j:j+batch_size]
195+
cell_batch_size = len(cell_batch_input_ids)
196+
197+
cell_predictions.extend(
198+
inference_loop(model, cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size)
199+
)
190200

191201
batch_predictions = []
192202

table_recognition.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pypdfium2 as pdfium # Needs to be on top to avoid warning
21
import os
32
import argparse
43
import copy

0 commit comments

Comments
 (0)