Skip to content

Commit ff82942

Browse files
Refactor logic for deleting PyTorch weights
Signed-off-by: abhishek-singh591 <sabhis@qti.qualcomm.com>
1 parent 9066b0c commit ff82942

File tree

7 files changed

+169
-297
lines changed

7 files changed

+169
-297
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 7 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
7676
# Flag for checking if weights are offloaded
7777
self._is_weights_offloaded: bool = False
7878

79-
self._original_model_name: Optional[str] = None
80-
8179
# Apply the transformations
8280
any_transformed = False
8381
for transform in self._pytorch_transforms:
@@ -89,100 +87,16 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
8987
else:
9088
logger.info(f"Pytorch transforms applied to model: {self.model_name}")
9189

92-
def _offload_model_weights(self, offload_pt_weights) -> None:
90+
def _offload_model_weights(self, offload_pt_weights) -> bool:
9391
"""Clear PyTorch model weights to reduce memory usage after ONNX export."""
9492

9593
if offload_pt_weights and not self._is_weights_offloaded:
9694
try:
97-
if self._original_model_name is None:
98-
self._original_model_name = self.model_name
99-
100-
cached_methods = {}
101-
methods_to_cache = constants.CACHE_MODULES
102-
self._is_weights_offloaded = True
103-
for method_name in methods_to_cache:
104-
if hasattr(self.model, method_name):
105-
method = getattr(self.model, method_name)
106-
if callable(method):
107-
cached_methods[method_name] = method
108-
109-
# Clear tensor storage and replace with empty shell
110-
for param in self.model.parameters():
111-
if hasattr(param, "data") and hasattr(param.data, "storage"):
112-
param.data.storage().resize_(0)
113-
114-
for buffer in self.model.buffers():
115-
if hasattr(buffer, "data") and hasattr(buffer.data, "storage"):
116-
buffer.data.storage().resize_(0)
117-
118-
# Clear module dictionaries and hooks
119-
for module in self.model.modules():
120-
if hasattr(module, "_parameters"):
121-
module._parameters.clear()
122-
if hasattr(module, "_buffers"):
123-
module._buffers.clear()
124-
125-
# Clear hooks
126-
for hook_dict in [
127-
getattr(module, "_forward_hooks", {}),
128-
getattr(module, "_forward_pre_hooks", {}),
129-
getattr(module, "_backward_hooks", {}),
130-
getattr(module, "_state_dict_hooks", {}),
131-
getattr(module, "_load_state_dict_pre_hooks", {}),
132-
]:
133-
hook_dict.clear()
134-
135-
# Replace with minimal shell for compatibility
136-
class ModelShell:
137-
def __init__(self, config, original_class_name, cached_methods=None):
138-
self.config = config
139-
self.qaic_config = None
140-
self.device = torch.device("meta")
141-
self._cached_methods = cached_methods or {}
142-
self._original_class_name = original_class_name
143-
144-
# Create a mock class with the original name
145-
self._mock_class = type(original_class_name, (), {})
146-
147-
@property
148-
def __class__(self):
149-
"""Override __class__ to return mock class with original name"""
150-
return self._mock_class
151-
152-
def __getattr__(self, name):
153-
if name in self._cached_methods:
154-
return self._cached_methods[name]
155-
raise AttributeError(f"'ModelShell' object has no attribute '{name}'")
156-
157-
def parameters(self):
158-
return iter([])
159-
160-
def named_parameters(self):
161-
return iter([])
162-
163-
def buffers(self):
164-
return iter([])
165-
166-
def named_buffers(self):
167-
return iter([])
168-
169-
def modules(self):
170-
return iter([self])
171-
172-
def state_dict(self):
173-
return {}
174-
175-
def to(self, device):
176-
return self
177-
178-
def eval(self):
179-
return self
180-
181-
config = getattr(self.model, "config", None)
182-
original_class_name = self.model.__class__.__name__
183-
self.model = ModelShell(config, original_class_name, cached_methods)
95+
meta_model = self.model.to("meta")
96+
del self.model
97+
gc.collect()
98+
self.model = meta_model
18499
return True
185-
186100
except Exception as e:
187101
logger.warning(f"Weight clearing failed, continuing: {e}")
188102
return False
@@ -359,11 +273,6 @@ def _export(
359273

360274
model = onnx.load(tmp_onnx_path, load_external_data=False)
361275
# Clear temporary references
362-
example_inputs.clear()
363-
input_names.clear()
364-
365-
# Force garbage collection
366-
gc.collect()
367276
transform_kwargs = {
368277
"onnx_base_dir": str(tmp_onnx_dir),
369278
"model_name": self.model_name,
@@ -381,6 +290,8 @@ def _export(
381290
logger.info("ONNX transforms applied")
382291

383292
onnx.save(model, onnx_path)
293+
del model
294+
gc.collect()
384295
logger.info("Transformed ONNX saved")
385296

386297
except Exception as e:
@@ -389,12 +300,6 @@ def _export(
389300

390301
finally:
391302
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
392-
# Clear external data from memory and cache after all transforms and saving
393-
# Make sure model exists before trying to clean it up
394-
if "model" in locals():
395-
BaseOnnxTransform._cleanup_external_data_and_cache(model)
396-
BaseOnnxTransform._cleanup_memory()
397-
logger.info("Cleanup complete.")
398303

399304
if use_onnx_subfunctions:
400305
undo_torch_patches()

0 commit comments

Comments
 (0)