@@ -76,8 +76,6 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
7676 # Flag for checking if weights are offloaded
7777 self ._is_weights_offloaded : bool = False
7878
79- self ._original_model_name : Optional [str ] = None
80-
8179 # Apply the transformations
8280 any_transformed = False
8381 for transform in self ._pytorch_transforms :
@@ -89,100 +87,16 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
8987 else :
9088 logger .info (f"Pytorch transforms applied to model: { self .model_name } " )
9189
92- def _offload_model_weights (self , offload_pt_weights ) -> None :
90+ def _offload_model_weights (self , offload_pt_weights ) -> bool :
9391 """Clear PyTorch model weights to reduce memory usage after ONNX export."""
9492
9593 if offload_pt_weights and not self ._is_weights_offloaded :
9694 try :
97- if self ._original_model_name is None :
98- self ._original_model_name = self .model_name
99-
100- cached_methods = {}
101- methods_to_cache = constants .CACHE_MODULES
102- self ._is_weights_offloaded = True
103- for method_name in methods_to_cache :
104- if hasattr (self .model , method_name ):
105- method = getattr (self .model , method_name )
106- if callable (method ):
107- cached_methods [method_name ] = method
108-
109- # Clear tensor storage and replace with empty shell
110- for param in self .model .parameters ():
111- if hasattr (param , "data" ) and hasattr (param .data , "storage" ):
112- param .data .storage ().resize_ (0 )
113-
114- for buffer in self .model .buffers ():
115- if hasattr (buffer , "data" ) and hasattr (buffer .data , "storage" ):
116- buffer .data .storage ().resize_ (0 )
117-
118- # Clear module dictionaries and hooks
119- for module in self .model .modules ():
120- if hasattr (module , "_parameters" ):
121- module ._parameters .clear ()
122- if hasattr (module , "_buffers" ):
123- module ._buffers .clear ()
124-
125- # Clear hooks
126- for hook_dict in [
127- getattr (module , "_forward_hooks" , {}),
128- getattr (module , "_forward_pre_hooks" , {}),
129- getattr (module , "_backward_hooks" , {}),
130- getattr (module , "_state_dict_hooks" , {}),
131- getattr (module , "_load_state_dict_pre_hooks" , {}),
132- ]:
133- hook_dict .clear ()
134-
135- # Replace with minimal shell for compatibility
136- class ModelShell :
137- def __init__ (self , config , original_class_name , cached_methods = None ):
138- self .config = config
139- self .qaic_config = None
140- self .device = torch .device ("meta" )
141- self ._cached_methods = cached_methods or {}
142- self ._original_class_name = original_class_name
143-
144- # Create a mock class with the original name
145- self ._mock_class = type (original_class_name , (), {})
146-
147- @property
148- def __class__ (self ):
149- """Override __class__ to return mock class with original name"""
150- return self ._mock_class
151-
152- def __getattr__ (self , name ):
153- if name in self ._cached_methods :
154- return self ._cached_methods [name ]
155- raise AttributeError (f"'ModelShell' object has no attribute '{ name } '" )
156-
157- def parameters (self ):
158- return iter ([])
159-
160- def named_parameters (self ):
161- return iter ([])
162-
163- def buffers (self ):
164- return iter ([])
165-
166- def named_buffers (self ):
167- return iter ([])
168-
169- def modules (self ):
170- return iter ([self ])
171-
172- def state_dict (self ):
173- return {}
174-
175- def to (self , device ):
176- return self
177-
178- def eval (self ):
179- return self
180-
181- config = getattr (self .model , "config" , None )
182- original_class_name = self .model .__class__ .__name__
183- self .model = ModelShell (config , original_class_name , cached_methods )
95+ meta_model = self .model .to ("meta" )
96+ del self .model
97+ gc .collect ()
98+ self .model = meta_model
18499 return True
185-
186100 except Exception as e :
187101 logger .warning (f"Weight clearing failed, continuing: { e } " )
188102 return False
@@ -359,11 +273,6 @@ def _export(
359273
360274 model = onnx .load (tmp_onnx_path , load_external_data = False )
361275 # Clear temporary references
362- example_inputs .clear ()
363- input_names .clear ()
364-
365- # Force garbage collection
366- gc .collect ()
367276 transform_kwargs = {
368277 "onnx_base_dir" : str (tmp_onnx_dir ),
369278 "model_name" : self .model_name ,
@@ -381,6 +290,8 @@ def _export(
381290 logger .info ("ONNX transforms applied" )
382291
383292 onnx .save (model , onnx_path )
293+ del model
294+ gc .collect ()
384295 logger .info ("Transformed ONNX saved" )
385296
386297 except Exception as e :
@@ -389,12 +300,6 @@ def _export(
389300
390301 finally :
391302 shutil .rmtree (tmp_onnx_dir , ignore_errors = True )
392- # Clear external data from memory and cache after all transforms and saving
393- # Make sure model exists before trying to clean it up
394- if "model" in locals ():
395- BaseOnnxTransform ._cleanup_external_data_and_cache (model )
396- BaseOnnxTransform ._cleanup_memory ()
397- logger .info ("Cleanup complete." )
398303
399304 if use_onnx_subfunctions :
400305 undo_torch_patches ()
0 commit comments