@@ -113,102 +113,53 @@ def execute(self, requests):
113
113
be the same as `requests`
114
114
"""
115
115
116
+ tokens_batch = []
117
+ sequence_lengths = []
118
+ for idx , request in enumerate (requests ):
119
+ for input_tensor in request .inputs ():
120
+ if input_tensor .name () == "TOKENS_BATCH" :
121
+ tokens_batch .append (input_tensor .as_numpy ())
122
+ elif input_tensor .name () == "SEQUENCE_LENGTH" :
123
+ sequence_lengths .append (input_tensor .as_numpy ())
124
+ else :
125
+ raise ValueError (f"unknown input { input_tensor .name } " )
126
+
127
+ # batch decode
128
+ list_of_tokens = []
129
+ req_idx_offset = 0
130
+ req_idx_offsets = [req_idx_offset ]
131
+ for idx , token_batch in enumerate (tokens_batch ):
132
+ for batch_idx , beam_tokens in enumerate (token_batch ):
133
+ for beam_idx , tokens in enumerate (beam_tokens ):
134
+ seq_len = sequence_lengths [idx ][batch_idx ][beam_idx ]
135
+ # Exclude fake ids in multimodal models
136
+ fake_id_len = 0
137
+ for i in range (seq_len ):
138
+ if tokens [i ] < self .tokenizer .vocab_size :
139
+ fake_id_len = i
140
+ break
141
+ list_of_tokens .append (tokens [fake_id_len :seq_len ])
142
+ req_idx_offset += 1
143
+
144
+ req_idx_offsets .append (req_idx_offset )
145
+
146
+ all_outputs = self .tokenizer .batch_decode (
147
+ list_of_tokens , skip_special_tokens = self .skip_special_tokens )
148
+
149
+ # construct responses
116
150
responses = []
117
-
118
- # Every Python backend must iterate over everyone of the requests
119
- # and create a pb_utils.InferenceResponse for each of them.
120
151
for idx , request in enumerate (requests ):
121
- # Get input tensors
122
- tokens_batch = pb_utils .get_input_tensor_by_name (
123
- request , 'TOKENS_BATCH' ).as_numpy ()
124
-
125
- # Get sequence length
126
- sequence_lengths = pb_utils .get_input_tensor_by_name (
127
- request , 'SEQUENCE_LENGTH' ).as_numpy ()
128
-
129
- # Get cum log probs
130
- cum_log_probs = pb_utils .get_input_tensor_by_name (
131
- request , 'CUM_LOG_PROBS' )
132
-
133
- # Get sequence length
134
- output_log_probs = pb_utils .get_input_tensor_by_name (
135
- request , 'OUTPUT_LOG_PROBS' )
136
-
137
- # Get context logits
138
- context_logits = pb_utils .get_input_tensor_by_name (
139
- request , 'CONTEXT_LOGITS' )
140
-
141
- # Get generation logits
142
- generation_logits = pb_utils .get_input_tensor_by_name (
143
- request , 'GENERATION_LOGITS' )
144
-
145
- # Get the batch index
146
- batch_index = pb_utils .get_input_tensor_by_name (
147
- request , 'BATCH_INDEX' )
148
-
149
- # Reshape Input
150
- # tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]])
151
- # tokens_batch = tokens_batch.T
152
+ req_outputs = [
153
+ x .encode ('utf8' )
154
+ for x in all_outputs [req_idx_offsets [idx ]:req_idx_offsets [idx +
155
+ 1 ]]
156
+ ]
152
157
153
- # Postprocessing output data.
154
- outputs = self ._postprocessing (tokens_batch , sequence_lengths )
155
-
156
- # Create output tensors. You need pb_utils.Tensor
157
- # objects to create pb_utils.InferenceResponse.
158
158
output_tensor = pb_utils .Tensor (
159
159
'OUTPUT' ,
160
- np .array (outputs ).astype (self .output_dtype ))
160
+ np .array (req_outputs ).astype (self .output_dtype ))
161
161
162
- outputs = []
163
- outputs .append (output_tensor )
164
-
165
- if cum_log_probs :
166
- out_cum_log_probs = pb_utils .Tensor ('OUT_CUM_LOG_PROBS' ,
167
- cum_log_probs .as_numpy ())
168
- outputs .append (out_cum_log_probs )
169
- else :
170
- out_cum_log_probs = pb_utils .Tensor (
171
- 'OUT_CUM_LOG_PROBS' , np .array ([[0.0 ]], dtype = np .float32 ))
172
- outputs .append (out_cum_log_probs )
173
-
174
- if output_log_probs :
175
- out_output_log_probs = pb_utils .Tensor (
176
- 'OUT_OUTPUT_LOG_PROBS' , output_log_probs .as_numpy ())
177
- outputs .append (out_output_log_probs )
178
- else :
179
- out_output_log_probs = pb_utils .Tensor (
180
- 'OUT_OUTPUT_LOG_PROBS' ,
181
- np .array ([[[0.0 ]]], dtype = np .float32 ))
182
- outputs .append (out_output_log_probs )
183
-
184
- if context_logits :
185
- out_context_logits = pb_utils .Tensor ('OUT_CONTEXT_LOGITS' ,
186
- context_logits .as_numpy ())
187
- outputs .append (out_context_logits )
188
- else :
189
- out_context_logits = pb_utils .Tensor (
190
- 'OUT_CONTEXT_LOGITS' , np .array ([[[0.0 ]]],
191
- dtype = np .float32 ))
192
- outputs .append (out_context_logits )
193
-
194
- if generation_logits :
195
- out_generation_logits = pb_utils .Tensor (
196
- 'OUT_GENERATION_LOGITS' , generation_logits .as_numpy ())
197
- outputs .append (out_generation_logits )
198
- else :
199
- out_generation_logits = pb_utils .Tensor (
200
- 'OUT_GENERATION_LOGITS' ,
201
- np .array ([[[[0.0 ]]]], dtype = np .float32 ))
202
- outputs .append (out_generation_logits )
203
-
204
- if batch_index :
205
- out_batch_index = pb_utils .Tensor ('OUT_BATCH_INDEX' ,
206
- batch_index .as_numpy ())
207
- outputs .append (out_batch_index )
208
- else :
209
- out_batch_index = pb_utils .Tensor (
210
- 'OUT_BATCH_INDEX' , np .array ([[0 ]], dtype = np .int32 ))
211
- outputs .append (out_batch_index )
162
+ outputs = [output_tensor ]
212
163
213
164
# Create InferenceResponse. You can set an error here in case
214
165
# there was a problem with handling this inference request.
@@ -220,7 +171,6 @@ def execute(self, requests):
220
171
inference_response = pb_utils .InferenceResponse (
221
172
output_tensors = outputs )
222
173
responses .append (inference_response )
223
-
224
174
# You should return a list of pb_utils.InferenceResponse. Length
225
175
# of this list must match the length of `requests` list.
226
176
return responses
@@ -231,20 +181,3 @@ def finalize(self):
231
181
the model to perform any necessary clean ups before exit.
232
182
"""
233
183
print ('Cleaning up...' )
234
-
235
- def _postprocessing (self , tokens_batch , sequence_lengths ):
236
- outputs = []
237
- for batch_idx , beam_tokens in enumerate (tokens_batch ):
238
- for beam_idx , tokens in enumerate (beam_tokens ):
239
- seq_len = sequence_lengths [batch_idx ][beam_idx ]
240
- # Exclude fake ids in multimodal models
241
- fake_id_len = 0
242
- for i in range (seq_len ):
243
- if tokens [i ] < len (self .tokenizer .vocab ):
244
- fake_id_len = i
245
- break
246
- output = self .tokenizer .decode (
247
- tokens [fake_id_len :seq_len ],
248
- skip_special_tokens = self .skip_special_tokens )
249
- outputs .append (output .encode ('utf8' ))
250
- return outputs
0 commit comments