Skip to content

Commit 6bcc48a

Browse files
committed
Name change of the layers and debugging the cases where indices are non continuous
1 parent f3cee0e commit 6bcc48a

File tree

1 file changed

+25
-19
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+25
-19
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,16 @@ def index(
9090
# if no, then we need to broadcast
9191

9292
last_index = None
93-
broadcast_shape_len = 0
9493
for i, ind in enumerate(index):
9594
if ind is not None:
9695
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
9796
adv_indx_indices.append(i)
9897
# torch.nn.parameter.Parameter=> torch.Tensor
99-
ind = get_trt_tensor(network, ind, f"parameter_to_fp32_tensor_{i}")
98+
ind = get_trt_tensor(network, ind, name + f"_parameter_to_fp32_tensor_{i}")
10099
if last_index is not None:
101-
if not (broadcastable(ind, last_index)):
102-
assert "The indices should be broadcastable"
100+
assert broadcastable(
101+
ind, last_index
102+
), "The indices should be broadcastable!"
103103
last_index = ind
104104
tensor_indices.append(ind)
105105

@@ -129,7 +129,7 @@ def index(
129129

130130
for i in range(rank):
131131
dim = input_shape[i]
132-
dim_tensor = get_trt_tensor(network, dim, f"individual_dim_{i}")
132+
dim_tensor = get_trt_tensor(network, dim, name + f"_individual_dim_{i}")
133133
# dim_tensor_list is a list of tensors
134134
dim_tensor_list.append(dim_tensor)
135135

@@ -166,8 +166,8 @@ def index(
166166

167167
concat_tensor_layer = network.add_concatenation(
168168
[
169-
get_trt_tensor(network, mult_d0, "d0_shape"),
170-
get_trt_tensor(network, mult_d1, "d1_shape"),
169+
get_trt_tensor(network, mult_d0, name + "_d0_shape"),
170+
get_trt_tensor(network, mult_d1, name + "_d1_shape"),
171171
]
172172
)
173173
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
@@ -182,15 +182,17 @@ def index(
182182
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
183183
# // j dimension of input x.
184184
multiplier = get_trt_tensor(
185-
network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last"
185+
network,
186+
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
187+
name + "_dim_last",
186188
)
187189
cum_adv_index = tensor_indices[adv_indx_count - 1]
188190
for i in range(adv_indx_count - 2, -1, -1):
189191
adv_index = convert_binary_elementwise(
190192
network,
191193
target,
192194
source_ir,
193-
name + "index_intermediate",
195+
name + f"_index_intermediate_{i}",
194196
trt.ElementWiseOperation.PROD,
195197
multiplier,
196198
tensor_indices[i],
@@ -199,7 +201,7 @@ def index(
199201
network,
200202
target,
201203
source_ir,
202-
name + "index_sum_intermediate",
204+
name + f"_index_sum_intermediate_{i}",
203205
trt.ElementWiseOperation.SUM,
204206
cum_adv_index,
205207
adv_index,
@@ -208,7 +210,7 @@ def index(
208210
network,
209211
target,
210212
source_ir,
211-
name + "index_intermediate",
213+
name + f"_index_intermediate_xj_{i}",
212214
trt.ElementWiseOperation.PROD,
213215
multiplier,
214216
dim_tensor_list[adv_indx_indices[i]],
@@ -236,7 +238,9 @@ def index(
236238
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
237239
):
238240
_LOGGER.debug(f"The indices are continuous in this case")
239-
concat_tensor_reshape.append(get_trt_tensor(network, -1, "dynamic_concat"))
241+
concat_tensor_reshape.append(
242+
get_trt_tensor(network, -1, name + "_dynamic_concat")
243+
)
240244
for i in range(0, rank):
241245
if i not in adv_indx_indices:
242246
curr_dim = dim_tensor_list[i]
@@ -295,7 +299,7 @@ def index(
295299
set_layer_name(
296300
concat_final_shape_layer,
297301
target,
298-
name + "_index_concat_final_shape_layer",
302+
name + "_index_continuous_concat_final_shape_layer",
299303
source_ir,
300304
)
301305
concat_final_tensor = concat_final_shape_layer.get_output(0)
@@ -312,17 +316,19 @@ def index(
312316
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
313317

314318
else:
315-
concat_tensor = []
319+
_LOGGER.debug(f"The indices are not continuous in this case")
320+
concat_final_tensor = []
321+
concat_final_tensor.append(cum_adv_index_shape_tensor)
316322
for i in range(0, rank):
317323
if i not in adv_indx_indices:
318324
curr_dim = dim_tensor_list[i]
319-
concat_tensor.append(curr_dim)
325+
concat_final_tensor.append(curr_dim)
320326

321-
concat_layer = network.add_concatenation(concat_tensor)
327+
concat_final_shape_layer = network.add_concatenation(concat_final_tensor)
322328
set_layer_name(
323-
concat_layer,
329+
concat_final_shape_layer,
324330
target,
325-
name + "_index_concat_final_shape_layer",
331+
name + "_index_non_continuous_concat_final_shape_layer",
326332
source_ir,
327333
)
328334
concat_final_tensor = concat_final_shape_layer.get_output(0)
@@ -332,7 +338,7 @@ def index(
332338
set_layer_name(
333339
reshape_layer,
334340
target,
335-
name + "_index_shuffle_final_shape_layer",
341+
name + "_index_non_continuous_shuffle_final_shape_layer",
336342
source_ir,
337343
)
338344
reshape_output = reshape_layer.get_output(0)

0 commit comments

Comments
 (0)