11import logging
2- import operator
3- from typing import Callable , List , Optional , Set , Tuple
2+ from typing import Callable , List , Set , Tuple
43
54import torch
65from torch ._subclasses .fake_tensor import FakeTensorMode
76from torch .fx import GraphModule , Node
8- from torch .fx .subgraph_rewriter import Match
7+ from torch .fx .experimental . proxy_tensor import unset_fake_temporarily
98from torch_tensorrt .dynamo ._settings import CompilationSettings
109from torch_tensorrt .dynamo .lowering .passes .pass_utils import (
1110 clean_up_graph_after_modifications ,
@@ -25,7 +24,7 @@ def __init__(
2524 self .subgraph_nodes = subgraph_nodes
2625 self .input_nodes = input_nodes
2726
28- def __repr__ (self ):
27+ def __repr__ (self ) -> str :
2928 return (
3029 f"ComplexOpSubGraphInfo(anchor_nodes={ [n .name for n in self .anchor_nodes ]} , "
3130 f"subgraph={ [n .name for n in self .subgraph_nodes ]} , "
@@ -34,7 +33,7 @@ def __repr__(self):
3433
3534
3635class ComplexOpDetector :
37- def __init__ (self ):
36+ def __init__ (self ) -> None :
3837 pass
3938
4039 def is_complex_dtype (self , node : Node ) -> bool :
@@ -106,16 +105,18 @@ def find_complex_op_subgraphs(
106105
107106
108107class ComplexGraphRewriter :
109- def __init__ (self , gm : GraphModule , truncate_double : bool = False ):
108+ def __init__ (self , gm : GraphModule , truncate_double : bool = False ) -> None :
110109 self .gm = gm
111110 self .truncate_double = truncate_double
112111
113- def extract_shape_dtype_device (self , input_node ):
112+ def extract_shape_dtype_device (
113+ self , input_node : Node
114+ ) -> Tuple [Tuple [int , ...], torch .dtype , torch .device ]:
114115 if input_node .op == "placeholder" :
115116 tensor_val = input_node .meta ["val" ]
116117
117118 elif input_node .op == "get_attr" :
118- tensor_val = self .get_attr_tensor (input_node .target )
119+ tensor_val = self .get_attr_tensor (input_node .target ) # type: ignore
119120
120121 else :
121122 raise ValueError (f"Unsupported node type: { input_node .op } " )
@@ -134,7 +135,7 @@ def extract_shape_dtype_device(self, input_node):
134135
135136 return new_node_shape , new_node_dtype , device
136137
137- def get_attr_tensor (self , target ):
138+ def get_attr_tensor (self , target ): # type: ignore
138139 # Check if target is param or buffer
139140 if target in dict (self .gm .named_parameters ()):
140141 return self .gm .get_parameter (target )
@@ -145,7 +146,7 @@ def get_attr_tensor(self, target):
145146 f"Attribute { target } not found in gm parameters or buffers."
146147 )
147148
148- def replace_input_node (self , input_node ) :
149+ def replace_input_node (self , input_node : Node ) -> None :
149150 modified = False
150151 logger .debug (f"Replacing input node: { input_node .name } " )
151152 new_shape , new_dtype , device = self .extract_shape_dtype_device (input_node )
@@ -160,10 +161,8 @@ def replace_input_node(self, input_node):
160161
161162 elif input_node .op == "get_attr" :
162163 new_attr_name = input_node .target + "_reshaped"
163- from torch ._subclasses .fake_tensor import unset_fake_temporarily
164-
165164 with unset_fake_temporarily ():
166- original_tensor = self .get_attr_tensor (input_node .target )
165+ original_tensor = self .get_attr_tensor (input_node .target ) # type: ignore
167166 stacked_tensor = torch .stack (
168167 [original_tensor .real , original_tensor .imag ], dim = - 1
169168 )
@@ -181,7 +180,7 @@ def replace_input_node(self, input_node):
181180 self .gm .graph .erase_node (input_node )
182181 clean_up_graph_after_modifications (self .gm )
183182
184- def rewrite_subgraph_nodes (self , subgraphs ) :
183+ def rewrite_subgraph_nodes (self , subgraphs : List [ ComplexSubGraphInfo ]) -> None :
185184 modified = False
186185 for subgraph in subgraphs :
187186 for input_node in subgraph .input_nodes :
@@ -196,11 +195,20 @@ def rewrite_subgraph_nodes(self, subgraphs):
196195 elif node .target == torch .ops .aten .mul .Tensor :
197196 # this is complex mul where inputs = a+ib and output = c+id.
198197 # complex mul returns (ac - bd) + (ad + bc)i
199- # which is then view_as_real as (ac-bd), ad+bc stacked along the last dimension with last dimension size 2
198+ # which is then view_as_real as (ac-bd), (ad+bc) stacked along the last dimension with last dimension size 2
199+ x_placeholder_or_func = (
200+ True if node .args [0 ].op != "get_attr" else False
201+ )
202+ y_placeholder_or_func = (
203+ True if node .args [1 ].op != "get_attr" else False
204+ )
205+
200206 replaced_nodes = []
201- original_mul , replacement = complex_mul_replacement ()
207+ original_mul , replacement = complex_mul_replacement (
208+ x_placeholder_or_func , y_placeholder_or_func
209+ )
202210
203- def match_complex_mul (
211+ def match_complex_mul ( # type: ignore[no-untyped-def]
204212 match : torch .fx .subgraph_rewriter .Match ,
205213 original_graph ,
206214 pattern_graph ,
@@ -233,7 +241,7 @@ def match_complex_mul(
233241 self .gm .graph .lint ()
234242 self .gm .recompile ()
235243
236- def propagate_metadata (self ):
244+ def propagate_metadata (self ) -> None :
237245 fake_inputs = []
238246 from torch ._subclasses .fake_tensor import FakeTensorMode
239247 from torch .fx .passes .fake_tensor_prop import FakeTensorProp
@@ -260,7 +268,34 @@ def propagate_metadata(self):
260268 ).propagate (* fake_inputs )
261269
262270
263- def complex_mul_replacement () -> Tuple [
271+ def extract_real_imag (input , placeholder_or_func : bool = True ): # type: ignore
272+ """Extract real and imaginary parts from a tensor.
273+ This function handles different tensor types based on whether they are placeholder/function
274+ tensors or get_attr tensors. For placeholder/function tensors, it uses select operations,
275+ while for get_attr tensors, it uses indexing.
276+ Args:
277+ input: Input tensor to extract real and imaginary parts from
278+ placeholder_or_func: Boolean flag indicating if the input is a placeholder/function tensor (True)
279+ or a get_attr tensor (False). Defaults to True.
280+ Returns:
281+ Tuple of (real_part, imaginary_part) where both parts have the same type as the input
282+ Note:
283+ - When placeholder_or_func=True: Uses torch.ops.aten.select.int operations
284+ - When placeholder_or_func=False: Uses tensor indexing [..., 0] and [..., 1]
285+ """
286+ if placeholder_or_func :
287+ # For ITensor, use select operations
288+ real_part = torch .ops .aten .select .int (input , - 1 , 0 )
289+ imag_part = torch .ops .aten .select .int (input , - 1 , 1 )
290+ return real_part , imag_part
291+ else :
292+ # For get_attr, use indexing
293+ return input [..., 0 ], input [..., 1 ]
294+
295+
296+ def complex_mul_replacement (
297+ x_placeholder_or_func : bool = True , y_placeholder_or_func : bool = True
298+ ) -> Tuple [
264299 Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ],
265300 Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ],
266301]:
@@ -280,9 +315,8 @@ def original_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
280315
281316 # Replacement function: manual complex multiplication on real/imag stacked tensors
282317 def replacement (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
283- x_real = torch .ops .aten .select .int (x , - 1 , 0 )
284- x_imag = torch .ops .aten .select .int (x , - 1 , 1 ) # x is reshape tensor
285- y_real , y_imag = y [..., 0 ], y [..., 1 ] # y is frozen param
318+ x_real , x_imag = extract_real_imag (x , x_placeholder_or_func )
319+ y_real , y_imag = extract_real_imag (y , y_placeholder_or_func )
286320
287321 real_part1 = torch .ops .aten .mul .Tensor (x_real , y_real )
288322 real_part2 = torch .ops .aten .mul .Tensor (x_imag , y_imag )
@@ -304,10 +338,18 @@ def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
304338
305339
306340# This lowering pass is used to detect and rewrite complex subgraphs in the graph
307- # This lowering pass works for complex tensor in mul which are parameter or buffers in the graph
308341def complex_graph_detection (
309342 gm : GraphModule , settings : CompilationSettings
310- ) -> List [ComplexSubGraphInfo ]:
343+ ) -> GraphModule :
344+ """Detect and rewrite complex subgraphs in the graph.
345+ This lowering pass is used to detect and rewrite complex subgraphs in the graph.
346+ This lowering pass works for complex tensor in mul which are parameter or buffers in the graph.
347+ Args:
348+ gm: The GraphModule to process
349+ settings: Compilation settings
350+ Returns:
351+ The modified GraphModule with complex subgraphs rewritten
352+ """
311353 complex_op_detector = ComplexOpDetector ()
312354 complex_subgraphs = complex_op_detector .find_complex_op_subgraphs (
313355 gm , anchor_target = torch .ops .aten .view_as_real .default
0 commit comments