Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def loss_fn_wrap(Z, T, use_gpu, device):
args.print_freq = ld_nbatches
args.test_freq = 0

del ld_model
del(ld_model)

print(
"Saved at: epoch = {:d}/{:d}, batch = {:d}/{:d}, ntbatch = {:d}".format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ def loss_fn_wrap(Z, T, use_gpu, device):
args.print_freq = ld_nbatches
args.test_freq = 0

del ld_model
del(ld_model)

print(
"Saved at: epoch = {:d}/{:d}, batch = {:d}/{:d}, ntbatch = {:d}".format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ def run():
)
)
print("Testing state: accuracy = {:3.3f} %".format(ld_acc_test * 100))
del ld_model
del(ld_model)

ext_dist.barrier()
print("time/loss/accuracy (if enabled):")
Expand Down
10 changes: 9 additions & 1 deletion neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2584,7 +2584,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
self.q_dataloader.batch(batch_size)
logger.info('Recovery `calibration.dataloader.batchsize` {} according \
to config.yaml' .format(batch_size))
del init_model
del(init_model)
with open(self.ipex_config_path, 'r') as f:
self.cfgs = json.load(f)
if self.version.release < Version("1.12.0").release:
Expand Down Expand Up @@ -2776,6 +2776,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
from torch.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx
try:
q_model = copy.deepcopy(model)
q_model.fp32_model = model.fp32_model
except Exception as e: # pragma: no cover
logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format(
repr(e)))
Expand Down Expand Up @@ -2983,6 +2984,13 @@ def _pre_hook_for_qat(self, dataloader=None):
# so set it to None.
example_inputs = None

# For export API, deepcopy fp32_model
try:
self.model.fp32_model = copy.deepcopy(self.model.fp32_model)
except Exception as e: # pragma: no cover
logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format(
repr(e)))

if self.sub_module_list is None:
if self.version.release >= Version("1.13.0").release: # pragma: no cover
# pylint: disable=E1123
Expand Down
17 changes: 11 additions & 6 deletions neural_compressor/model/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@ def __init__(self, model, **kwargs):
self.q_config = None
self._workspace_path = ''
self.is_quantized = False
try:
self.fp32_model = copy.deepcopy(model)
except Exception as e: # pragma: no cover
logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format(
repr(e)))
self.fp32_model = model
self.fp32_model = model
self.kwargs = kwargs if kwargs else None

def __repr__(self):
Expand Down Expand Up @@ -93,6 +88,16 @@ def model(self, model):
""" Setter to model """
self._model = model

@property
def fp32_model(self):
""" Getter to model """
return self._fp32_model

@fp32_model.setter
def fp32_model(self, fp32_model):
""" Setter to model """
self._fp32_model = fp32_model

def register_forward_pre_hook(self):
self.handles.append(
self._model.register_forward_pre_hook(self.generate_forward_pre_hook()))
Expand Down