@@ -90,16 +90,16 @@ def index(
90
90
# if no, then we need to broadcast
91
91
92
92
last_index = None
93
- broadcast_shape_len = 0
94
93
for i , ind in enumerate (index ):
95
94
if ind is not None :
96
95
_LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
97
96
adv_indx_indices .append (i )
98
97
# torch.nn.parameter.Parameter=> torch.Tensor
99
- ind = get_trt_tensor (network , ind , f"parameter_to_fp32_tensor_ { i } " )
98
+ ind = get_trt_tensor (network , ind , name + f"_parameter_to_fp32_tensor_ { i } " )
100
99
if last_index is not None :
101
- if not (broadcastable (ind , last_index )):
102
- assert "The indices should be broadcastable"
100
+ assert broadcastable (
101
+ ind , last_index
102
+ ), "The indices should be broadcastable!"
103
103
last_index = ind
104
104
tensor_indices .append (ind )
105
105
@@ -129,7 +129,7 @@ def index(
129
129
130
130
for i in range (rank ):
131
131
dim = input_shape [i ]
132
- dim_tensor = get_trt_tensor (network , dim , f"individual_dim_ { i } " )
132
+ dim_tensor = get_trt_tensor (network , dim , name + f"_individual_dim_ { i } " )
133
133
# dim_tensor_list is a list of tensors
134
134
dim_tensor_list .append (dim_tensor )
135
135
@@ -166,8 +166,8 @@ def index(
166
166
167
167
concat_tensor_layer = network .add_concatenation (
168
168
[
169
- get_trt_tensor (network , mult_d0 , "d0_shape " ),
170
- get_trt_tensor (network , mult_d1 , "d1_shape " ),
169
+ get_trt_tensor (network , mult_d0 , name + "_d0_shape " ),
170
+ get_trt_tensor (network , mult_d1 , name + "_d1_shape " ),
171
171
]
172
172
)
173
173
set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
@@ -182,15 +182,17 @@ def index(
182
182
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
183
183
# // j dimension of input x.
184
184
multiplier = get_trt_tensor (
185
- network , dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]], "dim_last"
185
+ network ,
186
+ dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
187
+ name + "_dim_last" ,
186
188
)
187
189
cum_adv_index = tensor_indices [adv_indx_count - 1 ]
188
190
for i in range (adv_indx_count - 2 , - 1 , - 1 ):
189
191
adv_index = convert_binary_elementwise (
190
192
network ,
191
193
target ,
192
194
source_ir ,
193
- name + "index_intermediate " ,
195
+ name + f"_index_intermediate_ { i } " ,
194
196
trt .ElementWiseOperation .PROD ,
195
197
multiplier ,
196
198
tensor_indices [i ],
@@ -199,7 +201,7 @@ def index(
199
201
network ,
200
202
target ,
201
203
source_ir ,
202
- name + "index_sum_intermediate " ,
204
+ name + f"_index_sum_intermediate_ { i } " ,
203
205
trt .ElementWiseOperation .SUM ,
204
206
cum_adv_index ,
205
207
adv_index ,
@@ -208,7 +210,7 @@ def index(
208
210
network ,
209
211
target ,
210
212
source_ir ,
211
- name + "index_intermediate " ,
213
+ name + f"_index_intermediate_xj_ { i } " ,
212
214
trt .ElementWiseOperation .PROD ,
213
215
multiplier ,
214
216
dim_tensor_list [adv_indx_indices [i ]],
@@ -236,7 +238,9 @@ def index(
236
238
== adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
237
239
):
238
240
_LOGGER .debug (f"The indices are continuous in this case" )
239
- concat_tensor_reshape .append (get_trt_tensor (network , - 1 , "dynamic_concat" ))
241
+ concat_tensor_reshape .append (
242
+ get_trt_tensor (network , - 1 , name + "_dynamic_concat" )
243
+ )
240
244
for i in range (0 , rank ):
241
245
if i not in adv_indx_indices :
242
246
curr_dim = dim_tensor_list [i ]
@@ -295,7 +299,7 @@ def index(
295
299
set_layer_name (
296
300
concat_final_shape_layer ,
297
301
target ,
298
- name + "_index_concat_final_shape_layer " ,
302
+ name + "_index_continuous_concat_final_shape_layer " ,
299
303
source_ir ,
300
304
)
301
305
concat_final_tensor = concat_final_shape_layer .get_output (0 )
@@ -312,17 +316,19 @@ def index(
312
316
reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
313
317
314
318
else :
315
- concat_tensor = []
319
+ _LOGGER .debug (f"The indices are not continuous in this case" )
320
+ concat_final_tensor = []
321
+ concat_final_tensor .append (cum_adv_index_shape_tensor )
316
322
for i in range (0 , rank ):
317
323
if i not in adv_indx_indices :
318
324
curr_dim = dim_tensor_list [i ]
319
- concat_tensor .append (curr_dim )
325
+ concat_final_tensor .append (curr_dim )
320
326
321
- concat_layer = network .add_concatenation (concat_tensor )
327
+ concat_final_shape_layer = network .add_concatenation (concat_final_tensor )
322
328
set_layer_name (
323
- concat_layer ,
329
+ concat_final_shape_layer ,
324
330
target ,
325
- name + "_index_concat_final_shape_layer " ,
331
+ name + "_index_non_continuous_concat_final_shape_layer " ,
326
332
source_ir ,
327
333
)
328
334
concat_final_tensor = concat_final_shape_layer .get_output (0 )
@@ -332,7 +338,7 @@ def index(
332
338
set_layer_name (
333
339
reshape_layer ,
334
340
target ,
335
- name + "_index_shuffle_final_shape_layer " ,
341
+ name + "_index_non_continuous_shuffle_final_shape_layer " ,
336
342
source_ir ,
337
343
)
338
344
reshape_output = reshape_layer .get_output (0 )
0 commit comments