Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,24 +1039,29 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
e,
)

def new_constant(self, node: ir.Node, value) -> ir.Node | None:
irvalue = node.outputs[0]
if not isinstance(value, np.ndarray):
def new_initializer(self, node: ir.Node, array) -> ir.Value | None:
original_value = node.outputs[0]
if not isinstance(array, np.ndarray):
# ONNX does not have a way to represent non-tensor constants, eg. a sequence.
# So, a constant-value of type sequence is not folded, but it can be used
# to optimize subsequent operations when possible.
logger.info(
"Skip storing constant folded value %s due to unsupported type %s.",
irvalue.name,
type(value),
original_value.name,
type(array),
)
return None

tensor = ir.tensor(value)
tensor.name = irvalue.name
irvalue.const_value = tensor
tensor = ir.tensor(array)
tensor.name = original_value.name
initializer = ir.Value(
name=original_value.name,
type=ir.TensorType(ir.DataType(tensor.dtype)),
shape=tensor.shape, # type: ignore[arg-type]
const_value=tensor,
)

if value.size > self.output_size_limit:
if array.size > self.output_size_limit:
# Handle examples like Transpose(weight) to be folded even if the size is large,
# as long as weight has no other uses. This won't increase model size.
removed_input_size = 0
Expand All @@ -1065,25 +1070,23 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None:
array = _get_numpy_value(input)
if array is not None:
removed_input_size += array.size
increased_size = value.size - removed_input_size
increased_size = array.size - removed_input_size
if increased_size > 0:
logger.info(
"Skip storing constant folded nvalue %s due to large size %s.",
irvalue.name,
value.size,
original_value.name,
array.size,
)
return None

logger.debug(
"New constant for value %s dtype: %s shape: %s",
irvalue.name,
value.dtype,
value.shape,
"New Initializer for value %s dtype: %s shape: %s",
original_value.name,
array.dtype,
array.shape,
)

attributes = ir.convenience.convert_attributes({"value": tensor})
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
return node
return initializer

def process_node(self, node: ir.Node) -> Replacement | None:
"""Process a node and return a Replacement if the node can be replaced."""
Expand All @@ -1109,7 +1112,13 @@ def process_node(self, node: ir.Node) -> Replacement | None:
self._do_inference(node)

if node.domain not in self._opset_imports:
logger.debug(
"Skipping constant folding for node %r due to missing opset import for domain %r.",
node.name,
node.domain,
)
return None

version = self._opset_imports[node.domain]
op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version)
for optimizer in op_optimizers:
Expand Down Expand Up @@ -1153,7 +1162,7 @@ def process_node(self, node: ir.Node) -> Replacement | None:
)
return None

# Ensure all node inputs are constants
# Ensure all node inputs are constants or initializers
if any(x.const_value is None for x in node.inputs if x is not None):
return None

Expand Down Expand Up @@ -1227,10 +1236,13 @@ def convert(av):
if outputs is None:
return None
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
replacement = self.new_constant(node, outputs)
if replacement is None:
new_initializer_value = self.new_initializer(node, outputs)
if new_initializer_value is None:
return None
return Replacement(replacement.outputs, [replacement])
# Add the new initializer to the graph
assert node.graph is not None
node.graph.register_initializer(new_initializer_value)
return Replacement([new_initializer_value], [])
else:
logger.warning(
"Skipping constant folding for op %s with multiple outputs.", node.op_type
Expand All @@ -1244,7 +1256,6 @@ def replace_node(

# Record the names of the values that has contributed to the replacement
_record_contributing_values(node, replacement)

ir.convenience.replace_nodes_and_values(
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)
Expand Down
Loading