Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ynnpack/subgraph/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ cc_library(
"fusion.cc",
"get_tensor_shape.cc",
"reduce.cc",
"reduce.h",
"runtime.cc",
"slinky.cc",
"stack.cc",
Expand Down
40 changes: 39 additions & 1 deletion ynnpack/subgraph/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "ynnpack/kernels/ternary/ternary.h"
#include "ynnpack/subgraph/dot.h"
#include "ynnpack/subgraph/elementwise.h"
#include "ynnpack/subgraph/reduce.h"
#include "ynnpack/subgraph/stencil_copy.h"
#include "ynnpack/subgraph/subgraph.h"

Expand Down Expand Up @@ -415,6 +416,42 @@ bool rewrite_transpose_stencil_copy(ynn_subgraph& subgraph, ynn_node& node,
return true;
}

// Rewrites ynn_reduce_sum of x*x to ynn_reduce_sum_squared of x.
bool rewrite_reduce_sum_of_squared(ynn_subgraph& subgraph, ynn_node& node,
subgraph_analysis& analysis) {
const ynn_node::reduce* reduce_op = std::get_if<ynn_node::reduce>(&node.op);
if (reduce_op == nullptr || reduce_op->op != ynn_reduce_sum) {
return false;
}

auto producer = analysis.producers.find(node.inputs[0]);
if (producer == analysis.producers.end() ||
analysis.consumers[producer->second->outputs[0]].size() != 1) {
return false;
}

ynn_node* mul_node = producer->second;
if (!is_binary_node(*mul_node, ynn_binary_multiply)) {
return false;
}

if (mul_node->inputs[0] != mul_node->inputs[1]) {
return false;
}

const ynn_value& x = subgraph.value(mul_node->inputs[0]);
if (x.type != ynn_type_fp16 && x.type != ynn_type_fp32 &&
x.type != ynn_type_bf16) {
return false;
}

YNN_LOG_DEBUG() << "Rewriting reduce_sum(x*x) to reduce_sum_squared(x)";
ynn::define_reduce(subgraph, node, ynn_reduce_sum_squared, reduce_op->k_dims,
x.id, node.inputs[1], node.outputs[0],
reduce_op->keep_dims);
return true;
}

} // namespace

