@@ -502,22 +502,40 @@ def index_put_converter(
502502 K = len (I )
503503 # Determine the maximum size 'N' among the index tensors
504504 if K > 0 :
505- index_shapes = [tensor .shape [0 ] for tensor in indices if tensor is not None ]
505+ index_shapes = (
506+ []
507+ ) # [tensor.shape[0] for tensor in indices if tensor is not None]
508+ for idx_tensor in indices :
509+ if idx_tensor is not None :
510+ if idx_tensor .shape [0 ] != DYNAMIC_DIM :
511+ index_shapes .append (idx_tensor .shape [0 ])
512+ else :
513+ index_shapes .append (
514+ get_shape (
515+ ctx ,
516+ target ,
517+ source_ir ,
518+ name + "idx_shape_dim_0" ,
519+ idx_tensor ,
520+ 0 ,
521+ )
522+ )
506523 N = max (index_shapes ) if index_shapes else 1
507524 else :
508525 N = 1
509526
510527 # Compute shapes and volume for the free dimensions
511528 F_shapes = [input_tensor .shape [i ] for i in F ]
529+ assert - 1 not in F_shapes , "Dynamic shape in free dimensions is not supported"
512530 F_volume = trt .volume (F_shapes ) if F_shapes else 1
513531
514532 # Process indexed dimensions (I)
515533 I_tensors = []
516534 for i in I :
517535 idx = indices [i ]
518536 assert idx is not None
519- idx_reshaped = impl .shuffle . reshape (
520- ctx , target , source_ir , f"{ name } _reshape_idx_I_ { i } " , idx , ( idx . shape [ 0 ], 1 )
537+ idx_reshaped = impl .unsqueeze . unsqueeze (
538+ ctx , target , source_ir , f"{ name } _unsqueeze_idx_I_ { i } " , idx , 1
521539 )
522540 expanded_idx = impl .slice .expand (
523541 ctx ,
@@ -539,46 +557,50 @@ def index_put_converter(
539557 )
540558 arange_tensors .append (arange_tensor )
541559
542- meshgrid_tensors = []
543- for i , arange in enumerate (arange_tensors ):
544- reshape_shape = [1 ] * len (F )
545- reshape_shape [i ] = F_shapes [i ]
546- arange_reshaped = impl .shuffle .reshape (
547- ctx ,
548- target ,
549- source_ir ,
550- f"{ name } _reshape_arange_F_{ F [i ]} " ,
551- arange ,
552- tuple (reshape_shape ),
553- )
554- expanded_arange = impl .slice .expand (
555- ctx ,
556- target ,
557- source_ir ,
558- f"{ name } _expand_arange_F_{ F [i ]} " ,
559- arange_reshaped ,
560- tuple (F_shapes ),
561- )
562- meshgrid_tensors .append (expanded_arange )
563-
564- meshgrid_stacked = impl .cat .cat (
565- ctx ,
566- target ,
567- source_ir ,
568- f"{ name } _stack_meshgrid" ,
569- [
570- impl .shuffle .reshape (
560+ if len (arange_tensors ) == 1 :
561+ # No need to stack
562+ meshgrid_stacked = arange_tensors [0 ]
563+ else :
564+ meshgrid_tensors = []
565+ for i , arange in enumerate (arange_tensors ):
566+ reshape_shape = [1 ] * len (F )
567+ reshape_shape [i ] = F_shapes [i ]
568+ arange_reshaped = impl .shuffle .reshape (
571569 ctx ,
572570 target ,
573571 source_ir ,
574- f"{ name } _reshape_mesh_ { i } " ,
575- t ,
576- ( * F_shapes , 1 ),
572+ f"{ name } _reshape_arange_F_ { F [ i ] } " ,
573+ arange ,
574+ tuple ( reshape_shape ),
577575 )
578- for i , t in enumerate (meshgrid_tensors )
579- ],
580- dim = - 1 ,
581- )
576+ expanded_arange = impl .slice .expand (
577+ ctx ,
578+ target ,
579+ source_ir ,
580+ f"{ name } _expand_arange_F_{ F [i ]} " ,
581+ arange_reshaped ,
582+ tuple (F_shapes ),
583+ )
584+ meshgrid_tensors .append (expanded_arange )
585+
586+ meshgrid_stacked = impl .cat .cat (
587+ ctx ,
588+ target ,
589+ source_ir ,
590+ f"{ name } _stack_meshgrid" ,
591+ [
592+ impl .shuffle .reshape (
593+ ctx ,
594+ target ,
595+ source_ir ,
596+ f"{ name } _reshape_mesh_{ i } " ,
597+ t ,
598+ (* F_shapes , 1 ),
599+ )
600+ for i , t in enumerate (meshgrid_tensors )
601+ ],
602+ dim = - 1 ,
603+ )
582604 meshgrid_reshaped = impl .shuffle .reshape (
583605 ctx ,
584606 target ,
@@ -603,21 +625,15 @@ def index_put_converter(
603625
604626 # Combine all indexed dimensions (I)
605627 if K > 0 :
606- I_combined = impl .cat .cat (
607- ctx ,
608- target ,
609- source_ir ,
610- f"{ name } _cat_I" ,
611- [
612- impl .shuffle .reshape (
613- ctx , target , source_ir , f"{ name } _reshape_I_{ i } " , t , (N , F_volume , 1 )
614- )
615- for i , t in enumerate (I_tensors )
616- ],
617- dim = 2 ,
618- )
628+
629+ I_combined = [
630+ impl .shuffle .reshape (
631+ ctx , target , source_ir , f"{ name } _reshape_I_{ i } " , t , (N , F_volume , 1 )
632+ )
633+ for i , t in enumerate (I_tensors )
634+ ]
619635 else :
620- I_combined = None
636+ I_combined = []
621637
622638 # Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded
623639 ii_list = []
@@ -626,24 +642,12 @@ def index_put_converter(
626642 for dim in range (rank ):
627643 unique_suffix = f"{ dim } _{ i_idx if dim in I else f_idx } "
628644 if dim in I :
629- start = [0 , 0 , i_idx ]
630- shape = [N , F_volume , 1 ]
631- stride = [1 , 1 , 1 ]
632- idx_tensor = impl .slice .slice (
633- ctx ,
634- target ,
635- source_ir ,
636- f"{ name } _slice_I_dim_{ unique_suffix } " ,
637- I_combined ,
638- start ,
639- shape ,
640- stride ,
641- )
645+ idx_tensor = I_combined [i ]
642646 ii_list .append (idx_tensor )
643647 i_idx += 1
644648 else :
645649 start = [0 , 0 , f_idx ]
646- shape = [N , F_volume , 1 ]
650+ shape = [- 1 , F_volume , 1 ] if isinstance ( N , TRTTensor ) else [ N , F_volume , 1 ]
647651 stride = [1 , 1 , 1 ]
648652 mesh_tensor = impl .slice .slice (
649653 ctx ,
@@ -662,20 +666,24 @@ def index_put_converter(
662666 indices_cat = impl .cat .cat (
663667 ctx , target , source_ir , f"{ name } _cat_indices" , ii_list , dim = 2
664668 )
669+
670+ # Flatten the indices_cat to (N * F_volume, rank)
665671 indices_cat = impl .shuffle .reshape (
666672 ctx ,
667673 target ,
668674 source_ir ,
669675 f"{ name } _reshape_indices_cat" ,
670676 indices_cat ,
671- (N * F_volume , rank ),
677+ (- 1 , rank ),
672678 )
673679
674680 if not isinstance (values , TRTTensor ):
675681 values = get_trt_tensor (ctx , values , f"{ name } _values" , min_rank = 0 )
676682
677683 # Define the expected shape based on (N,) + F_shapes
678- expected_shape = (N ,) + tuple (F_shapes )
684+ expected_shape = (
685+ (- 1 ,) + tuple (F_shapes ) if isinstance (N , TRTTensor ) else (N ,) + tuple (F_shapes )
686+ )
679687
680688 # Broadcast 'values' to match the expected shape
681689 if len (values .shape ) == 0 or values .shape == (1 ,): # Scalar case
@@ -773,16 +781,51 @@ def index_put_converter(
773781 source_ir ,
774782 f"{ name } _flatten_values" ,
775783 values_expanded ,
776- (N * F_volume ,),
784+ (- 1 ,),
777785 )
778-
779786 indices_cat = cast_trt_tensor (ctx , indices_cat , trt .int32 , f"{ name } _idx_int32" )
780- # Perform Scatter ND operation
781- scatter_layer = ctx .net .add_scatter (
782- input_tensor ,
783- indices_cat ,
784- flattened_values ,
785- trt .ScatterMode .ND if not accumulate else trt .ScatterMode .ND_ELEMENTWISE_ADD ,
786- )
787- set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
788- return scatter_layer .get_output (0 )
787+ if accumulate :
788+ zero_tensor = impl .full .full (
789+ ctx ,
790+ target ,
791+ source_ir ,
792+ f"{ name } _zero_tensor" ,
793+ [
794+ get_shape (
795+ ctx ,
796+ target ,
797+ source_ir ,
798+ name + f"input_tensor_shape_dim_{ i } " ,
799+ input_tensor ,
800+ i ,
801+ )
802+ for i in range (len (input_tensor .shape ))
803+ ],
804+ 0.0 ,
805+ dtype = input_tensor .dtype ,
806+ )
807+ # Perform Scatter ND operation
808+ scatter_layer = ctx .net .add_scatter (
809+ zero_tensor ,
810+ indices_cat ,
811+ flattened_values ,
812+ trt .ScatterMode .ND ,
813+ )
814+ set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
815+
816+ scatter_out = scatter_layer .get_output (0 )
817+ result = impl .elementwise .add (
818+ ctx , target , source_ir , f"{ name } _add" , scatter_out , input_tensor
819+ )
820+ return result
821+
822+ else :
823+ scatter_layer = ctx .net .add_scatter (
824+ input_tensor ,
825+ indices_cat ,
826+ flattened_values ,
827+ trt .ScatterMode .ND ,
828+ )
829+ set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
830+ scatter_out = scatter_layer .get_output (0 )
831+ return scatter_out
0 commit comments