@@ -202,34 +202,29 @@ def slice_scatter_decomposition(
202202 start = get_positive_dim (start , input_tensor .shape [dim ])
203203 if end is None : # Ensure end is int
204204 end = dim_size
205- end = get_positive_dim (end , input_tensor .shape [dim ])
205+ end = (
206+ get_positive_dim (end , input_tensor .shape [dim ]) if isinstance (end , int ) else end
207+ )
206208 if step is None :
207209 step = 1
208210
209- src_dim = src_tensor .shape
210211 # step == 0 is not a valid torch case
211- # also src_dim should be equal to slice dimension
212-
213212 if start == 0 and end == dim_size and step == 1 :
214213 return src_tensor
215214
216215 # Ensure start, end, and step are all integers
217- assert isinstance (start , int ), "start must be an integer"
218- assert isinstance (end , int ), "end must be an integer"
219- assert isinstance (step , int ), "step must be an integer"
220-
221- cat_tensors = []
222- index_tensor_shape = []
223- for i , src_each_dim in enumerate (list (src_dim )):
224- if i != dim :
225- index_tensor_shape .append (src_each_dim )
226- for index in range (start , end , step ):
227- cat_tensors .append (index * torch .ones (index_tensor_shape , dtype = torch .int64 ))
228- index_tensor = torch .stack (cat_tensors , dim )
229- index_tensor = index_tensor .to (device_input_tensor )
230- index_tensor_64 = index_tensor .to (torch .int64 )
231- output_tensor = torch .scatter (input_tensor , dim , index_tensor_64 , src_tensor )
232- return output_tensor
216+ assert isinstance (start , (int , torch .SymInt )), "start must be an int or SymInt"
217+ assert isinstance (end , (int , torch .SymInt )), "end must be an int or SymInt"
218+ assert isinstance (step , (int , torch .SymInt )), "step must be an int or SymInt"
219+
220+ indices = torch .arange (
221+ start , end , step , device = device_input_tensor , dtype = torch .int64
222+ )
223+ index_tensor = indices .view (
224+ [- 1 if i == dim else 1 for i in range (input_tensor .dim ())]
225+ )
226+ index_tensor = index_tensor .expand_as (src_tensor )
227+ return torch .scatter (input_tensor , dim , index_tensor , src_tensor )
233228
234229
235230@register_torch_trt_decomposition (
0 commit comments