ynn_status ynn_subgraph::fusion() {
Expand All @@ -429,7 +466,8 @@ ynn_status ynn_subgraph::fusion() {
rewrite_clamp(*this, node, analysis) ||
rewrite_convert_to_quantize(*this, node, analysis) ||
remove_broadcast(*this, node, analysis) ||
rewrite_transpose_stencil_copy(*this, node, analysis);
rewrite_transpose_stencil_copy(*this, node, analysis) ||
rewrite_reduce_sum_of_squared(*this, node, analysis);
}

return ynn_status_success;
Expand Down
169 changes: 91 additions & 78 deletions ynnpack/subgraph/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "ynnpack/base/base.h"
#include "ynnpack/base/type.h"
#include "ynnpack/include/ynnpack.h"
#include "ynnpack/subgraph/reduce.h"
#include "ynnpack/subgraph/runtime.h"
#include "ynnpack/subgraph/slinky.h"
#include "ynnpack/subgraph/subgraph.h"
Expand Down Expand Up @@ -250,90 +251,22 @@ uint32_t get_reduce_identity_value(ynn_subgraph_t subgraph,

} // namespace

extern "C" {

ynn_status ynn_define_reduce(ynn_subgraph_t subgraph,
enum ynn_reduce_operator op, size_t num_axes,
const int32_t* axes, uint32_t input_a_id,
uint32_t input_b_id, uint32_t* output_id,
uint32_t flags) {
// Validate arguments.
assert(subgraph);
assert(subgraph->is_valid_value(input_a_id));
const ynn_value& a = subgraph->value(input_a_id);

ynn_node::reduce reduce;
reduce.op = op;
for (size_t i = 0; i < num_axes; ++i) {
reduce.k_dims[axis_to_slinky_dim(a.rank(), axes[i])] = true;
}
reduce.keep_dims = flags & YNN_NODE_FLAG_KEEP_DIMS;

assert(output_id);
if (*output_id == YNN_INVALID_VALUE_ID) {
// Make the output for this reduction.
ynn_type output_type = get_accumulator_type(op, a.type);
ynn_value& output = subgraph->new_internal_value(output_type);
uint32_t reduce_size_id = YNN_INVALID_VALUE_ID;
switch (op) {
case ynn_reduce_sum:
case ynn_reduce_sum_squared:
if (a.zero_point_id != YNN_INVALID_VALUE_ID) {
// When computing a sum, the zero point gets multiplied by the number
// of elements in the reduction.
ynn_define_get_tensor_shape(subgraph, num_axes, axes, ynn_type_int32,
/*rank=*/0, input_a_id, &reduce_size_id,
/*flags=*/YNN_NODE_FLAG_RESHAPE_1D);
ynn_define_binary(subgraph, ynn_binary_multiply, a.zero_point_id,
reduce_size_id, &output.zero_point_id,
/*flags=*/0);
}
output.scale_id = a.scale_id;
break;
case ynn_reduce_max:
case ynn_reduce_min:
case ynn_reduce_min_max:
output.zero_point_id = a.zero_point_id;
output.scale_id = a.scale_id;
break;
default:
YNN_UNREACHABLE;
}

*output_id = output.id;
}

// Propagate shape
ynn_value& output = subgraph->value(*output_id);
output.extents = a.extents;
for (int i = static_cast<int>(output.extents.size()) - 1; i >= 0; --i) {
if (reduce.k_dims[i]) {
if (reduce.keep_dims) {
output.extents[i] = {};
} else {
output.extents.erase(output.extents.begin() + i);
}
}
}

if (op == ynn_reduce_min_max) {
// This reduction adds a dimension for the min/max index.
output.extents.push_back(2);
}

if (input_b_id == YNN_INVALID_VALUE_ID) {
input_b_id = get_reduce_identity_value(subgraph, output, op);
}
void define_reduce(ynn_subgraph& subgraph, ynn_node& node,
ynn_reduce_operator op, const ynn::axes_set& k_dims,
uint32_t input_a_id, uint32_t input_b_id, uint32_t output_id,
bool keep_dims) {
assert(subgraph.is_valid_value(input_a_id));
assert(subgraph.is_valid_value(output_id));
const ynn_value& a = subgraph.value(input_a_id);
const ynn_value& output = subgraph.value(output_id);

// Get the reduce kernel we are going to use.
unary_reduce_kernel_fn kernel = get_reduce_kernel(op, a.type, output.type);
assert(kernel);

// Make the node.
ynn_node node;
node.inputs = {input_a_id, input_b_id};
node.outputs = {*output_id};
node.op = std::move(reduce);
node.outputs = {output_id};
node.op = ynn_node::reduce{k_dims, op, keep_dims};

node.create = [kernel](const ynn_node& node, ynn_runtime& runtime) {
const ynn_node::reduce& op = std::get<ynn_node::reduce>(node.op);
Expand Down Expand Up @@ -404,6 +337,86 @@ ynn_status ynn_define_reduce(ynn_subgraph_t subgraph,
runtime.funcs.push_back(std::move(func));
return ynn_status_success;
};
}

extern "C" {

ynn_status ynn_define_reduce(ynn_subgraph_t subgraph,
enum ynn_reduce_operator op, size_t num_axes,
const int32_t* axes, uint32_t input_a_id,
uint32_t input_b_id, uint32_t* output_id,
uint32_t flags) {
// Validate arguments.
assert(subgraph);
assert(subgraph->is_valid_value(input_a_id));
const ynn_value& a = subgraph->value(input_a_id);

ynn::axes_set k_dims;
for (size_t i = 0; i < num_axes; ++i) {
k_dims[axis_to_slinky_dim(a.rank(), axes[i])] = true;
}
bool keep_dims = flags & YNN_NODE_FLAG_KEEP_DIMS;

assert(output_id);
if (*output_id == YNN_INVALID_VALUE_ID) {
// Make the output for this reduction.
ynn_type output_type = get_accumulator_type(op, a.type);
ynn_value& output = subgraph->new_internal_value(output_type);
uint32_t reduce_size_id = YNN_INVALID_VALUE_ID;
switch (op) {
case ynn_reduce_sum:
case ynn_reduce_sum_squared:
if (a.zero_point_id != YNN_INVALID_VALUE_ID) {
// When computing a sum, the zero point gets multiplied by the number
// of elements in the reduction.
ynn_define_get_tensor_shape(subgraph, num_axes, axes, ynn_type_int32,
/*rank=*/0, input_a_id, &reduce_size_id,
/*flags=*/YNN_NODE_FLAG_RESHAPE_1D);
ynn_define_binary(subgraph, ynn_binary_multiply, a.zero_point_id,
reduce_size_id, &output.zero_point_id,
/*flags=*/0);
}
output.scale_id = a.scale_id;
break;
case ynn_reduce_max:
case ynn_reduce_min:
case ynn_reduce_min_max:
output.zero_point_id = a.zero_point_id;
output.scale_id = a.scale_id;
break;
default:
YNN_UNREACHABLE;
}

*output_id = output.id;
}

// Propagate shape
ynn_value& output = subgraph->value(*output_id);
output.extents = a.extents;
for (int i = static_cast<int>(output.extents.size()) - 1; i >= 0; --i) {
if (k_dims[i]) {
if (keep_dims) {
output.extents[i] = {};
} else {
output.extents.erase(output.extents.begin() + i);
}
}
}

if (op == ynn_reduce_min_max) {
// This reduction adds a dimension for the min/max index.
output.extents.push_back(2);
}

if (input_b_id == YNN_INVALID_VALUE_ID) {
input_b_id = get_reduce_identity_value(subgraph, output, op);
}

// Make the node.
ynn_node node;
define_reduce(*subgraph, node, op, k_dims, input_a_id, input_b_id, *output_id,
keep_dims);
subgraph->add_node(std::move(node));
return ynn_status_success;
}
Expand Down
23 changes: 23 additions & 0 deletions ynnpack/subgraph/reduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2025 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#ifndef XNNPACK_YNNPACK_SUBGRAPH_REDUCE_H_
#define XNNPACK_YNNPACK_SUBGRAPH_REDUCE_H_

#include <cstdint>

#include "ynnpack/include/ynnpack.h"
#include "ynnpack/subgraph/subgraph.h"

namespace ynn {

void define_reduce(ynn_subgraph& subgraph, ynn_node& node,
ynn_reduce_operator op, const ynn::axes_set& k_dims,
uint32_t input_a_id, uint32_t input_b_id, uint32_t output_id,
bool keep_dims);

} // namespace ynn

#endif // XNNPACK_YNNPACK_SUBGRAPH_REDUCE_H_
Loading