@@ -571,22 +571,40 @@ def index_put_converter(
571571 K = len (I )
572572 # Determine the maximum size 'N' among the index tensors
573573 if K > 0 :
574- index_shapes = [tensor .shape [0 ] for tensor in indices if tensor is not None ]
574+ index_shapes = (
575+ []
576+ ) # [tensor.shape[0] for tensor in indices if tensor is not None]
577+ for idx_tensor in indices :
578+ if idx_tensor is not None :
579+ if idx_tensor .shape [0 ] != DYNAMIC_DIM :
580+ index_shapes .append (idx_tensor .shape [0 ])
581+ else :
582+ index_shapes .append (
583+ get_shape (
584+ ctx ,
585+ target ,
586+ source_ir ,
587+ name + "idx_shape_dim_0" ,
588+ idx_tensor ,
589+ 0 ,
590+ )
591+ )
575592 N = max (index_shapes ) if index_shapes else 1
576593 else :
577594 N = 1
578595
579596 # Compute shapes and volume for the free dimensions
580597 F_shapes = [input_tensor .shape [i ] for i in F ]
598+ assert - 1 not in F_shapes , "Dynamic shape in free dimensions is not supported"
581599 F_volume = trt .volume (F_shapes ) if F_shapes else 1
582600
583601 # Process indexed dimensions (I)
584602 I_tensors = []
585603 for i in I :
586604 idx = indices [i ]
587605 assert idx is not None
588- idx_reshaped = impl .shuffle . reshape (
589- ctx , target , source_ir , f"{ name } _reshape_idx_I_ { i } " , idx , ( idx . shape [ 0 ], 1 )
606+ idx_reshaped = impl .unsqueeze . unsqueeze (
607+ ctx , target , source_ir , f"{ name } _unsqueeze_idx_I_ { i } " , idx , 1
590608 )
591609 expanded_idx = impl .slice .expand (
592610 ctx ,
@@ -608,46 +626,50 @@ def index_put_converter(
608626 )
609627 arange_tensors .append (arange_tensor )
610628
611- meshgrid_tensors = []
612- for i , arange in enumerate (arange_tensors ):
613- reshape_shape = [1 ] * len (F )
614- reshape_shape [i ] = F_shapes [i ]
615- arange_reshaped = impl .shuffle .reshape (
616- ctx ,
617- target ,
618- source_ir ,
619- f"{ name } _reshape_arange_F_{ F [i ]} " ,
620- arange ,
621- tuple (reshape_shape ),
622- )
623- expanded_arange = impl .slice .expand (
624- ctx ,
625- target ,
626- source_ir ,
627- f"{ name } _expand_arange_F_{ F [i ]} " ,
628- arange_reshaped ,
629- tuple (F_shapes ),
630- )
631- meshgrid_tensors .append (expanded_arange )
632-
633- meshgrid_stacked = impl .cat .cat (
634- ctx ,
635- target ,
636- source_ir ,
637- f"{ name } _stack_meshgrid" ,
638- [
639- impl .shuffle .reshape (
629+ if len (arange_tensors ) == 1 :
630+ # No need to stack
631+ meshgrid_stacked = arange_tensors [0 ]
632+ else :
633+ meshgrid_tensors = []
634+ for i , arange in enumerate (arange_tensors ):
635+ reshape_shape = [1 ] * len (F )
636+ reshape_shape [i ] = F_shapes [i ]
637+ arange_reshaped = impl .shuffle .reshape (
640638 ctx ,
641639 target ,
642640 source_ir ,
643- f"{ name } _reshape_mesh_ { i } " ,
644- t ,
645- ( * F_shapes , 1 ),
641+ f"{ name } _reshape_arange_F_ { F [ i ] } " ,
642+ arange ,
643+ tuple ( reshape_shape ),
646644 )
647- for i , t in enumerate (meshgrid_tensors )
648- ],
649- dim = - 1 ,
650- )
645+ expanded_arange = impl .slice .expand (
646+ ctx ,
647+ target ,
648+ source_ir ,
649+ f"{ name } _expand_arange_F_{ F [i ]} " ,
650+ arange_reshaped ,
651+ tuple (F_shapes ),
652+ )
653+ meshgrid_tensors .append (expanded_arange )
654+
655+ meshgrid_stacked = impl .cat .cat (
656+ ctx ,
657+ target ,
658+ source_ir ,
659+ f"{ name } _stack_meshgrid" ,
660+ [
661+ impl .shuffle .reshape (
662+ ctx ,
663+ target ,
664+ source_ir ,
665+ f"{ name } _reshape_mesh_{ i } " ,
666+ t ,
667+ (* F_shapes , 1 ),
668+ )
669+ for i , t in enumerate (meshgrid_tensors )
670+ ],
671+ dim = - 1 ,
672+ )
651673 meshgrid_reshaped = impl .shuffle .reshape (
652674 ctx ,
653675 target ,
@@ -672,21 +694,15 @@ def index_put_converter(
672694
673695 # Combine all indexed dimensions (I)
674696 if K > 0 :
675- I_combined = impl .cat .cat (
676- ctx ,
677- target ,
678- source_ir ,
679- f"{ name } _cat_I" ,
680- [
681- impl .shuffle .reshape (
682- ctx , target , source_ir , f"{ name } _reshape_I_{ i } " , t , (N , F_volume , 1 )
683- )
684- for i , t in enumerate (I_tensors )
685- ],
686- dim = 2 ,
687- )
697+
698+ I_combined = [
699+ impl .shuffle .reshape (
700+ ctx , target , source_ir , f"{ name } _reshape_I_{ i } " , t , (N , F_volume , 1 )
701+ )
702+ for i , t in enumerate (I_tensors )
703+ ]
688704 else :
689- I_combined = None
705+ I_combined = []
690706
691707 # Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded
692708 ii_list = []
@@ -695,24 +711,12 @@ def index_put_converter(
695711 for dim in range (rank ):
696712 unique_suffix = f"{ dim } _{ i_idx if dim in I else f_idx } "
697713 if dim in I :
698- start = [0 , 0 , i_idx ]
699- shape = [N , F_volume , 1 ]
700- stride = [1 , 1 , 1 ]
701- idx_tensor = impl .slice .slice (
702- ctx ,
703- target ,
704- source_ir ,
705- f"{ name } _slice_I_dim_{ unique_suffix } " ,
706- I_combined ,
707- start ,
708- shape ,
709- stride ,
710- )
714+ idx_tensor = I_combined [i ]
711715 ii_list .append (idx_tensor )
712716 i_idx += 1
713717 else :
714718 start = [0 , 0 , f_idx ]
715- shape = [N , F_volume , 1 ]
719+ shape = [- 1 , F_volume , 1 ] if isinstance ( N , TRTTensor ) else [ N , F_volume , 1 ]
716720 stride = [1 , 1 , 1 ]
717721 mesh_tensor = impl .slice .slice (
718722 ctx ,
@@ -731,20 +735,24 @@ def index_put_converter(
731735 indices_cat = impl .cat .cat (
732736 ctx , target , source_ir , f"{ name } _cat_indices" , ii_list , dim = 2
733737 )
738+
739+ # Flatten the indices_cat to (N * F_volume, rank)
734740 indices_cat = impl .shuffle .reshape (
735741 ctx ,
736742 target ,
737743 source_ir ,
738744 f"{ name } _reshape_indices_cat" ,
739745 indices_cat ,
740- (N * F_volume , rank ),
746+ (- 1 , rank ),
741747 )
742748
743749 if not isinstance (values , TRTTensor ):
744750 values = get_trt_tensor (ctx , values , f"{ name } _values" , min_rank = 0 )
745751
746752 # Define the expected shape based on (N,) + F_shapes
747- expected_shape = (N ,) + tuple (F_shapes )
753+ expected_shape = (
754+ (- 1 ,) + tuple (F_shapes ) if isinstance (N , TRTTensor ) else (N ,) + tuple (F_shapes )
755+ )
748756
749757 # Broadcast 'values' to match the expected shape
750758 if len (values .shape ) == 0 or values .shape == (1 ,): # Scalar case
@@ -842,16 +850,51 @@ def index_put_converter(
842850 source_ir ,
843851 f"{ name } _flatten_values" ,
844852 values_expanded ,
845- (N * F_volume ,),
853+ (- 1 ,),
846854 )
847-
848855 indices_cat = cast_trt_tensor (ctx , indices_cat , trt .int32 , f"{ name } _idx_int32" )
849- # Perform Scatter ND operation
850- scatter_layer = ctx .net .add_scatter (
851- input_tensor ,
852- indices_cat ,
853- flattened_values ,
854- trt .ScatterMode .ND if not accumulate else trt .ScatterMode .ND_ELEMENTWISE_ADD ,
855- )
856- set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
857- return scatter_layer .get_output (0 )
856+ if accumulate :
857+ zero_tensor = impl .full .full (
858+ ctx ,
859+ target ,
860+ source_ir ,
861+ f"{ name } _zero_tensor" ,
862+ [
863+ get_shape (
864+ ctx ,
865+ target ,
866+ source_ir ,
867+ name + f"input_tensor_shape_dim_{ i } " ,
868+ input_tensor ,
869+ i ,
870+ )
871+ for i in range (len (input_tensor .shape ))
872+ ],
873+ 0.0 ,
874+ dtype = input_tensor .dtype ,
875+ )
876+ # Perform Scatter ND operation
877+ scatter_layer = ctx .net .add_scatter (
878+ zero_tensor ,
879+ indices_cat ,
880+ flattened_values ,
881+ trt .ScatterMode .ND ,
882+ )
883+ set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
884+
885+ scatter_out = scatter_layer .get_output (0 )
886+ result = impl .elementwise .add (
887+ ctx , target , source_ir , f"{ name } _add" , scatter_out , input_tensor
888+ )
889+ return result
890+
891+ else :
892+ scatter_layer = ctx .net .add_scatter (
893+ input_tensor ,
894+ indices_cat ,
895+ flattened_values ,
896+ trt .ScatterMode .ND ,
897+ )
898+ set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
899+ scatter_out = scatter_layer .get_output (0 )
900+ return scatter_out
0 commit comments