Skip to content

Commit

Permalink
Correct dtype handling for compressed torch.nn.Parameter (openvinotoo…
Browse files Browse the repository at this point in the history
…lkit#2477)

### Changes

Correct dtype handling for compressed torch.nn.Parameter

### Reason for changes

Weight compression for `CompVis/ldm-super-resolution-4x-openimages`

### Related tickets

PR: openvinotoolkit/openvino.genai#232

### Tests

test_get_dtype_attribute_of_parameter
  • Loading branch information
alexsu52 authored Feb 20, 2024
1 parent c8000ca commit b5e0a05
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 54 deletions.
1 change: 1 addition & 0 deletions nncf/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
# listed below for importing convenience

from nncf.torch.model_creation import create_compressed_model
from nncf.torch.model_creation import is_wrapped_model
from nncf.torch.model_creation import wrap_model
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.initialization import register_default_init_args
Expand Down
114 changes: 60 additions & 54 deletions nncf/torch/dynamic_graph/patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import functools
import inspect
from contextlib import contextmanager
from typing import List, Tuple
from typing import List

import torch
import torch.utils.cpp_extension
Expand All @@ -33,13 +33,15 @@
from nncf.torch.dynamic_graph.wrappers import wrap_operator


def get_namespaces_to_patch(namespace_target: NamespaceTarget) -> Tuple[object, ...]:
def get_namespaces_to_patch(namespace_target: NamespaceTarget) -> object:
if namespace_target == NamespaceTarget.TORCH_NN_FUNCTIONAL:
return (torch.nn.functional,)
return torch.nn.functional
if namespace_target == NamespaceTarget.TORCH_TENSOR:
return (TracedTensor, TracedParameter)
return TracedTensor
if namespace_target == NamespaceTarget.TORCH_NN_PARAMETER:
return TracedParameter
if namespace_target == NamespaceTarget.TORCH:
return (torch,)
return torch
raise nncf.ValidationError("{} namespace wasn't found in {}".format(namespace_target, NamespaceTarget))


Expand All @@ -48,6 +50,8 @@ def get_namespace_to_extract_functions_from(namespace_target: NamespaceTarget) -
return torch.nn.functional
if namespace_target == NamespaceTarget.TORCH_TENSOR:
return torch.Tensor
if namespace_target == NamespaceTarget.TORCH_NN_PARAMETER:
return torch.nn.Parameter
if namespace_target == NamespaceTarget.TORCH:
return torch._C._VariableFunctions
raise nncf.ValidationError("{} namespace wasn't found in {}".format(namespace_target, NamespaceTarget))
Expand Down Expand Up @@ -120,52 +124,54 @@ class FunctionsToPatchWithoutTracing:


class MagicFunctionsToPatch:
TENSOR_MAGIC_FUNCTIONS = [
"__abs__",
"__add__",
"__and__",
"__div__",
"__eq__",
"__floordiv__",
"__ge__",
"__getitem__",
"__gt__",
"__iadd__",
"__iand__",
"__idiv__",
"__ifloordiv__",
"__imul__",
"__invert__",
"__ior__",
"__ipow__",
"__isub__",
"__itruediv__",
"__ixor__",
"__le__",
"__lt__",
"__matmul__",
"__mod__",
"__mul__",
"__ne__",
"__neg__",
"__or__",
"__pow__",
"__radd__",
"__rand__",
"__rdiv__",
"__rfloordiv__",
"__rmatmul__",
"__rmul__",
"__ror__",
"__rpow__",
"__rsub__",
"__rtruediv__",
"__rxor__",
"__sub__",
"__truediv__",
"__xor__",
]
MAGIC_FUNCTIONS_TO_PATCH = {
NamespaceTarget.TORCH_TENSOR: [
"__abs__",
"__add__",
"__and__",
"__div__",
"__eq__",
"__floordiv__",
"__ge__",
"__getitem__",
"__gt__",
"__iadd__",
"__iand__",
"__idiv__",
"__ifloordiv__",
"__imul__",
"__invert__",
"__ior__",
"__ipow__",
"__isub__",
"__itruediv__",
"__ixor__",
"__le__",
"__lt__",
"__matmul__",
"__mod__",
"__mul__",
"__ne__",
"__neg__",
"__or__",
"__pow__",
"__radd__",
"__rand__",
"__rdiv__",
"__rfloordiv__",
"__rmatmul__",
"__rmul__",
"__ror__",
"__rpow__",
"__rsub__",
"__rtruediv__",
"__rxor__",
"__sub__",
"__truediv__",
"__xor__",
]
NamespaceTarget.TORCH_TENSOR: TENSOR_MAGIC_FUNCTIONS,
NamespaceTarget.TORCH_NN_PARAMETER: TENSOR_MAGIC_FUNCTIONS + ["get_dtype"],
}


Expand Down Expand Up @@ -369,17 +375,17 @@ def patch_torch_operators():
for namespace, function_names in functions_to_patch.items():
for function_name in function_names:
op_info = PatchedOperatorInfo(function_name, namespace)
for patched_namespace in get_namespaces_to_patch(namespace):
patch_namespace_opname(patched_namespace, op_info)
patched_namespace = get_namespaces_to_patch(namespace)
patch_namespace_opname(patched_namespace, op_info)

# Patch operators without tracing so that
# both they and any internal calls to otherwise traced functions do not appear into the model graph.

for namespace, function_names in functions_to_patch_without_tracing.items():
for function_name in function_names:
op_info = PatchedOperatorInfo(function_name, namespace, skip_trace=True)
for patched_namespace in get_namespaces_to_patch(namespace):
patch_namespace_opname(patched_namespace, op_info)
patched_namespace = get_namespaces_to_patch(namespace)
patch_namespace_opname(patched_namespace, op_info)

# Patch __repr__ twice in 'torch.Tensor' and 'TracedTensor'.
# This is done to not add operations behind print() operator for the both TracedTensor and torch.Tensor.
Expand Down
1 change: 1 addition & 0 deletions nncf/torch/dynamic_graph/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class NamespaceTarget(Enum):

TORCH_NN_FUNCTIONAL = "torch.nn.functional"
TORCH_TENSOR = "torch.tensor"
TORCH_NN_PARAMETER = "torch.nn.parameter"
TORCH = "torch"
EXTERNAL = "external_function"

Expand Down
9 changes: 9 additions & 0 deletions nncf/torch/dynamic_graph/trace_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ def name(self):
def is_reused(self):
return self.tracing_attrs[TracedParameter.IS_REUSED]

def get_dtype(self):
# Type of self is TracedParameter or TracedTensor
return super(self.__class__, self).__getattribute__("dtype")

def __getattribute__(self, name):
if name == "dtype":
return self.get_dtype()
return super().__getattribute__(name)

@staticmethod
def from_torch_parameter(tensor: torch.nn.Parameter, name: str, is_reused: bool) -> "TracedParameter":
"""
Expand Down
21 changes: 21 additions & 0 deletions tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,24 @@ def test_raise_error_with_not_int8_asym(mode):
wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True)
with pytest.raises(AttributeError):
compress_weights(wrapped_model, mode=mode)


class DTypeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(size=(3, 3), dtype=torch.float32))

def forward(self, x):
x = x.to(self.weight.dtype)
x = x @ self.weight
return x


def test_get_dtype_attribute_of_parameter():
model = DTypeModel()
dummy_input = torch.randint(0, 10, [3, 3])
wrapped_model = wrap_model(model, example_input=dummy_input, trace_parameters=True)
compressed_model = compress_weights(wrapped_model)
assert compressed_model.weight.dtype == torch.uint8
compressed_model(dummy_input)
assert compressed_model.weight.dtype == torch.uint8

0 comments on commit b5e0a05

Please sign in to comment.