Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR][fix] Save value info for initializers #1552

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
57 changes: 39 additions & 18 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,38 +567,55 @@ def _deserialize_graph(
for info, value in zip(proto.input, inputs):
deserialize_value_info_proto(info, value)

# Build the value info dictionary to allow for quick lookup for this graph scope
value_info = {info.name: info for info in proto.value_info}

# Initialize the values dictionary for this graph scope with the inputs and initializers
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]

# Enter the graph scope by pushing the values for this scope to the stack
scoped_values.append(values)

initializer_values = []
for tensor in initializer_tensors:
if tensor.name in values:
for i, tensor in enumerate(initializer_tensors):
initializer_name = tensor.name
if not initializer_name:
logger.warning(
"Initializer tensor must have a name but the %s-th initializer does not. Skipping this initializer.",
i,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we indicate it's index? Like f"initiliizer[{i}]"

)
continue
if initializer_name in values:
# The initializer is for an input
initializer_value = values[tensor.name]
initializer_value = values[initializer_name]
initializer_value.const_value = tensor
else:
# The initializer is for some other value. Create this value first
initializer_value = _core.Value(
None,
index=None,
name=tensor.name,
# TODO(justinchuby): Fix type hinting for shape and dtype
shape=tensor.shape, # type: ignore
type=_core.TensorType(tensor.dtype),
name=initializer_name,
# Do not include shape or type as we need to respect the ONNX file
# if the shape or type is not provided as ValueInfoProto
# The shape/type information will be filled in in the subsequent ValueInfoProto
# deserialization step
const_value=tensor,
)
values[tensor.name] = initializer_value # type: ignore[index]
if initializer_name in value_info:
# This is where we fill in the shape and type information for the initializer
deserialize_value_info_proto(value_info[initializer_name], initializer_value)
values[initializer_name] = initializer_value # type: ignore[index]
initializer_values.append(initializer_value)

# Add ValueInfos for this graph scope
value_info = {info.name: info for info in proto.value_info}

# Deserialize nodes with all known values
nodes = [_deserialize_node(node, scoped_values, value_info) for node in proto.node]

# Fill in values for graph outputs
outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output]

# Exit the graph scope by popping the values for this scope from the stack
scoped_values.pop()

return _core.Graph(
inputs,
outputs,
Expand Down Expand Up @@ -1159,17 +1176,21 @@ def serialize_graph_into(
graph_proto.doc_string = from_.doc_string
for input_ in from_.inputs:
serialize_value_into(graph_proto.input.add(), input_)
input_names = {input_.name for input_ in from_.inputs}
# TODO(justinchuby): Support sparse_initializer
for initializer in from_.initializers.values():
if initializer.const_value is None:
for value in from_.initializers.values():
if _should_create_value_info_for_value(value) and value.name not in input_names:
# Serialize information about all initializers into value_info,
# except for those that are also graph inputs
serialize_value_into(graph_proto.value_info.add(), value)
if value.const_value is None:
# Skip initializers without constant values
logger.warning(
"Initializer '%s' does not have a constant value set.", initializer.name
)
logger.warning("Initializer '%s' does not have a constant value set.", value.name)
continue
# Make sure the tensor's name is the same as the value's name
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
initializer.const_value.name = initializer.name
serialize_tensor_into(graph_proto.initializer.add(), from_=initializer.const_value)
# TODO(#1554): Handle tensor alias better
value.const_value.name = value.name
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
serialize_tensor_into(graph_proto.initializer.add(), from_=value.const_value)
for node in from_:
serialize_node_into(graph_proto.node.add(), from_=node)
for node_output in node.outputs:
Expand Down
35 changes: 18 additions & 17 deletions onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,23 @@
import numpy as np
import onnx

from onnxscript import ir
from onnxscript.rewriter import _ir_utils, pattern

torch_module_op = pattern.torch_module_op

logger = logging.getLogger(__name__)


def check_if_simulated_instance_norm_is_used(
def _simulated_instance_norm(
context,
input_x,
adjusted_input_shape,
original_input_shape,
weight_for_norm,
bias_for_norm,
weight_full,
bias_full,
input_x: ir.Value,
adjusted_input_shape: ir.Value,
original_input_shape: ir.Value,
weight_for_norm: ir.Value,
bias_for_norm: ir.Value,
weight_full: ir.Value,
bias_full: ir.Value,
**_,
) -> bool:
"""Check if the simulated instance normalization is used.
Expand All @@ -40,16 +41,16 @@ def check_if_simulated_instance_norm_is_used(
6. original_input_shape is the same as input_x shape.

Returns:
bool: True if the simulated instance normalization is used, False otherwise.
True if the simulated instance normalization is used, False otherwise.
"""
weight_for_norm_prop = _ir_utils.propagate_const_value(weight_for_norm)
weight_for_norm_const_value = weight_for_norm_prop.const_value
_ir_utils.propagate_const_value(weight_for_norm)
weight_for_norm_const_value = weight_for_norm.const_value
if weight_for_norm_const_value is None:
return False
weight_for_norm = weight_for_norm_const_value.numpy()

bias_for_norm_prop = _ir_utils.propagate_const_value(bias_for_norm)
bias_for_norm_const_value = bias_for_norm_prop.const_value
_ir_utils.propagate_const_value(bias_for_norm)
bias_for_norm_const_value = bias_for_norm.const_value
if bias_for_norm_const_value is None:
return False
bias_for_norm = bias_for_norm_const_value.numpy()
Expand All @@ -59,7 +60,7 @@ def check_if_simulated_instance_norm_is_used(
if not np.all(bias_for_norm == 0):
return False

input_rank_minus_one = len(input_x.shape) - 1
input_rank_minus_one = input_x.shape.rank() - 1
weight_full_rank = len(weight_full.shape)
bias_full_rank = len(bias_full.shape)
if weight_full_rank != input_rank_minus_one or bias_full_rank != input_rank_minus_one:
Expand All @@ -76,7 +77,7 @@ def check_if_simulated_instance_norm_is_used(
if not all(dim == 1 for dim in bias_full_shape[1:]):
return False

adjusted_input_shape = _ir_utils.propagate_const_value(adjusted_input_shape)
_ir_utils.propagate_const_value(adjusted_input_shape)
adjusted_input_shape_const_value = adjusted_input_shape.const_value

g = weight_for_norm.shape[0]
Expand All @@ -87,7 +88,7 @@ def check_if_simulated_instance_norm_is_used(
return False

# NOTE: Restrict the rule to only support constant shape
original_input_shape = _ir_utils.propagate_const_value(original_input_shape)
_ir_utils.propagate_const_value(original_input_shape)
original_input_shape_const_value = original_input_shape.const_value
if (
original_input_shape_const_value is None
Expand Down Expand Up @@ -151,7 +152,7 @@ def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, ep
instance_norm_to_group_norm_rule = pattern.RewriteRule(
instance_simulates_group_normalization_pattern,
group_normalization,
check_if_simulated_instance_norm_is_used,
_simulated_instance_norm,
)

# NOTE: instance_norm_to_group_norm_rule is subset of instance_norm_to_group_norm_with_silu_rule,
Expand Down
Loading