Skip to content

Commit

Permalink
ConstantLayerAttributes
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Feb 6, 2025
1 parent 57959bc commit da5f64b
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 30 deletions.
35 changes: 22 additions & 13 deletions nncf/experimental/torch2/function_hook/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@
om.PTDepthwiseConv3dSubtype,
)

CONV_TRANSPOSE_METATYPES = (
om.PTConvTranspose1dMetatype,
om.PTConvTranspose2dMetatype,
om.PTConvTranspose3dMetatype,
)


class ExtractedFunc(nn.Module):
"""
Expand All @@ -58,9 +52,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fn(x, **self.kwargs)


def apply_args_defaults(
def apply_args_to_kwargs(
args: Sequence[Any], kwargs: Dict[str, Any], indexed_args: List[Tuple[int, str]]
) -> Dict[str, Any]:
"""
Applies the given arguments and keyword arguments to a dictionary of keyword arguments.
:param args: The positional arguments.
:param kwargs: The keyword arguments.
:param indexed_args: The list of pairs of indexes and names.
:return: A dictionary of keyword arguments with the applied arguments and keyword arguments.
"""
args_dict: Dict[str, Any] = dict()
for idx, arg_name in indexed_args:
if idx < len(args):
Expand Down Expand Up @@ -91,8 +94,10 @@ def extract_bn(model: nn.Module, graph: PTNNCFGraph, node: NNCFNode) -> Extracte
running_mean = get_const_data_on_port(model, graph, node, 3)
running_var = get_const_data_on_port(model, graph, node, 4)

bn_kwargs = apply_args_defaults(
layer_attr.op_args, layer_attr.op_kwargs, [(6, "momentum"), (7, "eps"), (8, "cudnn_enabled")]
bn_kwargs = apply_args_to_kwargs(
layer_attr.op_args,
layer_attr.op_kwargs,
[(6, "momentum"), (7, "eps"), (8, "cudnn_enabled")],
)
bn_kwargs["weight"] = weight
bn_kwargs["bias"] = bias
Expand Down Expand Up @@ -120,7 +125,8 @@ def extract_conv(
"""
weight_node = get_const_node(input_node, 1, graph)
if weight_node is None:
raise nncf.InternalError(f"Weight node not found for {input_node}")
msg = "Weight node not found for {input_node}"
raise nncf.InternalError(msg)
weight = get_const_data(weight_node, model)

hook_storage = get_hook_storage(model)
Expand All @@ -138,8 +144,10 @@ def extract_conv(
msg = f"Expected PT2OpLayerAttributes for input_node.layer_attributes, actual: {type(layer_attrs)}"
raise nncf.InternalError(msg)

conv_kwargs = apply_args_defaults(
layer_attrs.op_args, layer_attrs.op_kwargs, [(3, "stride"), (4, "padding"), (5, "dilation"), (6, "groups")]
conv_kwargs = apply_args_to_kwargs(
layer_attrs.op_args,
layer_attrs.op_kwargs,
[(3, "stride"), (4, "padding"), (5, "dilation"), (6, "groups")],
)
conv_kwargs["weight"] = weight
conv_kwargs["bias"] = bias
Expand Down Expand Up @@ -178,7 +186,8 @@ def extract_model(
"""

if len(input_nodes) != 1 or len(output_nodes) != 1:
raise nncf.InternalError("input_nodes and output_nodes should contain only one node.")
msg = "input_nodes and output_nodes should contain only one node."
raise nncf.InternalError(msg)

input_node = graph.get_node_by_name(input_nodes[0])
output_node = graph.get_node_by_name(output_nodes[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.layer_attributes import ConstantLayerAttributes
from nncf.common.graph.layer_attributes import Dtype
from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph
from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta
Expand Down Expand Up @@ -157,7 +158,8 @@ def get_layer_attributes(
if isinstance(meta, FunctionMeta):
constant_port_ids = get_constant_port_ids(nx_graph, node)
return PT2OpLayerAttributes(meta.func, meta.args, meta.kwargs, constant_port_ids)

if isinstance(meta, ConstMeta):
return ConstantLayerAttributes(meta.name_in_model, meta.shape)
return None


Expand Down
17 changes: 3 additions & 14 deletions nncf/torch/model_graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
from nncf.torch.dynamic_graph.context import PreHookId
from nncf.torch.external_hook import ExternalOpCallHook
from nncf.torch.graph import operator_metatypes as om
Expand Down Expand Up @@ -126,10 +125,7 @@ def get_const_data(const_node: NNCFNode, model: nn.Module) -> torch.Tensor:
:param model: The NNCFNetwork object.
:return: A torch.Tensor object containing the constant value.
"""
if is_experimental_torch_tracing_enabled():
const_name = const_node.layer_name
else:
const_name = const_node.layer_attributes.name
const_name = const_node.layer_attributes.name
module_name, const_attr_name = split_const_name(const_name)
module = get_module_by_name(module_name, model)
data = getattr(module, const_attr_name)
Expand Down Expand Up @@ -265,10 +261,7 @@ def set_const_data(data: torch.Tensor, const_node: NNCFNode, model: NNCFNetwork)
:param const_node: The NNCF node representing the constant data.
:param model: The NNCF network model.
"""
if is_experimental_torch_tracing_enabled():
const_name = const_node.layer_name
else:
const_name = const_node.layer_attributes.name
const_name = const_node.layer_attributes.name
module_name, const_attr_name = split_const_name(const_name)
module = get_module_by_name(module_name, model)
const = getattr(module, const_attr_name)
Expand All @@ -289,15 +282,11 @@ def set_const_data_to_port_id(
:param const_port_id: The input port id of the node that receives the constant.
:param model: The NNCF network containing the module to be modified.
"""
# graph = model.nncf.get_graph()
const_node = get_const_node(node, port_id, graph)
if const_node is None:
msg = f"No found node with constant for {node.node_name} on {port_id} port"
raise nncf.InternalError(msg)
if is_experimental_torch_tracing_enabled():
const_name = const_node.layer_name
else:
const_name = const_node.layer_attributes.name
const_name = const_node.layer_attributes.name
module_name, const_attr_name = split_const_name(const_name)
module = get_module_by_name(module_name, model)
const = getattr(module, const_attr_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def check_bias(model: GraphModelWrapper, ref_bias: list):
# TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189
assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}"
return
raise ValueError("Not found node with bias")
msg = "Not found node with bias"
raise ValueError(msg)


@pytest.mark.cuda
Expand Down Expand Up @@ -87,4 +88,5 @@ def check_bias(model: GraphModelWrapper, ref_bias: list):
# TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189
assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}"
return
raise ValueError("Not found node with bias")
msg = "Not found node with bias"
raise ValueError(msg)

0 comments on commit da5f64b

Please sign in to comment.