@@ -57,18 +57,23 @@ def pad_to_batch_size(tensor: torch.Tensor, batch_size: int) -> torch.Tensor:
57
57
58
58
def inference_loop (
59
59
model : nn .Module ,
60
- processor : SuryaProcessor ,
61
60
encoder_hidden_states : torch .Tensor ,
62
61
batch_input_ids : torch .Tensor ,
63
62
current_batch_size : int ,
63
+ batch_size : int
64
64
):
65
65
shaper = LabelShaper ()
66
- max_batch_size = batch_input_ids .shape [0 ]
67
66
batch_predictions = [[] for _ in range (current_batch_size )]
68
67
max_tokens = settings .TABLE_REC_MAX_BOXES
69
68
decoder_position_ids = torch .ones_like (batch_input_ids [0 , :, 0 ], dtype = torch .int64 , device = model .device ).cumsum (
70
69
0 ) - 1
70
+ inference_token_count = batch_input_ids .shape [1 ]
71
71
72
+ if settings .TABLE_REC_STATIC_CACHE :
73
+ encoder_hidden_states = pad_to_batch_size (encoder_hidden_states , batch_size )
74
+ batch_input_ids = pad_to_batch_size (batch_input_ids , batch_size )
75
+
76
+ model .decoder .model ._setup_cache (model .config , batch_size , model .device , model .dtype )
72
77
73
78
with torch .inference_mode ():
74
79
token_count = 0
@@ -94,15 +99,15 @@ def inference_loop(
94
99
for (k , kcount , mode ) in BOX_PROPERTIES :
95
100
k_logits = return_dict ["box_property_logits" ][k ][j , - 1 , :]
96
101
if mode == "classification" :
97
- item = torch .argmax (k_logits , dim = - 1 ).item ()
102
+ item = int ( torch .argmax (k_logits , dim = - 1 ).item () )
98
103
if k == "category" :
99
- done .append (item == processor . tokenizer . eos_id or item == processor . tokenizer . pad_id )
104
+ done .append (item == model . decoder . config . eos_token_id or item == model . decoder . config . pad_token_id )
100
105
item -= SPECIAL_TOKENS
101
106
box_property [k ] = item
102
107
elif mode == "regression" :
103
108
if k == "bbox" :
104
109
k_logits *= BOX_DIM
105
- box_property [k ] = k_logits
110
+ box_property [k ] = k_logits . tolist ()
106
111
box_properties .append (box_property )
107
112
108
113
all_done = all_done | torch .tensor (done , dtype = torch .bool )
@@ -111,6 +116,7 @@ def inference_loop(
111
116
break
112
117
113
118
batch_input_ids = torch .tensor (shaper .dict_to_labels (box_properties ), dtype = torch .long ).to (model .device )
119
+ batch_input_ids = batch_input_ids .unsqueeze (1 ) # Add sequence length dimension
114
120
115
121
for j , (box_property , status ) in enumerate (zip (box_properties , all_done )):
116
122
if not status :
@@ -120,7 +126,7 @@ def inference_loop(
120
126
inference_token_count = batch_input_ids .shape [1 ]
121
127
122
128
if settings .TABLE_REC_STATIC_CACHE :
123
- batch_input_ids = pad_to_batch_size (batch_input_ids , max_batch_size )
129
+ batch_input_ids = pad_to_batch_size (batch_input_ids , batch_size )
124
130
return batch_predictions
125
131
126
132
@@ -156,19 +162,14 @@ def batch_table_recognition(images: List, model: TableRecEncoderDecoderModel, pr
156
162
batch_input_ids = model_inputs ["input_ids" ].to (model .device )
157
163
batch_pixel_values = torch .tensor (np .array (batch_pixel_values ), dtype = model .dtype ).to (model .device )
158
164
159
- if settings .TABLE_REC_STATIC_CACHE :
160
- batch_pixel_values = pad_to_batch_size (batch_pixel_values , batch_size )
161
- batch_input_ids = pad_to_batch_size (batch_input_ids , batch_size )
162
-
163
- model .decoder .model ._setup_cache (model .config , batch_size , model .device , model .dtype )
164
- model .text_encoder .model ._setup_cache (model .config , batch_size , model .device , model .dtype )
165
165
shaper = LabelShaper ()
166
166
167
167
# We only need to process each image once
168
168
with torch .inference_mode ():
169
169
encoder_hidden_states = model .encoder (pixel_values = batch_pixel_values ).last_hidden_state
170
170
171
- row_predictions = inference_loop (model , processor , encoder_hidden_states , batch_input_ids , current_batch_size )
171
+ row_predictions = inference_loop (model , encoder_hidden_states , batch_input_ids , current_batch_size , batch_size )
172
+
172
173
row_query_items = []
173
174
row_encoder_hidden_states = []
174
175
idx_map = []
@@ -186,7 +187,16 @@ def batch_table_recognition(images: List, model: TableRecEncoderDecoderModel, pr
186
187
187
188
row_encoder_hidden_states = torch .stack (row_encoder_hidden_states )
188
189
row_inputs = processor (images = None , query_items = row_query_items , convert_images = False )
189
- cell_predictions = inference_loop (model , processor , row_encoder_hidden_states , row_inputs ["input_ids" ], len (row_query_items ))
190
+ row_input_ids = row_inputs ["input_ids" ].to (model .device )
191
+ cell_predictions = []
192
+ for j in tqdm (range (0 , len (images ), batch_size ), desc = "Recognizing tables" ):
193
+ cell_batch_hidden_states = row_encoder_hidden_states [j :j + batch_size ]
194
+ cell_batch_input_ids = row_input_ids [j :j + batch_size ]
195
+ cell_batch_size = len (cell_batch_input_ids )
196
+
197
+ cell_predictions .extend (
198
+ inference_loop (model , cell_batch_hidden_states , cell_batch_input_ids , cell_batch_size , batch_size )
199
+ )
190
200
191
201
batch_predictions = []
192
202
0 commit comments